Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-10-12 21:38:38 +09:00 committed by kimbwook
commit f856e53323
No known key found for this signature in database
GPG key ID: 13B032C99CBD373A
1881 changed files with 886579 additions and 84028 deletions

View file

@ -21,7 +21,9 @@ async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: dict[Ap
deps[Api.safety],
deps[Api.tool_runtime],
deps[Api.tool_groups],
deps[Api.conversations],
policy,
Api.telemetry in deps,
)
await impl.initialize()
return impl

View file

@ -7,8 +7,6 @@
import copy
import json
import re
import secrets
import string
import uuid
import warnings
from collections.abc import AsyncGenerator
@ -50,11 +48,17 @@ from llama_stack.apis.inference import (
CompletionMessage,
Inference,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SamplingParams,
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
@ -68,17 +72,17 @@ from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict_new,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
@ -100,6 +104,7 @@ class ChatAgent(ShieldRunnerMixin):
persistence_store: KVStore,
created_at: str,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.agent_id = agent_id
self.agent_config = agent_config
@ -110,6 +115,7 @@ class ChatAgent(ShieldRunnerMixin):
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
self.telemetry_enabled = telemetry_enabled
ShieldRunnerMixin.__init__(
self,
@ -177,29 +183,31 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
span = tracing.get_current_span()
if span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
turn_id = str(uuid.uuid4())
if self.telemetry_enabled:
span = tracing.get_current_span()
if span is not None:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
span = tracing.get_current_span()
if span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled:
span = tracing.get_current_span()
if span is not None:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("request", request.model_dump_json())
span.set_attribute("turn_id", request.turn_id)
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
await self._initialize_tools()
async for chunk in self._run_turn(request):
@ -385,9 +393,12 @@ class ChatAgent(ShieldRunnerMixin):
touchpoint: str,
) -> AsyncGenerator:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if self.telemetry_enabled and span is not None:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
@ -420,7 +431,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
if self.telemetry_enabled and span is not None:
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
@ -443,7 +455,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
span.set_attribute("output", "no violations")
if self.telemetry_enabled and span is not None:
span.set_attribute("output", "no violations")
async def _run(
self,
@ -505,26 +518,95 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
stop_reason: StopReason | None = None
async with tracing.span("inference") as span:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
if self.telemetry_enabled and span is not None:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
from pydantic import BaseModel
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
elif isinstance(value, dict):
return {k: _serialize_nested(v) for k, v in value.items()}
elif isinstance(value, list):
return [_serialize_nested(item) for item in value]
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
role = openai_msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**openai_msg)
elif role == "system":
return OpenAISystemMessageParam(**openai_msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**openai_msg)
elif role == "tool":
return OpenAIToolMessageParam(**openai_msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**openai_msg)
else:
raise ValueError(f"Unknown message role: {role}")
# Convert messages to OpenAI format
openai_messages: list[OpenAIMessageParam] = [
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
]
# Convert tool definitions to OpenAI format
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
# Extract tool_choice from tool_config for OpenAI compatibility
# Note: tool_choice can only be provided when tools are also provided
tool_choice = None
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
tc = self.agent_config.tool_config.tool_choice
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
# Convert tool_choice to OpenAI format
if tool_choice_str in ("auto", "none", "required"):
tool_choice = tool_choice_str
else:
# It's a specific tool name, wrap it in the proper format
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
response_format=self.agent_config.response_format,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stream=True,
sampling_params=sampling_params,
tool_config=self.agent_config.tool_config,
):
)
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
)
async for chunk in response_stream:
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
stop_reason = event.stop_reason or StopReason.end_of_turn
continue
delta = event.delta
@ -533,7 +615,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls.append(delta.tool_call)
elif delta.parse_status == ToolCallParseStatus.failed:
# If we cannot parse the tools, set the content to the unparsed raw text
content = delta.tool_call
content = str(delta.tool_call)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -560,20 +642,19 @@ class ChatAgent(ShieldRunnerMixin):
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
if self.telemetry_enabled and span is not None:
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
@ -681,7 +762,9 @@ class ChatAgent(ShieldRunnerMixin):
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
}
if self.telemetry_enabled
else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_result = await self.execute_tool_call_maybe(
@ -696,7 +779,8 @@ class ChatAgent(ShieldRunnerMixin):
call_id=tool_call.call_id,
content=tool_result.content,
)
span.set_attribute("output", result_message.model_dump_json())
if self.telemetry_enabled and span is not None:
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(
@ -790,18 +874,12 @@ class ChatAgent(ShieldRunnerMixin):
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
# Use input_schema from ToolDef directly
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
input_schema=tool_def.input_schema,
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
@ -811,42 +889,34 @@ class ChatAgent(ShieldRunnerMixin):
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
if input_tool_name is not None and not any(tool.name == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.name for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
identifier: str | BuiltinTool | None = tool_def.identifier
identifier: str | BuiltinTool | None = tool_def.name
if identifier == "web_search":
identifier = BuiltinTool.brave_search
else:
identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.identifier):
identifier = tool_def.identifier
if input_tool_name in (None, tool_def.name):
identifier = tool_def.name
else:
identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name_to_def[identifier] = ToolDefinition(
tool_name=identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
input_schema=tool_def.input_schema,
)
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
@ -890,12 +960,18 @@ class ChatAgent(ShieldRunnerMixin):
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
try:
args = json.loads(tool_call.arguments)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**args,
**self.tool_name_to_args.get(tool_name_str, {}),
},
)
@ -920,7 +996,7 @@ async def get_raw_document_text(document: Document) -> str:
DeprecationWarning,
stacklevel=2,
)
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):

View file

@ -30,6 +30,7 @@ from llama_stack.apis.agents import (
)
from llama_stack.apis.agents.openai_responses import OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@ -63,7 +64,9 @@ class MetaReferenceAgentsImpl(Agents):
safety_api: Safety,
tool_runtime_api: ToolRuntime,
tool_groups_api: ToolGroups,
conversations_api: Conversations,
policy: list[AccessRule],
telemetry_enabled: bool = False,
):
self.config = config
self.inference_api = inference_api
@ -71,6 +74,8 @@ class MetaReferenceAgentsImpl(Agents):
self.safety_api = safety_api
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.conversations_api = conversations_api
self.telemetry_enabled = telemetry_enabled
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl: OpenAIResponsesImpl | None = None
@ -86,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
conversations_api=self.conversations_api,
)
async def create_agent(
@ -135,6 +141,7 @@ class MetaReferenceAgentsImpl(Agents):
),
created_at=agent_info.created_at,
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
@ -322,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -329,12 +337,14 @@ class MetaReferenceAgentsImpl(Agents):
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input,
model,
instructions,
previous_response_id,
conversation,
store,
stream,
temperature,
@ -342,6 +352,7 @@ class MetaReferenceAgentsImpl(Agents):
tools,
include,
max_infer_iters,
shields,
)
async def list_openai_responses(

View file

@ -8,7 +8,7 @@ import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
@ -24,24 +24,33 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.errors import (
InvalidConversationIdError,
)
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="openai::responses")
logger = get_logger(name=__name__, category="openai_responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
@ -57,12 +66,14 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
conversations_api: Conversations,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.conversations_api = conversations_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
@ -72,26 +83,52 @@ class OpenAIResponsesImpl:
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
new_input_items.extend(previous_response.output)
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
return new_input_items
async def _process_input_with_previous_response(
self,
input: str | list[OpenAIResponseInput],
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
"""Process input with optional previous response context.
Returns:
tuple: (all_input for storage, messages for chat completion, tool context)
"""
tool_context = ToolContext(tools)
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
previous_response: _OpenAIResponseObjectWithInputAndMessages = (
await self.responses_store.get_response_object(previous_response_id)
)
all_input = await self._prepend_previous_response(input, previous_response)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
if previous_response.messages:
# Use stored messages directly and convert only new input
message_adapter = TypeAdapter(list[OpenAIMessageParam])
messages = message_adapter.validate_python(previous_response.messages)
new_messages = await convert_response_input_to_chat_messages(input, previous_messages=messages)
messages.extend(new_messages)
else:
new_input_items.extend(input)
# Backward compatibility: reconstruct from inputs
messages = await convert_response_input_to_chat_messages(all_input)
input = new_input_items
tool_context.recover_tools_from_previous_response(previous_response)
else:
all_input = input
messages = await convert_response_input_to_chat_messages(input)
return input
return all_input, messages, tool_context
async def _prepend_instructions(self, messages, instructions):
if instructions:
@ -102,7 +139,7 @@ class OpenAIResponsesImpl:
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
return response_with_input.to_response_object()
async def list_openai_responses(
self,
@ -138,6 +175,7 @@ class OpenAIResponsesImpl:
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
@ -165,6 +203,7 @@ class OpenAIResponsesImpl:
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
messages=messages,
)
async def create_openai_response(
@ -173,6 +212,7 @@ class OpenAIResponsesImpl:
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
@ -180,15 +220,36 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
shields: list | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
if conversation is not None and previous_response_id is not None:
raise ValueError(
"Mutually exclusive parameters: 'previous_response_id' and 'conversation'. Ensure you are only providing one of these parameters."
)
original_input = input # needed for syncing to Conversations
if conversation is not None:
if not conversation.startswith("conv_"):
raise InvalidConversationIdError(conversation)
# Check conversation exists (raises ConversationNotFoundError if not)
_ = await self.conversations_api.get_conversation(conversation)
input = await self._load_conversation_context(conversation, input)
stream_gen = self._create_streaming_response(
input=input,
original_input=original_input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
conversation=conversation,
store=store,
temperature=temperature,
text=text,
@ -199,24 +260,42 @@ class OpenAIResponsesImpl:
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
final_response = None
final_event_type = None
failed_response = None
if response is None:
raise ValueError("The response stream never completed")
return response
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if failed_response is not None:
error_message = (
failed_response.error.message
if failed_response and failed_response.error
else "Response stream failed without error details"
)
raise RuntimeError(f"OpenAI response failed: {error_message}")
if final_response is None:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
original_input: str | list[OpenAIResponseInput] | None = None,
instructions: str | None = None,
previous_response_id: str | None = None,
conversation: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
@ -224,8 +303,9 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id
)
await self._prepend_instructions(messages, instructions)
# Structured outputs
@ -237,10 +317,12 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
inputs=all_input,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}"
response_id = f"resp_{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
@ -255,17 +337,110 @@ class OpenAIResponsesImpl:
# Stream the response
final_response = None
failed_response = None
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type == "response.completed":
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
yield stream_chunk
# Store the response if requested
if store and final_response:
await self._store_response(
response=final_response,
input=input,
)
# Store and sync immediately after yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks early
if (
stream_chunk.type in {"response.completed", "response.incomplete"}
and store
and final_response
and failed_response is None
):
await self._store_response(
response=final_response,
input=all_input,
messages=orchestrator.final_messages,
)
if stream_chunk.type in {"response.completed", "response.incomplete"} and conversation and final_response:
# for Conversations, we need to use the original_input if it's available, otherwise use input
sync_input = original_input if original_input is not None else input
await self._sync_response_to_conversation(conversation, sync_input, final_response)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)
async def _load_conversation_context(
self, conversation_id: str, content: str | list[OpenAIResponseInput]
) -> list[OpenAIResponseInput]:
"""Load conversation history and merge with provided content."""
conversation_items = await self.conversations_api.list(conversation_id, order="asc")
context_messages = []
for item in conversation_items.data:
if isinstance(item, OpenAIResponseMessage):
if item.role == "user":
context_messages.append(
OpenAIResponseMessage(
role="user", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
elif item.role == "assistant":
context_messages.append(
OpenAIResponseMessage(
role="assistant", content=item.content, id=item.id if hasattr(item, "id") else None
)
)
# add new content to context
if isinstance(content, str):
context_messages.append(OpenAIResponseMessage(role="user", content=content))
elif isinstance(content, list):
context_messages.extend(content)
return context_messages
async def _sync_response_to_conversation(
self, conversation_id: str, content: str | list[OpenAIResponseInput], response: OpenAIResponseObject
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# add user content message(s)
if isinstance(content, str):
conversation_items.append(
{"type": "message", "role": "user", "content": [{"type": "input_text", "text": content}]}
)
elif isinstance(content, list):
for item in content:
if not isinstance(item, OpenAIResponseMessage):
raise NotImplementedError(f"Unsupported input item type: {type(item)}")
if item.role == "user":
if isinstance(item.content, str):
conversation_items.append(
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": item.content}],
}
)
elif isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "user", "content": item.content})
else:
raise NotImplementedError(f"Unsupported user message content type: {type(item.content)}")
elif item.role == "assistant":
if isinstance(item.content, list):
conversation_items.append({"type": "message", "role": "assistant", "content": item.content})
else:
raise NotImplementedError(f"Unsupported assistant message content type: {type(item.content)}")
else:
raise NotImplementedError(f"Unsupported message role: {item.role}")
# add assistant response message
for output_item in response.output:
if isinstance(output_item, OpenAIResponseMessage) and output_item.role == "assistant":
if hasattr(output_item, "content") and isinstance(output_item.content, list):
conversation_items.append({"type": "message", "role": "assistant", "content": output_item.content})
if conversation_items:
adapter = TypeAdapter(list[ConversationItem])
validated_items = adapter.validate_python(conversation_items)
await self.conversations_api.add_items(conversation_id, validated_items)

View file

@ -10,18 +10,26 @@ from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartReasoningText,
OpenAIResponseContentPartRefusal,
OpenAIResponseError,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFailed,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseIncomplete,
OpenAIResponseObjectStreamResponseInProgress,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
@ -29,20 +37,31 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseObjectStreamResponseReasoningTextDelta,
OpenAIResponseObjectStreamResponseReasoningTextDone,
OpenAIResponseObjectStreamResponseRefusalDelta,
OpenAIResponseObjectStreamResponseRefusalDone,
OpenAIResponseOutput,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
OpenAIResponseUsage,
OpenAIResponseUsageInputTokensDetails,
OpenAIResponseUsageOutputTokensDetails,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
@ -50,6 +69,27 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal
logger = get_logger(name=__name__, category="agents::meta_reference")
def convert_tooldef_to_chat_tool(tool_def):
"""Convert a ToolDef to OpenAI ChatCompletionToolParam format.
Args:
tool_def: ToolDef from the tools API
Returns:
ChatCompletionToolParam suitable for OpenAI chat completion
"""
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
internal_tool_def = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
return convert_tooldef_to_openai_tool(internal_tool_def)
class StreamingResponseOrchestrator:
def __init__(
self,
@ -70,117 +110,378 @@ class StreamingResponseOrchestrator:
self.tool_executor = tool_executor
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
self.citation_files: dict[str, str] = {}
# Track accumulated usage across all inference calls
self.accumulated_usage: OpenAIResponseUsage | None = None
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
output_messages: list[OpenAIResponseOutput] = []
# Create initial response and emit response.created immediately
initial_response = OpenAIResponseObject(
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
for item in outputs:
if hasattr(item, "model_copy"):
cloned.append(item.model_copy(deep=True))
else:
cloned.append(item)
return cloned
def _snapshot_response(
self,
status: str,
outputs: list[OpenAIResponseOutput],
*,
error: OpenAIResponseError | None = None,
) -> OpenAIResponseObject:
return OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="in_progress",
output=output_messages.copy(),
status=status,
output=self._clone_outputs(outputs),
text=self.text,
tools=self.ctx.available_tools(),
error=error,
usage=self.accumulated_usage,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
output_messages: list[OpenAIResponseOutput] = []
# Process all tools (including MCP tools) and emit streaming events
if self.ctx.response_tools:
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
yield stream_event
# Emit response.created followed by response.in_progress to align with OpenAI streaming
yield OpenAIResponseObjectStreamResponseCreated(
response=self._snapshot_response("in_progress", output_messages)
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseInProgress(
response=self._snapshot_response("in_progress", output_messages),
sequence_number=self.sequence_number,
)
async for stream_event in self._process_tools(output_messages):
yield stream_event
n_iter = 0
messages = self.ctx.messages.copy()
final_status = "completed"
last_completion_result: ChatCompletionResult | None = None
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=self.ctx.response_format,
try:
while True:
# Text is the default response format for chat completion so don't need to pass it
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
stream_options={
"include_usage": True,
},
)
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
last_completion_result = completion_result_data
current_response = self._build_chat_completion(completion_result_data)
(
function_tool_calls,
non_function_tool_calls,
approvals,
next_turn_messages,
) = self._separate_tool_calls(current_response, messages)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
self.citation_files,
message_id=completion_result_data.message_item_id,
)
)
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
messages = next_turn_messages
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(
f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}"
)
final_status = "incomplete"
break
if last_completion_result and last_completion_result.finish_reason == "length":
final_status = "incomplete"
except Exception as exc: # noqa: BLE001
self.final_messages = messages.copy()
self.sequence_number += 1
error = OpenAIResponseError(code="internal_error", message=str(exc))
failure_response = self._snapshot_response("failed", output_messages, error=error)
yield OpenAIResponseObjectStreamResponseFailed(
response=failure_response,
sequence_number=self.sequence_number,
)
return
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
self.final_messages = messages.copy()
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
current_response, messages
if final_status == "incomplete":
self.sequence_number += 1
final_response = self._snapshot_response("incomplete", output_messages)
yield OpenAIResponseObjectStreamResponseIncomplete(
response=final_response,
sequence_number=self.sequence_number,
)
else:
final_response = self._snapshot_response("completed", output_messages)
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="completed",
text=self.text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
approvals = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
logger.debug(f"Choice message content: {choice.message.content}")
logger.debug(f"Choice message tool_calls: {choice.message.tool_calls}")
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
if self._approval_required(tool_call.function.name):
approval_response = self.ctx.approval_response(
tool_call.function.name, tool_call.function.arguments
)
if approval_response:
if approval_response.approve:
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
next_turn_messages.pop()
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
next_turn_messages.pop()
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
def _accumulate_chunk_usage(self, chunk: OpenAIChatCompletionChunk) -> None:
"""Accumulate usage from a streaming chunk into the response usage format."""
if not chunk.usage:
return
if self.accumulated_usage is None:
# Convert from chat completion format to response format
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else None
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else None
),
)
else:
# Accumulate across multiple inference calls
self.accumulated_usage = OpenAIResponseUsage(
input_tokens=self.accumulated_usage.input_tokens + chunk.usage.prompt_tokens,
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
# Use latest non-null details
input_tokens_details=(
OpenAIResponseUsageInputTokensDetails(cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens)
if chunk.usage.prompt_tokens_details
else self.accumulated_usage.input_tokens_details
),
output_tokens_details=(
OpenAIResponseUsageOutputTokensDetails(
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
)
if chunk.usage.completion_tokens_details
else self.accumulated_usage.output_tokens_details
),
)
async def _handle_reasoning_content_chunk(
self,
reasoning_content: str,
reasoning_part_emitted: bool,
reasoning_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Emit content_part.added event for first reasoning chunk
if not reasoning_part_emitted:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=reasoning_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartReasoningText(
text="", # Will be filled incrementally via reasoning deltas
),
sequence_number=self.sequence_number,
)
# Emit reasoning_text.delta event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseReasoningTextDelta(
content_index=reasoning_content_index,
delta=reasoning_content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
async def _handle_refusal_content_chunk(
self,
refusal_content: str,
refusal_part_emitted: bool,
refusal_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Emit content_part.added event for first refusal chunk
if not refusal_part_emitted:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=refusal_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartRefusal(
refusal="", # Will be filled incrementally via refusal deltas
),
sequence_number=self.sequence_number,
)
# Emit refusal.delta event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseRefusalDelta(
content_index=refusal_content_index,
delta=refusal_content,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
async def _emit_reasoning_done_events(
self,
reasoning_text_accumulated: list[str],
reasoning_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
final_reasoning_text = "".join(reasoning_text_accumulated)
# Emit reasoning_text.done event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseReasoningTextDone(
content_index=reasoning_content_index,
text=final_reasoning_text,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Emit content_part.done for reasoning
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=reasoning_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartReasoningText(
text=final_reasoning_text,
),
sequence_number=self.sequence_number,
)
async def _emit_refusal_done_events(
self,
refusal_text_accumulated: list[str],
refusal_content_index: int,
message_item_id: str,
message_output_index: int,
) -> AsyncIterator[OpenAIResponseObjectStream]:
final_refusal_text = "".join(refusal_text_accumulated)
# Emit refusal.done event
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseRefusalDone(
content_index=refusal_content_index,
refusal=final_refusal_text,
item_id=message_item_id,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
# Emit content_part.done for refusal
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=refusal_content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartRefusal(
refusal=final_refusal_text,
),
sequence_number=self.sequence_number,
)
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
@ -200,11 +501,23 @@ class StreamingResponseOrchestrator:
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
reasoning_part_emitted = False
refusal_part_emitted = False
content_index = 0
reasoning_content_index = 1 # reasoning is a separate content part
refusal_content_index = 2 # refusal is a separate content part
message_output_index = len(output_messages)
reasoning_text_accumulated = []
refusal_text_accumulated = []
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
# Accumulate usage from chunks (typically in final chunk with stream_options)
self._accumulate_chunk_usage(chunk)
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
@ -213,8 +526,10 @@ class StreamingResponseOrchestrator:
content_part_emitted = True
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
content_index=content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
@ -222,10 +537,10 @@ class StreamingResponseOrchestrator:
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
content_index=content_index,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
output_index=message_output_index,
sequence_number=self.sequence_number,
)
@ -234,6 +549,32 @@ class StreamingResponseOrchestrator:
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Handle reasoning content if present (non-standard field for o1/o3 models)
if hasattr(chunk_choice.delta, "reasoning_content") and chunk_choice.delta.reasoning_content:
async for event in self._handle_reasoning_content_chunk(
reasoning_content=chunk_choice.delta.reasoning_content,
reasoning_part_emitted=reasoning_part_emitted,
reasoning_content_index=reasoning_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
reasoning_part_emitted = True
reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content)
# Handle refusal content if present
if chunk_choice.delta.refusal:
async for event in self._handle_refusal_content_chunk(
refusal_content=chunk_choice.delta.refusal,
refusal_part_emitted=refusal_part_emitted,
refusal_content_index=refusal_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
refusal_part_emitted = True
refusal_text_accumulated.append(chunk_choice.delta.refusal)
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
@ -298,8 +639,11 @@ class StreamingResponseOrchestrator:
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]
# Ensure that arguments, if sent back to the inference provider, are not None
tool_call.function.arguments = tool_call.function.arguments or "{}"
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
final_arguments = tool_call.function.arguments
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
# Check if this is an MCP tool call
@ -322,14 +666,36 @@ class StreamingResponseOrchestrator:
final_text = "".join(chat_response_content)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
content_index=content_index,
response_id=self.response_id,
item_id=message_item_id,
output_index=message_output_index,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=self.sequence_number,
)
# Emit reasoning done events if reasoning content was streamed
if reasoning_part_emitted:
async for event in self._emit_reasoning_done_events(
reasoning_text_accumulated=reasoning_text_accumulated,
reasoning_content_index=reasoning_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Emit refusal done events if refusal content was streamed
if refusal_part_emitted:
async for event in self._emit_refusal_done_events(
refusal_text_accumulated=refusal_text_accumulated,
refusal_content_index=refusal_content_index,
message_item_id=message_item_id,
message_output_index=message_output_index,
):
yield event
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
@ -414,6 +780,8 @@ class StreamingResponseOrchestrator:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if result.citation_files:
self.citation_files.update(result.citation_files)
if tool_call_log:
output_messages.append(tool_call_log)
@ -462,29 +830,21 @@ class StreamingResponseOrchestrator:
sequence_number=self.sequence_number,
)
async def _process_tools(
async def _process_new_tools(
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process all tools and emit appropriate streaming events."""
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import Tool
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.apis.tools import ToolDef
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)
@ -525,7 +885,6 @@ class StreamingResponseOrchestrator:
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
sequence_number=self.sequence_number,
)
try:
# Parse allowed/never allowed tools
always_allowed = None
@ -538,14 +897,22 @@ class StreamingResponseOrchestrator:
never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
tool_defs = None
list_id = f"mcp_list_{uuid.uuid4()}"
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}
async with tracing.span("list_mcp_tools", attributes):
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
id=list_id,
server_label=mcp_tool.server_label,
tools=[],
)
@ -556,23 +923,7 @@ class StreamingResponseOrchestrator:
continue
if not always_allowed or t.name in always_allowed:
# Add to chat tools for inference
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in t.parameters
},
)
openai_tool = convert_tooldef_to_openai_tool(tool_def)
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
@ -587,48 +938,129 @@ class StreamingResponseOrchestrator:
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
input_schema=t.input_schema
or {
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
"properties": {},
"required": [],
},
)
)
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event
except Exception as e:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
async def _process_tools(
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
mcp_server = self.mcp_tool_to_server[tool_name]
if mcp_server.require_approval == "always":
return True
if mcp_server.require_approval == "never":
return False
if isinstance(mcp_server, ApprovalFilter):
if tool_name in mcp_server.always:
return True
if tool_name in mcp_server.never:
return False
return True
async def _add_mcp_approval_request(
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
mcp_server = self.mcp_tool_to_server[tool_name]
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
arguments=arguments,
id=f"approval_{uuid.uuid4()}",
name=tool_name,
server_label=mcp_server.server_label,
)
output_messages.append(mcp_approval_request)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _add_mcp_list_tools(
self, mcp_list_message: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _reuse_mcp_list_tools(
self, original: OpenAIResponseOutputMessageMCPListTools, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
for t in original.tools:
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
# convert from input_schema to map of ToolParamDefinitions...
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
input_schema=t.input_schema,
)
# ...then can convert that to openai completions tool
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
server_label=original.server_label,
tools=original.tools,
)
async for stream_event in self._add_mcp_list_tools(mcp_list_message, output_messages):
yield stream_event

View file

@ -11,6 +11,9 @@ from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
OpenAIResponseObjectStreamResponseFileSearchCallInProgress,
OpenAIResponseObjectStreamResponseFileSearchCallSearching,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
@ -35,6 +38,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from .types import ChatCompletionContext, ToolExecutionResult
@ -94,7 +98,10 @@ class ToolExecutor:
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
@ -129,8 +136,6 @@ class ToolExecutor:
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
@ -138,27 +143,58 @@ class ToolExecutor:
)
)
unique_files = set()
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
# Get file_id from attributes if result_item.file_id is empty
file_id = result_item.file_id or (
result_item.attributes.get("document_id") if result_item.attributes else None
)
metadata_text = f"document_id: {file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
text_content = f"[{i + 1}] {metadata_text} (cite as <|{file_id}|>)\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
citation_instruction = ""
if unique_files:
citation_instruction = (
" Cite sources immediately at the end of sentences before punctuation, using `<|file-id|>` format (e.g., 'This is a fact <|file-Cn3MSNn72ENTiiq11Qda4A|>.'). "
"Do not add extra punctuation. Use only the file IDs provided (do not invent new ones)."
)
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.{citation_instruction}\n',
)
)
# handling missing attributes for old versions
citation_files = {}
for result in search_results:
file_id = result.file_id
if not file_id and result.attributes:
file_id = result.attributes.get("document_id")
filename = result.filename
if not filename and result.attributes:
filename = result.attributes.get("filename")
if not filename:
filename = "unknown"
citation_files[file_id] = filename
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
"citation_files": citation_files,
},
)
@ -188,7 +224,13 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
elif function_name == "knowledge_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
@ -203,6 +245,16 @@ class ToolExecutor:
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
@ -219,12 +271,18 @@ class ToolExecutor:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
attributes = {
"server_label": mcp_tool.server_label,
"server_url": mcp_tool.server_url,
"tool_name": function_name,
}
async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
@ -234,15 +292,20 @@ class ToolExecutor:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
async with tracing.span("knowledge_search", {}):
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
attributes = {
"tool_name": function_name,
}
async with tracing.span("invoke_tool", attributes):
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
@ -278,7 +341,13 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
elif function_name == "knowledge_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)

View file

@ -10,9 +10,20 @@ from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseTool,
OpenAIResponseToolMCP,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
@ -24,6 +35,7 @@ class ToolExecutionResult(BaseModel):
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
citation_files: dict[str, str] | None = None
@dataclass
@ -51,6 +63,86 @@ class ChatCompletionResult:
return bool(self.tool_calls)
class ToolContext(BaseModel):
"""Holds information about tools from this and (if relevant)
previous response in order to facilitate reuse of previous
listings where appropriate."""
# tools argument passed into current request:
current_tools: list[OpenAIResponseInputTool]
# reconstructed map of tool -> mcp server from previous response:
previous_tools: dict[str, OpenAIResponseInputToolMCP]
# reusable mcp-list-tools objects from previous response:
previous_tool_listings: list[OpenAIResponseOutputMessageMCPListTools]
# tool arguments from current request that still need to be processed:
tools_to_process: list[OpenAIResponseInputTool]
def __init__(
self,
current_tools: list[OpenAIResponseInputTool] | None,
):
super().__init__(
current_tools=current_tools or [],
previous_tools={},
previous_tool_listings=[],
tools_to_process=current_tools or [],
)
def recover_tools_from_previous_response(
self,
previous_response: OpenAIResponseObject,
):
"""Determine which mcp_list_tools objects from previous response we can reuse."""
if self.current_tools and previous_response.tools:
previous_tools_by_label: dict[str, OpenAIResponseToolMCP] = {}
for tool in previous_response.tools:
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
else:
tools_to_process.append(tool)
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
self.previous_tool_listings = [
obj for obj in previous_response.output if obj.type == "mcp_list_tools" and obj.server_label in matched
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
return []
def convert_tool(tool: OpenAIResponseInputTool) -> OpenAIResponseTool:
if isinstance(tool, OpenAIResponseInputToolWebSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFileSearch):
return tool
if isinstance(tool, OpenAIResponseInputToolFunction):
return tool
if isinstance(tool, OpenAIResponseInputToolMCP):
return OpenAIResponseToolMCP(
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
return [convert_tool(tool) for tool in self.current_tools]
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
@ -58,3 +150,45 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
tool_context: ToolContext | None
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
tool_context: ToolContext,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
tool_context=tool_context,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
return self.approval_responses.get(request.id, None) if request else None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.tool_context:
return []
return self.tool_context.available_tools()

View file

@ -4,15 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -43,7 +47,12 @@ from llama_stack.apis.inference import (
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
async def convert_chat_choice_to_response_message(
choice: OpenAIChoice,
citation_files: dict[str, str] | None = None,
*,
message_id: str | None = None,
) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
@ -55,9 +64,11 @@ async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenA
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
annotations, clean_text = _extract_citations_from_text(output_content, citation_files or {})
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
status="completed",
role="assistant",
)
@ -95,9 +106,13 @@ async def convert_response_content_to_chat_content(
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
previous_messages: list[OpenAIMessageParam] | None = None,
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
:param input: The input to convert
:param previous_messages: Optional previous messages to check for function_call references
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
@ -149,6 +164,11 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
@ -156,16 +176,53 @@ async def convert_response_input_to_chat_messages(
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
# Skip user messages that duplicate the last user message in previous_messages
# This handles cases where input includes context for function_call_outputs
if previous_messages and input_item.role == "user":
last_user_msg = None
for msg in reversed(previous_messages):
if isinstance(msg, OpenAIUserMessageParam):
last_user_msg = msg
break
if last_user_msg:
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
previous_call_ids = _extract_tool_call_ids(previous_messages)
for call_id in list(tool_call_results.keys()):
if call_id in previous_call_ids:
# Valid: this output references a call from previous messages
# Add the tool message
messages.append(tool_call_results[call_id])
del tool_call_results[call_id]
# If still have unpaired outputs, error
if len(tool_call_results):
raise ValueError(
f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call"
)
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
def _extract_tool_call_ids(messages: list[OpenAIMessageParam]) -> set[str]:
"""Extract all tool_call IDs from messages."""
call_ids = set()
for msg in messages:
if isinstance(msg, OpenAIAssistantMessageParam):
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call in tool_calls:
# tool_call is a Pydantic model, use attribute access
call_ids.add(tool_call.id)
return call_ids
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
@ -193,6 +250,53 @@ async def get_message_type_by_role(role: str):
return role_to_type.get(role)
def _extract_citations_from_text(
text: str, citation_files: dict[str, str]
) -> tuple[list[OpenAIResponseAnnotationFileCitation], str]:
"""Extract citation markers from text and create annotations
Args:
text: The text containing citation markers like [file-Cn3MSNn72ENTiiq11Qda4A]
citation_files: Dictionary mapping file_id to filename
Returns:
Tuple of (annotations_list, clean_text_without_markers)
"""
file_id_regex = re.compile(r"<\|(?P<file_id>file-[A-Za-z0-9_-]+)\|>")
annotations = []
parts = []
total_len = 0
last_end = 0
for m in file_id_regex.finditer(text):
# segment before the marker
prefix = text[last_end : m.start()]
# drop one space if it exists (since marker is at sentence end)
if prefix.endswith(" "):
prefix = prefix[:-1]
parts.append(prefix)
total_len += len(prefix)
fid = m.group(1)
if fid in citation_files:
annotations.append(
OpenAIResponseAnnotationFileCitation(
file_id=fid,
filename=citation_files[fid],
index=total_len, # index points to punctuation
)
)
last_end = m.end()
parts.append(text[last_end:])
cleaned_text = "".join(parts)
return annotations, cleaned_text
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],

View file

@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -178,9 +180,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -425,18 +427,23 @@ class ReferenceBatchesImpl(Batches):
valid = False
if batch.endpoint == "/v1/chat/completions":
required_params = [
required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
elif batch.endpoint == "/v1/completions":
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params:
if param not in body:
@ -601,7 +608,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
chat_response = await self.inference_api.openai_chat_completion(chat_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -614,8 +622,9 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
elif request.url == "/v1/completions":
completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
completion_response = await self.inference_api.openai_completion(completion_params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
@ -630,6 +639,20 @@ class ReferenceBatchesImpl(Batches):
"body": completion_response.model_dump_json(),
},
}
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -159,31 +166,42 @@ class MetaReferenceEvalImpl(
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
if candidate.sampling_params.stop:
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion(
params = OpenAICompletionRequestWithExtraBody(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
prompt=input_content,
**sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
input_messages = [
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,
params = OpenAIChatCompletionRequestWithExtraBody(
model=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
**sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -9,11 +9,12 @@ import uuid
from pathlib import Path
from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
@ -21,7 +22,9 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -44,7 +47,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
@ -63,7 +66,7 @@ class LocalfsFilesImpl(Files):
def _generate_file_id(self) -> str:
"""Generate a unique file ID for OpenAI API."""
return f"file-{uuid.uuid4().hex}"
return generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
def _get_file_path(self, file_id: str) -> Path:
"""Get the filesystem path for a file ID."""
@ -74,7 +77,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -86,15 +89,16 @@ class LocalfsFilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
if expires_after_anchor is not None or expires_after_seconds is not None:
raise NotImplementedError("File expiration is not supported by this provider")
if expires_after is not None:
logger.warning(
f"File expiration is not supported by this provider, ignoring expires_after: {expires_after}"
)
file_id = self._generate_file_id()
file_path = self._get_file_path(file_id)
@ -150,7 +154,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -18,7 +18,7 @@ def model_checkpoint_dir(model_id) -> str:
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)

View file

@ -5,43 +5,17 @@
# the root directory of this source tree.
import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
InferenceProvider,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -59,15 +33,6 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
convert_request_to_raw,
)
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
@ -84,8 +49,6 @@ def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_
class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
@ -102,6 +65,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool:
return False
@ -167,15 +136,10 @@ class MetaReferenceInferenceImpl(
self.llama_model = llama_model
log.info("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
sampling_params=SamplingParams(max_tokens=10),
)
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
await self.openai_chat_completion(
model=model_id,
messages=[{"role": "user", "content": "Hi how are you?"}],
max_tokens=20,
)
log.info("Warmed up!")
@ -187,451 +151,8 @@ class MetaReferenceInferenceImpl(
elif request.model != self.model_id:
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def completion(
async def openai_chat_completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content = augment_content_with_response_format_prompt(response_format, content)
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
if request.stream:
return self._stream_completion(request)
else:
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
stop_reason = None
for token_results in self.generator.completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
logprobs = None
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
async def _nonstream_completion(self, request_batch: list[CompletionRequest]) -> list[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
results = []
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
CompletionResponse(
content=content,
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
if request.stream:
return self._stream_chat_completion(request)
else:
results = await self._nonstream_chat_completion([request])
return results[0]
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: list[int] = []
logprobs: list[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
states = [ItemState() for _ in request_batch]
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
return impl()
else:
return impl()
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
tokens = []
logprobs = []
stop_reason = None
ipython = False
for token_results in self.generator.chat_completion([request]):
token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.started,
),
)
)
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if ipython:
delta = ToolCallDelta(
tool_call=text,
parse_status=ToolCallParseStatus.in_progress,
)
else:
delta = TextDelta(text=text)
if stop_reason is None:
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call="",
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=stop_reason,
)
)
if self.config.create_distributed_process_group:
async with SEMAPHORE:
for x in impl():
yield x
else:
for x in impl():
yield x
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -27,8 +27,6 @@ class ModelRunner:
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
elif task[0] == "completion":
return self.llama.completion(task[1])
else:
raise ValueError(f"Unexpected task type {task[0]}")

View file

@ -4,19 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
CompletionResponse,
InferenceProvider,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
@ -26,7 +24,6 @@ from llama_stack.providers.utils.inference.embedding_mixin import (
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -36,7 +33,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
@ -74,28 +70,14 @@ class SentenceTransformersInferenceImpl(
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
async def openai_completion(
self,
model_id: str,
content: str,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator:
raise ValueError("Sentence transformers don't support completion")
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def chat_completion(
async def openai_chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -68,9 +68,7 @@ public class FunctionTagCustomToolGenerator {
{
"name": "{{t.tool_name}}",
"description": "{{t.description}}",
"parameters": {
"type": "dict",
"properties": { {{t.parameters}} }
"input_schema": { {{t.input_schema}} }
}
{{/let}}

View file

@ -104,9 +104,10 @@ class LoraFinetuningSingleDevice:
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
hf_repo = model.huggingface_repo or f"meta-llama/{model.descriptor()}"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
f"Please download the model using `huggingface-cli download {hf_repo} --local-dir ~/.llama/{model.descriptor()}`"
)
return str(checkpoint_dir)

View file

@ -10,7 +10,13 @@ from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage
from llama_stack.apis.inference import (
Inference,
Message,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -290,20 +296,21 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion(
model_id=self.model,
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
content = response.completion_message.content
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
@ -335,7 +342,7 @@ class LlamaGuardShield:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
return OpenAIUserMessageParam(role="user", content=prompt)
def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories()
@ -377,11 +384,12 @@ class LlamaGuardShield:
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.model,
messages=[shield_input_message],
stream=False,
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequestWithExtraBody
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,15 +55,17 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model_id=fn_def.params.judge_model,
params = OpenAIChatCompletionRequestWithExtraBody(
model=fn_def.params.judge_model,
messages=[
UserMessage(
content=judge_input_msg,
),
{
"role": "user",
"content": judge_input_msg,
}
],
)
content = judge_response.completion_message.content
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None

View file

@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
description="The service name to use for telemetry",
)
sinks: list[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
default=[TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel_trace, otel_metric, sqlite, console)",
)
sqlite_db_path: str = Field(
@ -49,7 +49,7 @@ class TelemetryConfig(BaseModel):
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:=\u200b}",
"sinks": "${env.TELEMETRY_SINKS:=console,sqlite}",
"sinks": "${env.TELEMETRY_SINKS:=sqlite}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + db_name,
"otel_exporter_otlp_endpoint": "${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}",
}

View file

@ -130,11 +130,9 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
trace.get_tracer_provider().force_flush()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}")
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
logger.debug("DEBUG: Routing MetricEvent to _log_metric")
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event, ttl_seconds)
@ -224,10 +222,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
# Always log to console if console sink is enabled (debug)
if TelemetrySink.CONSOLE in self.config.sinks:
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
# Add metric as an event to the current span
try:
with self._lock:

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
params = OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content

View file

@ -8,8 +8,6 @@ import asyncio
import base64
import io
import mimetypes
import secrets
import string
from typing import Any
import httpx
@ -33,7 +31,6 @@ from llama_stack.apis.tools import (
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import (
@ -53,10 +50,6 @@ from .context_retriever import generate_rag_query
log = get_logger(name=__name__, category="tool_runtime")
def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
@ -301,13 +294,16 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
input_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search for. Can be a natural language sentence or keywords.",
}
},
"required": ["query"],
},
),
]
)
@ -329,5 +325,8 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return ToolInvocationResult(
content=result.content or [],
metadata=result.metadata,
metadata={
**(result.metadata or {}),
"citation_files": getattr(result, "citation_files", None),
},
)

View file

@ -200,12 +200,10 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -227,8 +225,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def health(self) -> HealthResponse:
"""

View file

@ -410,12 +410,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
"""
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -436,8 +434,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
# Clean up mixin resources (file batch tasks)
await super().shutdown()
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]

View file

@ -32,9 +32,12 @@ def available_providers() -> list[ProviderSpec]:
Api.inference,
Api.safety,
Api.vector_io,
Api.vector_dbs,
Api.tool_runtime,
Api.tool_groups,
Api.conversations,
],
optional_api_dependencies=[
Api.telemetry,
],
description="Meta's reference implementation of an agent system that can use tools, access vector databases, and perform complex reasoning tasks.",
),

View file

@ -52,9 +52,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
@ -142,7 +140,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="databricks",
provider_type="remote::databricks",
pip_packages=[],
pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
@ -169,7 +167,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
@ -179,7 +177,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
pip_packages=["anthropic"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
@ -189,9 +187,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -202,7 +198,6 @@ def available_providers() -> list[ProviderSpec]:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
@ -233,9 +228,7 @@ Available Models:
api=Api.inference,
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
@ -245,7 +238,7 @@ Available Models:
api=Api.inference,
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
@ -255,9 +248,7 @@ Available Models:
api=Api.inference,
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
pip_packages=[],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
@ -277,7 +268,7 @@ Available Models:
api=Api.inference,
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
@ -287,7 +278,7 @@ Available Models:
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
pip_packages=[],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",

View file

@ -11,6 +11,7 @@ from llama_stack.providers.datatypes import (
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.providers.registry.vector_io import DEFAULT_VECTOR_IO_DEPS
def available_providers() -> list[ProviderSpec]:
@ -18,9 +19,8 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"chardet",
"pypdf",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
"numpy",
"scikit-learn",

View file

@ -12,13 +12,16 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
)
# Common dependencies for all vector IO providers that support document processing
DEFAULT_VECTOR_IO_DEPS = ["chardet", "pypdf"]
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::meta-reference",
pip_packages=["faiss-cpu"],
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
@ -29,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::faiss",
pip_packages=["faiss-cpu"],
pip_packages=["faiss-cpu"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
@ -82,7 +85,7 @@ more details about Faiss in general.
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite-vec",
pip_packages=["sqlite-vec"],
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
@ -289,7 +292,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite_vec",
pip_packages=["sqlite-vec"],
pip_packages=["sqlite-vec"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
@ -303,7 +306,7 @@ Please refer to the sqlite-vec provider documentation.
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"],
pip_packages=["chromadb-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -345,7 +348,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
pip_packages=["chromadb"],
pip_packages=["chromadb"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
@ -389,7 +392,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"],
pip_packages=["psycopg2-binary"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
@ -410,7 +413,7 @@ There are three implementations of search for PGVectoIndex available:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
- Eg. SQL query: SELECT document, embedding &lt;=&gt; %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
@ -500,7 +503,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"],
pip_packages=["weaviate-client>=4.16.5"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
@ -541,7 +544,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::qdrant",
pip_packages=["qdrant-client"],
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@ -594,7 +597,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"],
pip_packages=["qdrant-client"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
@ -607,7 +610,7 @@ Please refer to the inline provider documentation.
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"],
pip_packages=["pymilvus>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
@ -633,7 +636,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation
You can install Milvus using pymilvus:
If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash
pip install pymilvus
@ -807,7 +816,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"],
pip_packages=["pymilvus[milvus-lite]>=2.4.10"] + DEFAULT_VECTOR_IO_DEPS,
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],

View file

@ -14,7 +14,6 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring, ScoringResult
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .....apis.common.job_types import Job, JobStatus
@ -45,7 +44,7 @@ class NVIDIAEvalImpl(
self.inference_api = inference_api
self.agents_api = agents_api
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
ModelRegistryHelper.__init__(self)
async def initialize(self) -> None: ...

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
@ -23,6 +23,8 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -137,7 +139,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
@ -164,7 +166,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table(
"openai_files",
{
@ -195,23 +197,14 @@ class S3FilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
file_id = generate_object_id("file", lambda: f"file-{uuid.uuid4().hex}")
filename = getattr(file, "filename", None) or "uploaded_file"
created_at = self._now()
expires_after = None
if expires_after_anchor is not None or expires_after_seconds is not None:
# we use ExpiresAfter to validate input
expires_after = ExpiresAfter(
anchor=expires_after_anchor, # type: ignore[arg-type]
seconds=expires_after_seconds, # type: ignore[arg-type]
)
# the default is no expiration.
# to implement no expiration we set an expiration beyond the max.
# we'll hide this fact from users when returning the file object.
@ -268,7 +261,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: str | None = None
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter
impl = AnthropicInferenceAdapter(config)
impl = AnthropicInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,31 +4,33 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from collections.abc import Iterable
from anthropic import AsyncAnthropic
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
from .models import MODEL_ENTRIES
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config
class AnthropicInferenceAdapter(OpenAIMixin):
config: AnthropicConfig
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "anthropic_api_key"
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def get_base_url(self):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,12 +20,7 @@ class AnthropicProviderDataValidator(BaseModel):
@json_schema_type
class AnthropicConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="API key for Anthropic models",
)
class AnthropicConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -10,6 +10,6 @@ from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
impl = AzureInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -4,33 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
from .models import MODEL_ENTRIES
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
class AzureInferenceAdapter(OpenAIMixin):
config: AzureConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "azure_api_key"
def get_base_url(self) -> str:
"""
@ -39,26 +23,3 @@ class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -9,6 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -30,10 +31,7 @@ class AzureProviderDataValidator(BaseModel):
@json_schema_type
class AzureConfig(BaseModel):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
class AzureConfig(RemoteInferenceProviderConfig):
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
# https://learn.microsoft.com/en-us/azure/ai-foundry/openai/concepts/models?tabs=global-standard%2Cstandard-chat-completions
LLM_MODEL_IDS = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"gpt-5-chat",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
]
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES

View file

@ -5,31 +5,21 @@
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from botocore.client import BaseClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -37,18 +27,10 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_strategy_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
)
from .models import MODEL_ENTRIES
@ -94,11 +76,9 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None
@ -115,82 +95,6 @@ class BedrockInferenceAdapter(
if self._client is not None:
self._client.close()
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model_with_response_stream(**params)
event_stream = res["body"]
async def _generate_and_convert_to_openai_compat():
for chunk in event_stream:
chunk = chunk["chunk"]["bytes"]
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
yield OpenAICompatCompletionResponse(choices=[choice])
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
@ -218,36 +122,6 @@ class BedrockInferenceAdapter(
),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
# Convert foundation model ID to inference profile ID
region_name = self.client.meta.region_name
inference_profile_id = _to_inference_profile_id(model.provider_resource_id, region_name)
embeddings = []
for content in contents:
assert not content_has_media(content), "Bedrock does not support media for embeddings"
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=inference_profile_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
@ -257,3 +131,15 @@ class BedrockInferenceAdapter(
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -12,7 +12,7 @@ async def get_adapter_impl(config: CerebrasImplConfig, _deps):
assert isinstance(config, CerebrasImplConfig), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
impl = CerebrasInferenceAdapter(config=config)
await impl.initialize()

View file

@ -4,198 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import CerebrasImplConfig
from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_entries=MODEL_ENTRIES,
)
self.config = config
class CerebrasInferenceAdapter(OpenAIMixin):
config: CerebrasImplConfig
# TODO: make this use provider data, etc. like other providers
self.client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(
request,
)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def openai_embeddings(
self,

View file

@ -7,23 +7,20 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
class CerebrasImplConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr | None = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -5,11 +5,12 @@
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,27 +6,29 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(BaseModel):
url: str = Field(
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: str = Field(
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {

View file

@ -4,165 +4,41 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import Iterable
from openai import OpenAI
from databricks.sdk import WorkspaceClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequestWithExtraBody
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import DatabricksImplConfig
SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
] + SAFETY_MODELS_ENTRIES
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
async def initialize(self) -> None:
return
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
async def shutdown(self) -> None:
pass
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def completion(
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def openai_completion(
self,
model: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: FireworksImplConfig, _deps):
from .fireworks import FireworksInferenceAdapter
assert isinstance(config, FireworksImplConfig), f"Unexpected config type: {type(config)}"
impl = FireworksInferenceAdapter(config)
impl = FireworksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Fireworks.ai API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -4,434 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
class FireworksInferenceAdapter(OpenAIMixin):
config: FireworksImplConfig
async def initialize(self) -> None:
pass
embedding_model_metadata: dict[str, dict[str, int]] = {
"nomic-ai/nomic-embed-text-v1.5": {"embedding_dimension": 768, "context_length": 8192},
"accounts/fireworks/models/qwen3-embedding-8b": {"embedding_dimension": 4096, "context_length": 40960},
}
async def shutdown(self) -> None:
pass
provider_data_api_key_field: str = "fireworks_api_key"
def _get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-Provider-Data as { "fireworks_api_key": <your api key>}'
)
return provider_data.fireworks_api_key
def _get_base_url(self) -> str:
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
self,
sampling_params: SamplingParams | None,
fmt: ResponseFormat,
logprobs: LogProbConfig | None,
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
options["logprobs"] = logprobs.top_k
if options["logprobs"] <= 0 or options["logprobs"] >= 5:
raise ValueError("Required range: 0 < top_k < 5")
return options
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = self._get_client().chat.completions.acreate(**params)
else:
stream = self._get_client().completion.acreate(**params)
async for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
# TODO: tools are never added to the request, so we need to add them here
if media_present or not llama_model:
input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) for m in request.messages
]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
params = {
"model": request.model,
**input_dict,
"stream": bool(request.stream),
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimension"):
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -1,70 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-maverick-instruct-basic",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
ProviderModelEntry(
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
] + SAFETY_MODELS_ENTRIES

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: str | None = None
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter
impl = GeminiInferenceAdapter(config)
impl = GeminiInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,12 +20,7 @@ class GeminiProviderDataValidator(BaseModel):
@json_schema_type
class GeminiConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="API key for Gemini models",
)
class GeminiConfig(RemoteInferenceProviderConfig):
@classmethod
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
return {

View file

@ -4,31 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig
from .models import MODEL_ENTRIES
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="gemini",
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config
class GeminiInferenceAdapter(OpenAIMixin):
config: GeminiConfig
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "gemini_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-004": {"embedding_dimension": 768, "context_length": 2048},
}
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -1,34 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
"gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="text-embedding-004",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import GroqConfig
async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
async def get_adapter_impl(config: GroqConfig, _deps):
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter
adapter = GroqInferenceAdapter(config)
adapter = GroqInferenceAdapter(config=config)
return adapter

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,13 +20,7 @@ class GroqProviderDataValidator(BaseModel):
@json_schema_type
class GroqConfig(BaseModel):
api_key: str | None = Field(
# The Groq client library loads the GROQ_API_KEY environment variable by default
default=None,
description="The Groq API key",
)
class GroqConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.groq.com",
description="The URL for the Groq AI server",

View file

@ -6,33 +6,13 @@
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
class GroqInferenceAdapter(OpenAIMixin):
config: GroqConfig
class GroqInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
_config: GroqConfig
def __init__(self, config: GroqConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="groq",
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "groq_api_key"
def get_base_url(self) -> str:
return f"{self.config.url}/openai/v1"
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -1,48 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
build_model_entry,
)
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3-8b-8192",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama-3.1-8b-instant",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3-70b-8192",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b-versatile",
CoreModelId.llama3_3_70b_instruct.value,
),
# Groq only contains a preview version for llama-3.2-3b
# Preview models aren't recommended for production use, but we include this one
# to pass the test fixture
# TODO(aidand): Replace this with a stable model once Groq supports it
build_hf_repo_model_entry(
"llama-3.2-3b-preview",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import InferenceProvider
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> InferenceProvider:
async def get_adapter_impl(config: LlamaCompatConfig, _deps):
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
adapter = LlamaCompatInferenceAdapter(config=config)
return adapter

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,12 +20,7 @@ class LlamaProviderDataValidator(BaseModel):
@json_schema_type
class LlamaCompatConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="The Llama API key",
)
class LlamaCompatConfig(RemoteInferenceProviderConfig):
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",

View file

@ -3,44 +3,27 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference.inference import (
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class LlamaCompatInferenceAdapter(OpenAIMixin):
config: LlamaCompatConfig
provider_data_api_key_field: str = "llama_api_key"
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="meta_llama",
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
@ -49,8 +32,18 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
return self.config.openai_compat_api_base
async def initialize(self):
await super().initialize()
async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError()
async def shutdown(self):
await super().shutdown()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()

View file

@ -1,25 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct-FP8",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -39,32 +39,13 @@ client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
### Create Completion
The following example shows how to create a completion for an NVIDIA NIM.
> [!NOTE]
> The hosted NVIDIA Llama NIMs (for example ```meta-llama/Llama-3.1-8B-Instruct```) that have ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` do not support the ```completion``` method, while locally deployed NIMs do.
```python
response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
content="Complete the sentence using one word: Roses are red, violets are :",
stream=False,
sampling_params={
"max_tokens": 50,
},
)
print(f"Response: {response.content}")
```
### Create Chat Completion
The following example shows how to create a chat completion for an NVIDIA NIM.
```python
response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "system",
@ -76,11 +57,9 @@ response = client.inference.chat_completion(
},
],
stream=False,
sampling_params={
"max_tokens": 50,
},
max_tokens=50,
)
print(f"Response: {response.completion_message.content}")
print(f"Response: {response.choices[0].message.content}")
```
### Tool Calling Example ###
@ -108,15 +87,15 @@ tool_definition = ToolDefinition(
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
tool_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Response: {tool_response.choices[0].message.content}")
if tool_response.choices[0].message.tool_calls:
for tool_call in tool_response.choices[0].message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
@ -142,8 +121,8 @@ response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
structured_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
@ -153,7 +132,7 @@ structured_response = client.inference.chat_completion(
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
print(f"Structured Response: {structured_response.choices[0].message.content}")
```
### Create Embeddings
@ -186,8 +165,8 @@ def load_image_as_base64(image_path):
image_path = {path_to_the_image}
demo_image_b64 = load_image_as_base64(image_path)
vlm_response = client.inference.chat_completion(
model_id="nvidia/vila",
vlm_response = client.chat.completions.create(
model="nvidia/vila",
messages=[
{
"role": "user",
@ -207,5 +186,5 @@ vlm_response = client.inference.chat_completion(
],
)
print(f"VLM Response: {vlm_response.completion_message.content}")
print(f"VLM Response: {vlm_response.choices[0].message.content}")
```

View file

@ -15,7 +15,8 @@ async def get_adapter_impl(config: NVIDIAConfig, _deps) -> Inference:
if not isinstance(config, NVIDIAConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = NVIDIAInferenceAdapter(config)
adapter = NVIDIAInferenceAdapter(config=config)
await adapter.initialize()
return adapter

View file

@ -7,13 +7,14 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIAConfig(BaseModel):
class NVIDIAConfig(RemoteInferenceProviderConfig):
"""
Configuration for the NVIDIA NIM inference endpoint.
@ -39,10 +40,6 @@ class NVIDIAConfig(BaseModel):
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
description="A base url for accessing the NVIDIA NIM",
)
api_key: SecretStr | None = Field(
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",

View file

@ -1,109 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="nvidia/vila",
model_type=ModelType.llm,
),
# NeMo Retriever Text Embedding models -
#
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
#
# +-----------------------------------+--------+-----------+-----------+------------+
# | Model ID | Max | Publisher | Embedding | Dynamic |
# | | Tokens | | Dimension | Embeddings |
# +-----------------------------------+--------+-----------+-----------+------------+
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
# +-----------------------------------+--------+-----------+-----------+------------+
ProviderModelEntry(
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 2048,
"context_length": 8192,
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-embedqa-e5-v5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 4096,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="snowflake/arctic-embed-l",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
] + SAFETY_MODELS_ENTRIES

View file

@ -4,63 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError
from openai import NOT_GIVEN
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig
from .models import MODEL_ENTRIES
from .openai_utils import (
convert_chat_completion_request,
convert_completion_request,
convert_openai_completion_choice,
convert_openai_completion_stream,
)
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
class NVIDIAInferenceAdapter(OpenAIMixin):
config: NVIDIAConfig
"""
NVIDIA Inference Adapter for Llama Stack.
@ -74,28 +37,22 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
embedding_model_metadata: dict[str, dict[str, int]] = {
"nvidia/llama-3.2-nv-embedqa-1b-v2": {"embedding_dimension": 2048, "context_length": 8192},
"nvidia/nv-embedqa-e5-v5": {"embedding_dimension": 512, "context_length": 1024},
"nvidia/nv-embedqa-mistral-7b-v2": {"embedding_dimension": 512, "context_length": 4096},
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
async def initialize(self) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
if _is_nvidia_hosted(config):
if not config.api_key:
if _is_nvidia_hosted(self.config):
if not self.config.auth_credential:
raise RuntimeError(
"API key is required for hosted NVIDIA NIM. Either provide an API key or use a self-hosted NIM."
)
# elif self._config.api_key:
#
# we don't raise this warning because a user may have deployed their
# self-hosted NIM with an API key requirement.
#
# warnings.warn(
# "API key is not required for self-hosted NVIDIA NIM. "
# "Consider removing the api_key from the configuration."
# )
self._config = config
def get_api_key(self) -> str:
"""
@ -103,7 +60,13 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
:return: The NVIDIA API key
"""
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
if not _is_nvidia_hosted(self.config):
return "NO KEY REQUIRED"
return None
def get_base_url(self) -> str:
"""
@ -111,103 +74,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
:return: The NVIDIA API base URL
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if content_has_media(content):
raise NotImplementedError("Media is not supported")
# ToDo: check health of NeMo endpoints and enable this
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
provider_model_id = await self._get_provider_model_id(model_id)
request = convert_completion_request(
request=CompletionRequest(
model=provider_model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
),
n=1,
)
try:
response = await self.client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
if stream:
return convert_openai_completion_stream(response)
else:
# we pass n=1 to get only one completion
return convert_openai_completion_choice(response.choices[0])
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported")
#
# Llama Stack: contents = list[str] | list[InterleavedContentItem]
# ->
# OpenAI: input = str | list[str]
#
# we can ignore str and always pass list[str] to OpenAI
#
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
provider_model_id = await self._get_provider_model_id(model_id)
extra_body = {}
if text_truncation is not None:
text_truncation_options = {
TextTruncation.none: "NONE",
TextTruncation.end: "END",
TextTruncation.start: "START",
}
extra_body["truncate"] = text_truncation_options[text_truncation]
if output_dimension is not None:
extra_body["dimensions"] = output_dimension
if task_type is not None:
task_type_options = {
EmbeddingTaskType.document: "passage",
EmbeddingTaskType.query: "query",
}
extra_body["input_type"] = task_type_options[task_type]
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->
# Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def openai_embeddings(
self,
@ -259,49 +126,3 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
model=response.model,
usage=usage,
)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
# await check_health(self._config) # this raises errors
provider_model_id = await self._get_provider_model_id(model_id)
request = await convert_chat_completion_request(
request=ChatCompletionRequest(
model=provider_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
),
n=1,
)
try:
response = await self.client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
if stream:
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=False)
else:
# we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0])

View file

@ -1,217 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from collections.abc import AsyncGenerator
from typing import Any
from openai import AsyncStream
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
GreedySamplingStrategy,
JsonSchemaResponseFormat,
TokenLogProbs,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.openai_compat import (
_convert_openai_finish_reason,
convert_message_to_openai_dict_new,
convert_tooldef_to_openai_tool,
)
async def convert_chat_completion_request(
request: ChatCompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# messages -> messages
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> GrammarResponseFormat TODO(mf)
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
# tools -> tools
# tool_choice ("auto", "required") -> tool_choice
# tool_prompt_format -> TBD
# stream -> stream
# logprobs -> logprobs
if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
raise ValueError(
f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
)
nvext = {}
payload: dict[str, Any] = dict(
model=request.model,
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
stream=request.stream,
n=n,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
)
if request.response_format:
# server bug - setting guided_json changes the behavior of response_format resulting in an error
# payload.update(response_format="json_object")
nvext.update(guided_json=request.response_format.json_schema)
if request.tools:
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
if request.tool_config.tool_choice:
payload.update(
tool_choice=request.tool_config.tool_choice.value
) # we cannot include tool_choice w/o tools, server will complain
if request.logprobs:
payload.update(logprobs=True)
payload.update(top_logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
strategy = request.sampling_params.strategy
if isinstance(strategy, TopPSamplingStrategy):
nvext.update(top_k=-1)
payload.update(top_p=strategy.top_p)
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
else:
raise ValueError(f"Unsupported sampling strategy: {strategy}")
return payload
def convert_completion_request(
request: CompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# prompt -> prompt
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> nvext.guided_json
# stream -> stream
# logprobs.top_k -> logprobs
nvext = {}
payload: dict[str, Any] = dict(
model=request.model,
prompt=request.content,
stream=request.stream,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
n=n,
)
if request.response_format:
# this is not openai compliant, it is a nim extension
nvext.update(guided_json=request.response_format.json_schema)
if request.logprobs:
payload.update(logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_completion_logprobs(
logprobs: OpenAICompletionLogprobs | None,
) -> list[TokenLogProbs] | None:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
"""
if not logprobs:
return None
return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
def convert_openai_completion_choice(
choice: OpenAIChoice,
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
"""
return CompletionResponse(
content=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)
async def convert_openai_completion_stream(
stream: AsyncStream[OpenAICompletion],
) -> AsyncGenerator[CompletionResponse, None]:
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
delta=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)

View file

@ -4,53 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import httpx
from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
return "integrate.api.nvidia.com" in config.url
async def _get_health(url: str) -> tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NVIDIAConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config)
impl = OllamaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,17 +6,17 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
class OllamaImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -1,106 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
build_model_entry,
)
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_entry(
"llama3.1:8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:70b-instruct-fp16",
CoreModelId.llama3_1_70b_instruct.value,
),
build_model_entry(
"llama3.1:70b",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.1:405b-instruct-fp16",
CoreModelId.llama3_1_405b_instruct.value,
),
build_model_entry(
"llama3.1:405b",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:1b-instruct-fp16",
CoreModelId.llama3_2_1b_instruct.value,
),
build_model_entry(
"llama3.2:1b",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_model_entry(
"llama3.2:3b",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:11b-instruct-fp16",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:latest",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.2-vision:90b-instruct-fp16",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_model_entry(
"llama3.2-vision:90b",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"],
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nomic-embed-text",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
] + SAFETY_MODELS_ENTRIES

View file

@ -6,94 +6,49 @@
import asyncio
import base64
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
convert_image_content_to_url,
interleaved_content_as_str,
localize_image_content,
request_has_media,
)
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
class OllamaInferenceAdapter(OpenAIMixin):
config: OllamaImplConfig
# automatically set by the resolver when instantiating the provider
__provider_id__: str
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
},
"nomic-embed-text:latest": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:v1.5": {
"embedding_dimension": 768,
"context_length": 8192,
},
"nomic-embed-text:137m-v1.5-fp16": {
"embedding_dimension": 768,
"context_length": 8192,
},
}
download_images: bool = True
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
@ -104,7 +59,7 @@ class OllamaInferenceAdapter(
return self._clients[loop]
def get_api_key(self):
return "NO_KEY"
return "NO KEY REQUIRED"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
@ -117,62 +72,6 @@ class OllamaInferenceAdapter(
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
)
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text:latest",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
# kill embedding models since we don't know dimensions for them
if "bert" in m.details.family:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the Ollama server.
@ -190,343 +89,14 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None:
self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.ollama_client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.register_helper.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
else:
assert not media_present, "Ollama does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["raw"] = True
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.ollama_client.chat(**params)
else:
r = await self.ollama_client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.ollama_client.chat(**params)
else:
s = await self.ollama_client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self._get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.ollama_client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if await self.check_model_availability(model.provider_model_id):
return model
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
model.provider_resource_id = f"{model.provider_model_id}:latest"
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
)
return model
if model.model_type == ModelType.embedding:
response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.ollama_client.list()
available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id
assert provider_resource_id is not None # mypy
if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
return model
raise UnsupportedModelError(provider_resource_id, available_models)
# mutating this should be considered an anti-pattern
model.provider_resource_id = provider_resource_id
return model
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self._get_model(model)
# Ollama does not support image urls, so we need to download the image and convert it to base64
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
if isinstance(m.content, list):
for c in m.content:
if c.type == "image_url" and c.image_url and c.image_url.url:
localize_result = await localize_image_content(c.image_url.url)
if localize_result is None:
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
content, format = localize_result
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
return m
messages = [await _convert_message(m) for m in messages]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"role": message.role,
"images": [await convert_image_content_to_url(content, download=True, include_format=False)],
}
else:
text = content.text if isinstance(content, TextContentItem) else content
assert isinstance(text, str)
return {
"role": message.role,
"content": text,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: str | None = None
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter
impl = OpenAIInferenceAdapter(config)
impl = OpenAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,11 +20,7 @@ class OpenAIProviderDataValidator(BaseModel):
@json_schema_type
class OpenAIConfig(BaseModel):
api_key: str | None = Field(
default=None,
description="API key for OpenAI models",
)
class OpenAIConfig(RemoteInferenceProviderConfig):
base_url: str = Field(
default="https://api.openai.com/v1",
description="Base URL for OpenAI API",

View file

@ -1,60 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-4",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-mini",
"gpt-4o-audio-preview",
"chatgpt-4o-latest",
"o1",
"o1-mini",
"o3-mini",
"o4-mini",
]
@dataclass
class EmbeddingModelInfo:
"""Structured representation of embedding model information."""
embedding_dimension: int
context_length: int
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id=model_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": model_info.embedding_dimension,
"context_length": model_info.context_length,
},
)
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -5,60 +5,29 @@
# the root directory of this source tree.
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::openai")
#
# This OpenAI adapter implements Inference methods using two mixins -
# This OpenAI adapter implements Inference methods using OpenAIMixin
#
# | Inference Method | Implementation Source |
# |----------------------------|--------------------------|
# | completion | LiteLLMOpenAIMixin |
# | chat_completion | LiteLLMOpenAIMixin |
# | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | OpenAIMixin |
# | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | OpenAIMixin |
#
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class OpenAIInferenceAdapter(OpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="openai",
api_key_from_config=config.api_key,
provider_data_api_key_field="openai_api_key",
)
self.config = config
# we set is_openai_compat so users can use the canonical
# openai model names like "gpt-4" or "gpt-3.5-turbo"
# and the model name will be translated to litellm's
# "openai/gpt-4" or "openai/gpt-3.5-turbo" transparently.
# if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
config: OpenAIConfig
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
provider_data_api_key_field: str = "openai_api_key"
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
}
def get_base_url(self) -> str:
"""
@ -67,9 +36,3 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Returns the OpenAI API base URL from the configuration.
"""
return self.config.base_url
async def initialize(self) -> None:
await super().initialize()
async def shutdown(self) -> None:
await super().shutdown()

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PassthroughImplConfig(BaseModel):
class PassthroughImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default=None,
description="The URL for the passthrough endpoint",

View file

@ -4,54 +4,32 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncIterator
from typing import Any
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig
class PassthroughInferenceAdapter(Inference):
def __init__(self, config: PassthroughImplConfig) -> None:
ModelRegistryHelper.__init__(self, [])
ModelRegistryHelper.__init__(self)
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
@ -89,126 +67,6 @@ class PassthroughInferenceAdapter(Inference):
provider_data=provider_data,
)
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)
request_params = {
"model_id": model.provider_resource_id,
"content": content,
"sampling_params": sampling_params,
"response_format": response_format,
"stream": stream,
"logprobs": logprobs,
}
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
# only pass through the not None params
return await client.inference.completion(**json_params)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
# TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
"tools": tools,
"tool_choice": tool_choice,
"tool_prompt_format": tool_prompt_format,
"response_format": response_format,
"stream": stream,
"logprobs": logprobs,
}
# only pass through the not None params
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=response.completion_message.content.text,
stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings(
self,
model_id: str,
contents: list[InterleavedContent],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
client = self._get_client()
model = await self.model_store.get_model(model_id)
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
async def openai_embeddings(
self,
model: str,
@ -221,110 +79,31 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_chat_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: RunpodImplConfig, _deps):
from .runpod import RunpodInferenceAdapter
assert isinstance(config, RunpodImplConfig), f"Unexpected config type: {type(config)}"
impl = RunpodInferenceAdapter(config)
impl = RunpodInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,19 +6,21 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RunpodImplConfig(BaseModel):
class RunpodImplConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the Runpod model serving endpoint",
)
api_token: str | None = Field(
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)

View file

@ -3,155 +3,40 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from openai import OpenAI
from collections.abc import AsyncIterator
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
class RunpodInferenceAdapter(OpenAIMixin):
"""
Adapter for RunPod's OpenAI-compatible API endpoints.
Supports VLLM for serverless endpoint self-hosted or public endpoints.
Can work with any runpod endpoints that support OpenAI-compatible API
"""
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
config: RunpodImplConfig
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
class RunpodInferenceAdapter(
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
self.config = config
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
async def openai_chat_completion(
self,
model: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
raise NotImplementedError()
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
params = params.model_copy()
async def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
return await super().openai_chat_completion(params)

View file

@ -4,15 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import SambaNovaImplConfig
async def get_adapter_impl(config: SambaNovaImplConfig, _deps) -> Inference:
async def get_adapter_impl(config: SambaNovaImplConfig, _deps):
from .sambanova import SambaNovaInferenceAdapter
assert isinstance(config, SambaNovaImplConfig), f"Unexpected config type: {type(config)}"
impl = SambaNovaInferenceAdapter(config)
impl = SambaNovaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,8 +6,9 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -19,15 +20,11 @@ class SambaNovaProviderDataValidator(BaseModel):
@json_schema_type
class SambaNovaImplConfig(BaseModel):
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Meta-Llama-3.1-8B-Instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"Meta-Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -5,42 +5,20 @@
# the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
class SambaNovaInferenceAdapter(OpenAIMixin):
config: SambaNovaImplConfig
provider_data_api_key_field: str = "sambanova_api_key"
download_images: bool = True # SambaNova does not support image downloads server-size, perform them on the client
"""
SambaNova Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of LiteLLMOpenAIMixin.check_model_availability().
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
"""
def __init__(self, config: SambaNovaImplConfig):
self.config = config
self.environment_available_models = []
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
litellm_provider_name="sambanova",
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key",
openai_compat_api_base=self.config.url,
download_images=True, # SambaNova requires base64 image encoding
json_schema_strict=False, # SambaNova doesn't support strict=True yet
)
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.

View file

@ -7,11 +7,14 @@
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TGIImplConfig(BaseModel):
class TGIImplConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
url: str = Field(
description="The URL for the TGI serving endpoint",
)

View file

@ -5,79 +5,21 @@
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import Iterable
from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class _HfAdapter(
OpenAIMixin,
Inference,
ModelsProtocolPrivate,
):
class _HfAdapter(OpenAIMixin):
url: str
api_key: SecretStr
@ -87,234 +29,14 @@ class _HfAdapter(
overwrite_completion_id = True # TGI always returns id=""
def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}
def get_api_key(self):
return self.api_key.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self):
return self.url
async def shutdown(self) -> None:
pass
async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models
async def register_model(self, model: Model) -> Model:
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
def _build_options(
self,
sampling_params: SamplingParams | None = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# TGI does not support temperature=0 when using greedy sampling
# We set it to 1e-3 instead, anything lower outputs garbage from TGI
# We can use top_p sampling strategy to specify lower temperature
if abs(options["temperature"]) < 1e-10:
options["temperature"] = 1e-3
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
if chunk.details:
finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice(text=token_result.text, finish_reason=finish_reason)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response)
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.hf_client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
choice = OpenAICompatCompletionChoice(text=token_result.text)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model)
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id]
async def openai_embeddings(
self,

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter
assert isinstance(config, TogetherImplConfig), f"Unexpected config type: {type(config)}"
impl = TogetherInferenceAdapter(config)
impl = TogetherInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import Field, SecretStr
from pydantic import Field
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -18,10 +18,6 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The Together AI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> dict[str, Any]:

View file

@ -1,103 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
# source: https://docs.together.ai/docs/serverless-models#embedding-models
EMBEDDING_MODEL_ENTRIES = {
"togethercomputer/m2-bert-80M-32k-retrieval": ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-32k-retrieval",
metadata={
"embedding_dimension": 768,
"context_length": 32768,
},
),
"BAAI/bge-large-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-large-en-v1.5",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
"BAAI/bge-base-en-v1.5": ProviderModelEntry(
provider_model_id="BAAI/bge-base-en-v1.5",
metadata={
"embedding_dimension": 768,
"context_length": 512,
},
),
"Alibaba-NLP/gte-modernbert-base": ProviderModelEntry(
provider_model_id="Alibaba-NLP/gte-modernbert-base",
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
),
"intfloat/multilingual-e5-large-instruct": ProviderModelEntry(
provider_model_id="intfloat/multilingual-e5-large-instruct",
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
}
MODEL_ENTRIES = (
[
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-3B-Instruct-Turbo",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]
+ SAFETY_MODELS_ENTRIES
+ list(EMBEDDING_MODEL_ENTRIES.values())
)

View file

@ -4,109 +4,47 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from openai import NOT_GIVEN, AsyncOpenAI
from collections.abc import Iterable
from together import AsyncTogether
from together.constants import BASE_URL
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIEmbeddingUsage
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media,
)
from .config import TogetherImplConfig
from .models import EMBEDDING_MODEL_ENTRIES, MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
self._model_cache: dict[str, Model] = {}
class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
config: TogetherImplConfig
def get_api_key(self):
return self.config.api_key.get_secret_value()
embedding_model_metadata: dict[str, dict[str, int]] = {
"togethercomputer/m2-bert-80M-32k-retrieval": {"embedding_dimension": 768, "context_length": 32768},
"BAAI/bge-large-en-v1.5": {"embedding_dimension": 1024, "context_length": 512},
"BAAI/bge-base-en-v1.5": {"embedding_dimension": 768, "context_length": 512},
"Alibaba-NLP/gte-modernbert-base": {"embedding_dimension": 768, "context_length": 8192},
"intfloat/multilingual-e5-large-instruct": {"embedding_dimension": 1024, "context_length": 512},
}
_model_cache: dict[str, Model] = {}
provider_data_api_key_field: str = "together_api_key"
def get_base_url(self):
return BASE_URL
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_client(self) -> AsyncTogether:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
config_api_key = self.config.auth_credential.get_secret_value() if self.config.auth_credential else None
if config_api_key:
together_api_key = config_api_key
else:
@ -118,177 +56,9 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
together_client = self._get_client().client
return AsyncOpenAI(
base_url=together_client.base_url,
api_key=together_client.api_key,
)
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = self._get_client()
r = await client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = self._get_client()
stream = await client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
def _build_options(
self,
sampling_params: SamplingParams | None,
logprobs: LogProbConfig | None,
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
if logprobs.top_k != 1:
raise ValueError(
f"Unsupported value: Together only supports logprobs top_k=1. {logprobs.top_k} was provided",
)
options["logprobs"] = 1
return options
async def chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
r = await client.chat.completions.create(**params)
else:
r = await client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model:
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
else:
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
params = {
"model": request.model,
**input_dict,
"stream": request.stream,
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logger.debug(f"params to together: {params}")
return params
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
client = self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def list_models(self) -> list[Model] | None:
self._model_cache = {}
async def list_provider_model_ids(self) -> Iterable[str]:
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
for m in await self._get_client().models.list():
if m.type == "embedding":
if m.id not in EMBEDDING_MODEL_ENTRIES:
logger.warning(f"Unknown embedding dimension for model {m.id}, skipping.")
continue
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=EMBEDDING_MODEL_ENTRIES[m.id].provider_model_id,
identifier=m.id,
model_type=ModelType.embedding,
metadata=EMBEDDING_MODEL_ENTRIES[m.id].metadata,
)
else:
self._model_cache[m.id] = Model(
provider_id=self.__provider_id__,
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.llm,
)
return self._model_cache.values()
async def should_refresh_models(self) -> bool:
return True
async def check_model_availability(self, model):
return model in self._model_cache
return [m.id for m in await self._get_client().models.list()]
async def openai_embeddings(
self,
@ -303,10 +73,9 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
the standard OpenAI embeddings endpoint.
The endpoint -
- does not return usage information
- not all models return usage information
- does not support user param, returns 400 Unrecognized request arguments supplied: user
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
- does not support encoding_format param, always returns floats, never base64
"""
# Together support ticket #13332 -> will not fix
if user is not None:
@ -314,13 +83,11 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
# Together support ticket #13333 -> escalated
if dimensions is not None:
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
# Together support ticket #13331 -> will not fix, compute client side
if encoding_format not in (None, NOT_GIVEN, "float"):
raise ValueError("Together's embeddings endpoint only supports encoding_format='float'.")
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format,
)
response.model = model # return the user the same model id they provided, avoid exposing the provider model id
@ -333,4 +100,4 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
)
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
return response
return response # type: ignore[no-any-return]

View file

@ -10,6 +10,6 @@ from .config import VertexAIConfig
async def get_adapter_impl(config: VertexAIConfig, _deps):
from .vertexai import VertexAIInferenceAdapter
impl = VertexAIInferenceAdapter(config)
impl = VertexAIInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,8 +6,9 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -23,7 +24,9 @@ class VertexAIProviderDataValidator(BaseModel):
@json_schema_type
class VertexAIConfig(BaseModel):
class VertexAIConfig(RemoteInferenceProviderConfig):
auth_credential: SecretStr | None = Field(default=None, exclude=True)
project: str = Field(
description="Google Cloud project ID for Vertex AI",
)

View file

@ -1,20 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
# Vertex AI model IDs with vertex_ai/ prefix as required by litellm
LLM_MODEL_IDS = [
"vertex_ai/gemini-2.0-flash",
"vertex_ai/gemini-2.5-flash",
"vertex_ai/gemini-2.5-pro",
]
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES

View file

@ -4,31 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
from .models import MODEL_ENTRIES
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="vertex_ai",
api_key_from_config=None, # Vertex AI uses ADC, not API keys
provider_data_api_key_field="vertex_project", # Use project for validation
)
self.config = config
class VertexAIInferenceAdapter(OpenAIMixin):
config: VertexAIConfig
provider_data_api_key_field: str = "vertex_project"
def get_api_key(self) -> str:
"""
@ -43,8 +31,7 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
credentials.refresh(google.auth.transport.requests.Request())
return str(credentials.token)
except Exception:
# If we can't get credentials, return empty string to let LiteLLM handle it
# This allows the LiteLLM mixin to work with ADC directly
# If we can't get credentials, return empty string to let the env work with ADC directly
return ""
def get_base_url(self) -> str:
@ -55,23 +42,3 @@ class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Vertex AI specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "vertex_project", None):
params["vertex_project"] = provider_data.vertex_project
if getattr(provider_data, "vertex_location", None):
params["vertex_location"] = provider_data.vertex_location
else:
params["vertex_project"] = self.config.project
params["vertex_location"] = self.config.location
# Remove api_key since Vertex AI uses ADC
params.pop("api_key", None)
return params

View file

@ -17,6 +17,6 @@ async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
from .vllm import VLLMInferenceAdapter
assert isinstance(config, VLLMInferenceAdapterConfig), f"Unexpected config type: {type(config)}"
impl = VLLMInferenceAdapter(config)
impl = VLLMInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,13 +6,14 @@
from pathlib import Path
from pydantic import BaseModel, Field, field_validator
from pydantic import Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMInferenceAdapterConfig(BaseModel):
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
url: str | None = Field(
default=None,
description="The URL for the vLLM model serving endpoint",
@ -21,18 +22,15 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=4096,
description="Maximum number of tokens to generate.",
)
api_token: str | None = Field(
default="fake",
auth_credential: SecretStr | None = Field(
default=None,
alias="api_token",
description="The API token",
)
tls_verify: bool | str = Field(
default=True,
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@field_validator("tls_verify")
@classmethod

View file

@ -3,305 +3,43 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncIterator
from urllib.parse import urljoin
import httpx
from openai import APIConnectionError, AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from pydantic import ConfigDict
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ModelStore,
ResponseFormat,
SamplingParams,
TextTruncation,
OpenAIChatCompletion,
OpenAIChatCompletionRequestWithExtraBody,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
UnparseableToolCall,
convert_message_to_openai_dict,
convert_openai_chat_completion_stream,
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media,
)
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():
return [
build_hf_repo_model_entry(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class VLLMInferenceAdapter(OpenAIMixin):
config: VLLMInferenceAdapterConfig
model_config = ConfigDict(arbitrary_types_allowed=True)
def _convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> list[ToolCall]:
if not tool_calls:
return []
provider_data_api_key_field: str = "vllm_api_token"
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
def _convert_to_vllm_tools_in_request(tools: list[ToolDefinition]) -> list[dict]:
compat_tools = []
for tool in tools:
properties = {}
compat_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compat_required.append(tool_key)
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
tool_name = tool.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
compat_tool = {
"type": "function",
"function": {
"name": tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compat_required,
},
},
}
compat_tools.append(compat_tool)
return compat_tools
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
def _process_vllm_chat_completion_end_of_stream(
finish_reason: str | None,
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
chunks = []
if finish_reason is not None:
stop_reason = _convert_to_vllm_finish_reason(finish_reason)
else:
stop_reason = StopReason.end_of_message
tool_call_bufs = tool_call_bufs or {}
for _index, tool_call_buf in sorted(tool_call_bufs.items()):
args_str = tool_call_buf.arguments or "{}"
try:
args = json.loads(args_str)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=current_event_type,
delta=ToolCallDelta(
tool_call=ToolCall(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
)
)
)
except Exception as e:
log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=str(tool_call_buf),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
chunks.append(
ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=last_chunk_content or ""),
logprobs=None,
stop_reason=stop_reason,
)
)
)
return chunks
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
tool_call_bufs: dict[str, UnparseableToolCall] = {}
end_of_stream_processed = False
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
return
choice = chunk.choices[0]
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (
tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
)
if choice.finish_reason:
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=choice.finish_reason,
last_chunk_content=choice.delta.content,
current_event_type=event_type,
tool_call_bufs=tool_call_bufs,
)
for c in chunks:
yield c
end_of_stream_processed = True
elif not choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=None,
)
)
event_type = ChatCompletionResponseEventType.progress
if end_of_stream_processed:
return
# the stream ended without a chunk containing finish_reason - we have to generate the
# respective completion chunks manually
chunks = _process_vllm_chat_completion_end_of_stream(
finish_reason=None, last_chunk_content=None, current_event_type=event_type, tool_call_bufs=tool_call_bufs
)
for c in chunks:
yield c
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
build_hf_repo_model_entries(),
litellm_provider_name="vllm",
api_key_from_config=config.api_token,
provider_data_api_key_field="vllm_api_token",
openai_compat_api_base=config.url,
)
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_api_key(self) -> str | None:
if self.config.auth_credential:
return self.config.auth_credential.get_secret_value()
return "NO KEY REQUIRED"
def get_base_url(self) -> str:
"""Get the base URL from config."""
@ -315,31 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
)
async def should_refresh_models(self) -> bool:
# Strictly respecting the refresh_models directive
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
pass
async def unregister_model(self, model_id: str) -> None:
pass
async def health(self) -> HealthResponse:
"""
Performs a health check by verifying connectivity to the remote vLLM server.
@ -361,216 +74,38 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def check_model_availability(self, model: str) -> bool:
"""
Skip the check when running without authentication.
"""
if not self.config.auth_credential:
model_ids = []
async for m in self.client.models.list():
if m.id == model: # Found exact match
return True
model_ids.append(m.id)
raise ValueError(f"Model '{model}' not found. Available models: {model_ids}")
log.warning(f"Not checking model availability for {model} as API token may trigger OAuth workflow")
return True
async def chat_completion(
async def openai_chat_completion(
self,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_config is not None:
tool_config.tool_choice = ToolChoice.none
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
response_format=response_format,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion_with_client(request, self.client)
else:
return await self._nonstream_chat_completion(request, self.client)
if not params.tools and params.tool_choice is not None:
params.tool_choice = ToolChoice.none.value
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse:
assert self.client is not None
params = await self._get_params(request)
r = await client.chat.completions.create(**params)
choice = r.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason),
tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
# This method is called from LiteLLMOpenAIMixin.chat_completion
# The response parameter contains the litellm response
# We need to convert it to our format
async def _stream_generator():
async for chunk in response:
yield chunk
async for chunk in convert_openai_chat_completion_stream(
_stream_generator(), enable_incremental_tool_calls=True
):
yield chunk
async def _stream_chat_completion_with_client(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Helper method for streaming with explicit client parameter."""
assert self.client is not None
params = await self._get_params(request)
stream = await client.chat.completions.create(**params)
if request.tools:
res = _process_vllm_chat_completion_stream_response(stream)
else:
res = process_chat_completion_stream_response(stream, request)
async for chunk in res:
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = await self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
) from e
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict: dict[str, Any] = {}
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if request.logprobs and request.logprobs.top_k:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
**input_dict,
"stream": request.stream,
**options,
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self._get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimension")
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = await self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
return await super().openai_chat_completion(params)

Some files were not shown because too many files have changed in this diff Show more