pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -41,16 +41,36 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: Optional[datetime] = None
@ -58,6 +78,14 @@ class StepCommon(BaseModel):
class StepType(Enum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
@ -66,6 +94,11 @@ class StepType(Enum):
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
@ -74,6 +107,12 @@ class InferenceStep(StepCommon):
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
tool_calls: List[ToolCall]
tool_responses: List[ToolResponse]
@ -81,13 +120,25 @@ class ToolExecutionStep(StepCommon):
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
violation: Optional[SafetyViolation]
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_db_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
@ -148,7 +199,7 @@ AgentToolGroup = register_schema(
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
input_shields: Optional[List[str]] = Field(default_factory=list)
output_shields: Optional[List[str]] = Field(default_factory=list)
@ -183,6 +234,23 @@ class AgentConfig(AgentConfigCommon):
response_format: Optional[ResponseFormat] = None
@json_schema_type
class Agent(BaseModel):
agent_id: str
agent_config: AgentConfig
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: List[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: List[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
@ -296,16 +364,13 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
stream: Optional[bool] = False
tool_config: Optional[ToolConfig] = None
# TODO (xiyan): temporary flag, will remove for 0.1.5
allow_turn_resume: Optional[bool] = False
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: List[ToolResponseMessage]
tool_responses: List[ToolResponse]
stream: Optional[bool] = False
@ -338,7 +403,13 @@ class Agents(Protocol):
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
async def create_agent_turn(
@ -355,8 +426,19 @@ class Agents(Protocol):
documents: Optional[List[Document]] = None,
toolgroups: Optional[List[AgentToolGroup]] = None,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
"""
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -367,7 +449,7 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]:
"""Resume an agent turn with executed tool call responses.
@ -392,7 +474,15 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn: ...
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
@ -404,14 +494,30 @@ class Agents(Protocol):
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse: ...
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse: ...
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
async def get_agents_session(
@ -419,17 +525,64 @@ class Agents(Protocol):
session_id: str,
agent_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session: ...
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None: ...
) -> None:
"""Delete an agent session by its ID.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE")
async def delete_agent(
self,
agent_id: str,
) -> None: ...
) -> None:
"""Delete an agent by its ID.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
"""List all agents.
:returns: A ListAgentsResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET")
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
"""
...

View file

@ -40,7 +40,7 @@ class BatchInference(Protocol):
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
@ -50,7 +50,7 @@ class BatchInference(Protocol):
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -14,6 +14,14 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class PaginatedRowsResult(BaseModel):
"""
A paginated list of rows from a dataset.
:param rows: The rows in the current page.
:param total_count: The total number of rows in the dataset.
:param next_page_token: The token to get the next page of rows.
"""
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
@ -36,7 +44,15 @@ class DatasetIO(Protocol):
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult: ...
) -> PaginatedRowsResult:
"""Get a paginated list of rows from a dataset.
:param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page.
:param page_token: The token to get the next page of rows.
:param filter_condition: (Optional) A condition to filter the rows by.
"""
...
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -19,6 +19,13 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model: str
sampling_params: SamplingParams
@ -27,6 +34,11 @@ class ModelCandidate(BaseModel):
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
@ -39,6 +51,13 @@ EvalCandidate = register_schema(
@json_schema_type
class BenchmarkConfig(BaseModel):
"""A benchmark configuration for evaluation.
:param eval_candidate: The candidate to evaluate.
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
"""
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
@ -53,18 +72,32 @@ class BenchmarkConfig(BaseModel):
@json_schema_type
class EvaluateResponse(BaseModel):
"""The response from an evaluation.
:param generations: The generations from the evaluation.
:param scores: The scores from the evaluation.
"""
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
) -> Job: ...
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation.
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
@ -72,14 +105,41 @@ class Eval(Protocol):
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse: ...
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
"""Evaluate a list of rows on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores
"""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
"""Get the status of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to cancel.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of.
:return: The result of the job.
"""

View file

@ -117,13 +117,11 @@ class ToolResponseMessage(BaseModel):
:param role: Must be "tool" to identify this as a tool response
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was called
:param content: The response content from the tool
"""
role: Literal["tool"] = "tool"
call_id: str
tool_name: Union[BuiltinTool, str]
content: InterleavedContent
@ -278,14 +276,14 @@ ResponseFormat = register_schema(
class CompletionRequest(BaseModel):
model: str
content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class CompletionResponse(BaseModel):
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
:param content: The generated completion text
@ -299,7 +297,7 @@ class CompletionResponse(BaseModel):
@json_schema_type
class CompletionResponseStreamChunk(BaseModel):
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
@ -357,7 +355,7 @@ class ToolConfig(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams()
sampling_params: Optional[SamplingParams] = Field(default_factory=SamplingParams)
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_config: Optional[ToolConfig] = Field(default_factory=ToolConfig)
@ -368,7 +366,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -378,7 +376,7 @@ class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
@json_schema_type
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
class ChatCompletionResponse(MetricResponseMixin):
"""Response from a chat completion request.
:param completion_message: The complete response message
@ -444,7 +442,7 @@ class Inference(Protocol):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -467,7 +465,7 @@ class Inference(Protocol):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,

View file

@ -17,6 +17,13 @@ ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
:param aggregated_results: Map of metric name to aggregated value
"""
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@ -30,6 +37,12 @@ class ScoreBatchResponse(BaseModel):
@json_schema_type
class ScoreResponse(BaseModel):
"""
The response from scoring.
:param results: A map of scoring function name to ScoringResult.
"""
# each key in the dict is a scoring function name
results: Dict[str, ScoringResult]
@ -55,4 +68,11 @@ class Scoring(Protocol):
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ...
) -> ScoreResponse:
"""Score a list of rows.
:param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results
"""
...

View file

@ -96,6 +96,13 @@ class MetricEvent(EventCommon):
unit: str
@json_schema_type
class MetricInResponse(BaseModel):
metric: str
value: Union[int, float]
unit: Optional[str] = None
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
@ -117,7 +124,7 @@ class MetricEvent(EventCommon):
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
metrics: Optional[List[MetricInResponse]] = None
@json_schema_type

View file

@ -64,7 +64,7 @@ class ModelDescribe(Subcommand):
]
if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.dict()
sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k]
rows.append(

View file

@ -13,7 +13,7 @@ from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent
ROOT_DIR = Path(__file__).parent.parent.parent
class ModelPromptFormat(Subcommand):
@ -44,6 +44,12 @@ class ModelPromptFormat(Subcommand):
default="llama3_1",
help="Model Family (llama3_1, llama3_X, etc.)",
)
self.parser.add_argument(
"-l",
"--list",
action="store_true",
help="List all available models",
)
def _run_model_template_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources

View file

@ -39,7 +39,7 @@ from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@ -170,7 +170,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
sys.exit(1)
if build_config.image_type == ImageType.container.value and not args.image_name:
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
cprint(
"Please specify --image-name when building a container from a config file",
color="red",
@ -226,7 +226,7 @@ def _generate_run_config(
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == ImageType.container.value else None),
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name,
apis=apis,
providers={},
@ -248,7 +248,7 @@ def _generate_run_config(
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"distributions/{image_name}")
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
@ -279,16 +279,16 @@ def _run_stack_build_command_from_build_config(
template_name: Optional[str] = None,
config_path: Optional[str] = None,
) -> str:
if build_config.image_type == ImageType.container.value:
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name:
image_name = f"distribution-{template_name}"
else:
if not image_name:
raise ValueError("Please specify an image name when building a container image without a template")
elif build_config.image_type == ImageType.conda.value:
elif build_config.image_type == LlamaStackImageType.CONDA.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")
elif build_config.image_type == ImageType.venv.value:
elif build_config.image_type == LlamaStackImageType.VENV.value:
if not image_name and os.environ.get("UV_SYSTEM_PYTHON"):
image_name = "__system__"
if not image_name:

View file

@ -16,7 +16,7 @@ class StackBuild(Subcommand):
"build",
prog="llama stack build",
description="Build a Llama stack container",
formatter_class=argparse.RawTextHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_build_command)

View file

@ -5,15 +5,15 @@
# the root directory of this source tree.
import argparse
import logging
import os
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="server")
class StackRun(Subcommand):
@ -23,7 +23,7 @@ class StackRun(Subcommand):
"run",
prog="llama stack run",
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_run_cmd)
@ -37,12 +37,13 @@ class StackRun(Subcommand):
self.parser.add_argument(
"--port",
type=int,
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. Defaults to 8321",
help="Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT.",
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
)
self.parser.add_argument(
"--image-name",
type=str,
default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current conda environment",
)
self.parser.add_argument(
@ -79,12 +80,8 @@ class StackRun(Subcommand):
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import (
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
config_file = Path(args.config)
@ -97,14 +94,6 @@ class StackRun(Subcommand):
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir
config_file = Path(BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to container dir
config_file = Path(BUILDS_BASE_DIR / ImageType.container.value / f"{args.config}-run.yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")

View file

@ -16,7 +16,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__)
@ -95,7 +95,7 @@ def build_image(
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.container.value:
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [
script,
@ -104,7 +104,7 @@ def build_image(
container_base,
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.conda.value:
elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [
script,
@ -112,7 +112,7 @@ def build_image(
str(build_file_path),
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.venv.value:
elif build_config.image_type == LlamaStackImageType.VENV.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [
script,

View file

@ -39,7 +39,7 @@ def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provi
return Provider(
provider_id=provider.provider_id,
provider_type=provider.provider_type,
config=cfg.dict(),
config=cfg.model_dump(),
)

View file

@ -32,7 +32,10 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import (
@ -41,8 +44,10 @@ from llama_stack.distribution.stack import (
redact_sensitive_fields,
replace_env_vars,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.distribution.utils.exec import in_notebook
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
@ -160,6 +165,9 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
except StopAsyncIteration:
pass
finally:
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator()
@ -262,21 +270,25 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not self.endpoint_impls:
raise ValueError("Client not initialized")
# Create headers with provider data if available
headers = {}
if self.provider_data:
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(self.provider_data)})
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
# Use context manager for provider data
with request_provider_data_context(headers):
if stream:
response = await self._call_streaming(
cast_to=cast_to,
options=options,
stream_cls=stream_cls,
)
else:
response = await self._call_non_streaming(
cast_to=cast_to,
options=options,
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
@ -374,9 +386,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
finally:
await end_trace()
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=gen(),
content=wrapped_gen,
headers={
"Content-Type": "application/json",
},

View file

@ -4,16 +4,35 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextvars
import json
import logging
import threading
from typing import Any, Dict
from typing import Any, ContextManager, Dict, Optional
from .utils.dynamic import instantiate_class_type
log = logging.getLogger(__name__)
_THREAD_LOCAL = threading.local()
# Context variable for request provider data
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
"""Context manager for request provider data"""
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
self.provider_data = provider_data
self.token = None
def __enter__(self):
# Save the current value and set the new one
self.token = PROVIDER_DATA_VAR.set(self.provider_data)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the previous value
if self.token is not None:
PROVIDER_DATA_VAR.reset(self.token)
class NeedsRequestProviderData:
@ -26,7 +45,7 @@ class NeedsRequestProviderData:
if not validator_class:
raise ValueError(f"Provider {provider_type} does not have a validator")
val = getattr(_THREAD_LOCAL, "provider_data_header_value", None)
val = PROVIDER_DATA_VAR.get()
if not val:
return None
@ -36,25 +55,32 @@ class NeedsRequestProviderData:
return provider_data
except Exception as e:
log.error(f"Error parsing provider data: {e}")
return None
def set_request_provider_data(headers: Dict[str, str]):
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
"x-llamastack-provider-data",
]
val = None
for key in keys:
val = headers.get(key, None)
if val:
break
if not val:
return
return None
try:
val = json.loads(val)
return json.loads(val)
except json.JSONDecodeError:
log.error("Provider data not encoded as a JSON object!", val)
return
log.error("Provider data not encoded as a JSON object!")
return None
_THREAD_LOCAL.provider_data_header_value = val
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
"""Context manager that sets request provider data from headers for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data)

View file

@ -7,7 +7,6 @@ import importlib
import inspect
from typing import Any, Dict, List, Set, Tuple
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
@ -35,6 +34,7 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
@ -50,6 +50,8 @@ from llama_stack.providers.datatypes import (
VectorDBsProtocolPrivate,
)
logger = get_logger(name=__name__, category="core")
class InvalidProviderError(Exception):
pass
@ -163,7 +165,9 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=[info.routing_table_api.value],
# Add telemetry as an optional dependency to all auto-routed providers
optional_api_dependencies=[Api.telemetry],
deps__=([info.routing_table_api.value, Api.telemetry.value]),
),
)
}
@ -184,7 +188,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)
@ -206,11 +210,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
logcat.error("core", p.deprecation_error)
logger.error(p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)
elif p.deprecation_warning:
logcat.warning(
"core",
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
@ -244,9 +247,10 @@ def sort_providers_by_deps(
)
)
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
logger.debug(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
logcat.debug("core", f" {api_str} => {provider.provider_id}")
logger.debug(f" {api_str} => {provider.provider_id}")
logger.debug("")
return sorted_providers
@ -387,7 +391,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class

View file

@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
@ -65,9 +65,17 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"eval": EvalRouter,
"tool_runtime": ToolRuntimeRouter,
}
api_to_deps = {
"inference": {"telemetry": Api.telemetry},
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")
impl = api_to_routers[api.value](routing_table)
api_to_dep_impl = {}
for dep_name, dep_api in api_to_deps.get(api.value, {}).items():
if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api]
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize()
return impl

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict, List, Optional
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -21,6 +21,10 @@ from llama_stack.apis.eval import (
JobStatus,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -28,13 +32,14 @@ from llama_stack.apis.inference import (
Message,
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
@ -43,6 +48,7 @@ from llama_stack.apis.scoring import (
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
RAGDocument,
RAGQueryConfig,
@ -52,7 +58,13 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import RoutingTable
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
@ -62,15 +74,15 @@ class VectorIORouter(VectorIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing VectorIORouter")
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
@ -81,7 +93,7 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
@ -96,8 +108,7 @@ class VectorIORouter(VectorIO):
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
logcat.debug(
"core",
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
@ -108,7 +119,7 @@ class VectorIORouter(VectorIO):
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -118,16 +129,21 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
) -> None:
logcat.debug("core", "Initializing InferenceRouter")
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
self.telemetry = telemetry
if self.telemetry:
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
logger.debug("InferenceRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
logger.debug("InferenceRouter.shutdown")
pass
async def register_model(
@ -138,17 +154,81 @@ class InferenceRouter(Inference):
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logcat.debug(
"core",
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total number of tokens used
model: Model object containing model_id and provider_id
Returns:
List of MetricEvent objects with token usage metrics
"""
span = get_current_span()
if span is None:
logger.warning("No span found for token usage metrics")
return []
metrics = [
("prompt_tokens", prompt_tokens),
("completion_tokens", completion_tokens),
("total_tokens", total_tokens),
]
metric_events = []
for metric_name, value in metrics:
metric_events.append(
MetricEvent(
trace_id=span.trace_id,
span_id=span.span_id,
metric=metric_name,
value=value,
timestamp=time.time(),
unit="tokens",
attributes={
"model_id": model.model_id,
"provider_id": model.provider_id,
},
)
)
return metric_events
async def _compute_and_log_token_usage(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
@ -156,11 +236,12 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -205,22 +286,60 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
if stream:
return (chunk async for chunk in await provider.chat_completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.chat_completion(**params)
response = await provider.chat_completion(**params)
completion_tokens = await self._count_tokens(
[response.completion_message],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
if sampling_params is None:
sampling_params = SamplingParams()
logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
@ -237,10 +356,41 @@ class InferenceRouter(Inference):
stream=stream,
logprobs=logprobs,
)
prompt_tokens = await self._count_tokens(content)
if stream:
return (chunk async for chunk in await provider.completion(**params))
async def stream_generator():
completion_text = ""
async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics
yield chunk
return stream_generator()
else:
return await provider.completion(**params)
response = await provider.completion(**params)
completion_tokens = await self._count_tokens(response.content)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage(
prompt_tokens or 0,
completion_tokens or 0,
total_tokens,
model,
)
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def embeddings(
self,
@ -250,7 +400,7 @@ class InferenceRouter(Inference):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
@ -270,15 +420,15 @@ class SafetyRouter(Safety):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing SafetyRouter")
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
@ -288,7 +438,7 @@ class SafetyRouter(Safety):
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
@ -297,7 +447,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
@ -310,15 +460,15 @@ class DatasetIORouter(DatasetIO):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
@ -328,7 +478,9 @@ class DatasetIORouter(DatasetIO):
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
@ -337,7 +489,7 @@ class DatasetIORouter(DatasetIO):
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
@ -349,15 +501,15 @@ class ScoringRouter(Scoring):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ScoringRouter")
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
@ -366,7 +518,7 @@ class ScoringRouter(Scoring):
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
@ -387,7 +539,7 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
@ -405,26 +557,26 @@ class EvalRouter(Eval):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing EvalRouter")
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
task_config=task_config,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
@ -432,14 +584,14 @@ class EvalRouter(Eval):
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
benchmark_config=benchmark_config,
)
async def job_status(
@ -447,7 +599,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
@ -455,7 +607,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
@ -466,7 +618,7 @@ class EvalRouter(Eval):
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
@ -479,7 +631,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
@ -488,7 +640,7 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
@ -499,9 +651,8 @@ class ToolRuntimeRouter(ToolRuntime):
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
@ -511,7 +662,7 @@ class ToolRuntimeRouter(ToolRuntime):
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
@ -520,15 +671,15 @@ class ToolRuntimeRouter(ToolRuntime):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
@ -537,5 +688,5 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
@ -366,7 +367,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}

View file

@ -6,12 +6,9 @@
import argparse
import asyncio
import functools
import inspect
import json
import logging
import os
import signal
import sys
import traceback
import warnings
@ -28,10 +25,12 @@ from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated
from llama_stack import logcat
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import (
construct_stack,
@ -39,12 +38,15 @@ from llama_stack.distribution.stack import (
replace_env_vars,
validate_env_pair,
)
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
TelemetryAdapter,
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
@ -54,8 +56,7 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logcat.init()
logger = get_logger(name=__name__, category="server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -117,78 +118,32 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
def handle_signal(app, signum, _) -> None:
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
"""
Handle incoming signals and initiate a graceful shutdown of the application.
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
logcat.info("server", f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logcat.info("server", f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logcat.warning("server", f"No shutdown method for {impl_name}")
except asyncio.TimeoutError:
logcat.exception("server", f"Shutdown timeout for {impl_name}")
except Exception as e:
logcat.exception("server", f"Failed to shutdown {impl_name}: {e}")
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logcat.exception("server", "Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager
async def lifespan(app: FastAPI):
logcat.info("server", "Starting up")
logger.info("Starting up")
yield
logcat.info("server", "Shutting down")
for impl in app.__llama_stack_impls__.values():
await impl.shutdown()
logger.info("Shutting down")
await shutdown(app)
def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -204,15 +159,14 @@ async def maybe_await(value):
async def sse_generator(event_gen):
try:
event_gen = await event_gen
async for item in event_gen:
async for item in await event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logcat.info("server", "Generator cancelled")
logger.info("Generator cancelled")
await event_gen.aclose()
except Exception as e:
logcat.exception("server", "Error in sse_generator")
logger.exception("Error in sse_generator")
yield create_sse_event(
{
"error": {
@ -224,18 +178,22 @@ async def sse_generator(event_gen):
def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs):
set_request_provider_data(request.headers)
# Use context manager for request provider data
with request_provider_data_context(request.headers):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logcat.exception("server", f"Error in {func.__name__}")
raise translate_exception(e) from e
try:
if is_streaming:
gen = preserve_contexts_async_generator(
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
except Exception as e:
logger.exception(f"Error executing endpoint {route=} {method=}")
raise translate_exception(e) from e
sig = inspect.signature(func)
@ -264,7 +222,7 @@ class TracingMiddleware:
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
return await self.app(scope, receive, send)
@ -313,8 +271,6 @@ class ClientVersionMiddleware:
def main():
logcat.init()
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
@ -354,10 +310,10 @@ def main():
for env_pair in args.env:
try:
key, value = validate_env_pair(env_pair)
logcat.info("server", f"Setting CLI environment variable {key} => {value}")
logger.info(f"Setting CLI environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logcat.error("server", f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if args.yaml_config:
@ -365,12 +321,12 @@ def main():
config_file = Path(args.yaml_config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
logcat.info("server", f"Using config file: {config_file}")
logger.info(f"Using config file: {config_file}")
elif args.template:
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
logcat.info("server", f"Using template {args.template} config file: {config_file}")
logger.info(f"Using template {args.template} config file: {config_file}")
else:
raise ValueError("Either --yaml-config or --template must be provided")
@ -378,10 +334,9 @@ def main():
config = replace_env_vars(yaml.safe_load(fp))
config = StackRunConfig(**config)
logcat.info("server", "Run configuration:")
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump())
for log_line in yaml.dump(safe_config, indent=2).split("\n"):
logcat.info("server", log_line)
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
@ -391,7 +346,7 @@ def main():
try:
impls = asyncio.run(construct_stack(config))
except InvalidProviderError as e:
logcat.error("server", f"Error: {str(e)}")
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
@ -436,12 +391,10 @@ def main():
)
)
logcat.debug("server", f"serving APIs: {apis_to_serve}")
logger.debug(f"serving APIs: {apis_to_serve}")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls
@ -463,15 +416,17 @@ def main():
"ssl_keyfile": keyfile,
"ssl_certfile": certfile,
}
logcat.info("server", f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0"
logcat.info("server", f"Listening on {listen_host}:{port}")
logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = {
"app": app,
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
}
if ssl_config:
uvicorn_config.update(ssl_config)

View file

@ -7,12 +7,11 @@
import importlib.resources
import os
import re
import tempfile
from typing import Any, Dict, Optional
import yaml
from termcolor import colored
from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
@ -33,12 +32,16 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
logger = get_logger(name=__name__, category="core")
class LlamaStack(
VectorDBs,
@ -99,9 +102,8 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
objects_to_process = response.data if hasattr(response, "data") else response
for obj in objects_to_process:
logcat.debug(
"core",
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
logger.debug(
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
@ -228,3 +230,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config

View file

@ -100,12 +100,15 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x
$PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args
elif [[ "$env_type" == "container" ]]; then
set -x
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
@ -141,8 +144,6 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version')
fi
set -x
$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \
$env_vars \

View file

@ -17,7 +17,7 @@ llama stack run together
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
```bash
$ llama-stack-client datasets register \
llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
```
```bash
$ llama-stack-client benchmarks register \
llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \

View file

@ -212,7 +212,7 @@ def run_evaluation_3():
benchmark_id=selected_benchmark,
input_rows=[r],
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
task_config=benchmark_config,
benchmark_config=benchmark_config,
)
for k in r.keys():

View file

@ -7,7 +7,6 @@
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file
@ -124,13 +123,14 @@ def rag_chat_page():
else:
strategy = {"type": "greedy"}
agent_config = AgentConfig(
agent = Agent(
llama_stack_api.client,
model=selected_model,
instructions=system_prompt,
sampling_params={
"strategy": strategy,
},
toolgroups=[
tools=[
dict(
name="builtin::rag/knowledge_search",
args={
@ -138,12 +138,7 @@ def rag_chat_page():
},
)
],
tool_choice="auto",
tool_prompt_format="json",
enable_session_persistence=False,
)
agent = Agent(llama_stack_api.client, agent_config)
session_id = agent.create_session("rag-session")
# Chat input

View file

@ -13,6 +13,4 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"

View file

@ -20,14 +20,14 @@ import importlib
import json
from pathlib import Path
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = ""
if image_type == ImageType.container.value or config.container_image:
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image:
env_name = f"distribution-{template_name}" if template_name else config.container_image
elif image_type == ImageType.conda.value:
elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env
if not env_name:

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
import enum
class ImageType(Enum):
container = "container"
conda = "conda"
venv = "venv"
class LlamaStackImageType(enum.Enum):
CONTAINER = "container"
CONDA = "conda"
VENV = "venv"

View file

@ -1,204 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Category-based logging utility for llama-stack.
This module provides a wrapper over the standard Python logging module that supports
categorized logging with environment variable control.
Usage:
from llama_stack import logcat
logcat.info("server", "Starting up...")
logcat.debug("inference", "Processing request...")
Environment variable:
LLAMA_STACK_LOGGING: Semicolon-separated list of category=level pairs
Example: "server=debug;inference=warning"
"""
import datetime
import logging
import os
from typing import Dict
# ANSI color codes for terminal output
COLORS = {
"RESET": "\033[0m",
"DEBUG": "\033[36m", # Cyan
"INFO": "\033[32m", # Green
"WARNING": "\033[33m", # Yellow
"ERROR": "\033[31m", # Red
"CRITICAL": "\033[35m", # Magenta
"DIM": "\033[2m", # Dimmed text
"YELLOW_DIM": "\033[2;33m", # Dimmed yellow
}
# Static list of valid categories representing various parts of the Llama Stack
# server codebase
CATEGORIES = [
"core",
"server",
"router",
"inference",
"agents",
"safety",
"eval",
"tools",
"client",
]
_logger = logging.getLogger("llama_stack")
_logger.propagate = False
_default_level = logging.INFO
# Category-level mapping (can be modified by environment variables)
_category_levels: Dict[str, int] = {}
class TerminalStreamHandler(logging.StreamHandler):
def __init__(self, stream=None):
super().__init__(stream)
self.is_tty = hasattr(self.stream, "isatty") and self.stream.isatty()
def format(self, record):
record.is_tty = self.is_tty
return super().format(record)
class ColoredFormatter(logging.Formatter):
"""Custom formatter with colors and fixed-width level names"""
def format(self, record):
levelname = record.levelname
# Use only time with milliseconds, not date
timestamp = datetime.datetime.now().strftime("%H:%M:%S.%f")[:-3] # HH:MM:SS.mmm format
file_info = f"{record.filename}:{record.lineno}"
# Get category from extra if available
category = getattr(record, "category", None)
msg = record.getMessage()
if getattr(record, "is_tty", False):
color = COLORS.get(levelname, COLORS["RESET"])
if category:
category_formatted = f"{COLORS['YELLOW_DIM']}{category}{COLORS['RESET']} "
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']} "
f"{file_info:<20} {category_formatted}{msg}"
)
else:
formatted_msg = (
f"{color}{levelname:<7}{COLORS['RESET']} {COLORS['DIM']}{timestamp}{COLORS['RESET']}] "
f"{file_info:<20} {msg}"
)
else:
if category:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} [{category}] {msg}"
else:
formatted_msg = f"{levelname:<7} {timestamp} {file_info:<20} {msg}"
return formatted_msg
def init(default_level: int = logging.INFO) -> None:
global _default_level, _category_levels, _logger
_default_level = default_level
_logger.setLevel(logging.DEBUG)
_logger.handlers = [] # Clear existing handlers
# Add our custom handler with the colored formatter
handler = TerminalStreamHandler()
formatter = ColoredFormatter()
handler.setFormatter(formatter)
_logger.addHandler(handler)
for category in CATEGORIES:
_category_levels[category] = default_level
env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config:
for pair in env_config.split(";"):
if not pair.strip():
continue
try:
category, level = pair.split("=", 1)
category = category.strip().lower()
level = level.strip().lower()
level_value = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"warn": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}.get(level)
if level_value is None:
_logger.warning(f"Unknown log level '{level}' for category '{category}'")
continue
if category == "all":
for cat in CATEGORIES:
_category_levels[cat] = level_value
else:
if category in CATEGORIES:
_category_levels[category] = level_value
else:
_logger.warning(f"Unknown logging category: {category}")
except ValueError:
_logger.warning(f"Invalid logging configuration: {pair}")
def _should_log(level: int, category: str) -> bool:
category = category.lower()
if category not in _category_levels:
return False
category_level = _category_levels[category]
return level >= category_level
def _log(level: int, level_name: str, category: str, msg: str, *args, **kwargs) -> None:
if _should_log(level, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
getattr(_logger, level_name)(msg, *args, stacklevel=3, **kwargs)
def debug(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.DEBUG, "debug", category, msg, *args, **kwargs)
def info(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.INFO, "info", category, msg, *args, **kwargs)
def warning(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.WARNING, "warning", category, msg, *args, **kwargs)
def warn(category: str, msg: str, *args, **kwargs) -> None:
warning(category, msg, *args, **kwargs)
def error(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.ERROR, "error", category, msg, *args, **kwargs)
def critical(category: str, msg: str, *args, **kwargs) -> None:
_log(logging.CRITICAL, "critical", category, msg, *args, **kwargs)
def exception(category: str, msg: str, *args, **kwargs) -> None:
if _should_log(logging.ERROR, category):
kwargs.setdefault("extra", {})["category"] = category.lower()
_logger.exception(msg, *args, stacklevel=2, **kwargs)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -12,12 +12,11 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from typing import AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx
from llama_stack import logcat
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroup,
@ -31,7 +30,6 @@ from llama_stack.apis.agents import (
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
@ -68,6 +66,7 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
@ -89,6 +88,8 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
class ChatAgent(ShieldRunnerMixin):
def __init__(
@ -152,7 +153,6 @@ class ChatAgent(ShieldRunnerMixin):
messages.append(
ToolResponseMessage(
call_id=response.call_id,
tool_name=response.tool_name,
content=response.content,
)
)
@ -180,120 +180,58 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
await self._initialize_tools(request.toolgroups)
async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
start_time = datetime.now().astimezone().isoformat()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents,
toolgroups_for_turn=request.toolgroups,
):
if isinstance(chunk, CompletionMessage):
logcat.info(
"agents",
f"returning result from the agent turn: {chunk}",
)
output_message = chunk
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
async for chunk in self._run_turn(request, turn_id):
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls and request.allow_turn_resume:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
await self._initialize_tools()
async with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
async for chunk in self._run_turn(request):
yield chunk
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
async def _run_turn(
self,
request: Union[AgentTurnCreateRequest, AgentTurnResumeRequest],
turn_id: Optional[str] = None,
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
turns = await self.storage.get_session_turns(request.session_id)
if len(turns) == 0:
raise ValueError("No turns found for session")
is_resume = isinstance(request, AgentTurnResumeRequest)
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)
turns = await self.storage.get_session_turns(request.session_id)
if is_resume and len(turns) == 0:
raise ValueError("No turns found for session")
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(tool_response_messages)
# TODO: figure out whether we should add the tool responses to the last turn messages
last_turn_messages.extend(request.tool_responses)
# get the steps from the turn id
steps = []
steps = turns[-1].steps
# get steps from the turn
steps = last_turn.steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
@ -306,14 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
@ -327,62 +258,66 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
input_messages = last_turn_messages
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
turn_id = request.turn_id
start_time = last_turn.started_at
else:
messages.extend(request.messages)
start_time = datetime.now().astimezone().isoformat()
input_messages = request.messages
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
documents=request.documents if not is_resume else None,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue
yield chunk
assert output_message is not None
last_turn_start_time = datetime.now().astimezone().isoformat()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at
turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=input_messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now().astimezone().isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
yield chunk
async def run(
self,
session_id: str,
@ -391,7 +326,6 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -414,7 +348,6 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params,
stream,
documents,
toolgroups_for_turn,
):
if isinstance(res, bool):
return
@ -446,7 +379,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -515,27 +448,19 @@ class ChatAgent(ShieldRunnerMixin):
sampling_params: SamplingParams,
stream: bool = False,
documents: Optional[List[Document]] = None,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> AsyncGenerator:
# TODO: simplify all of this code, it can be simpler
toolgroup_args = {}
toolgroups = set()
for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name)
toolgroups.add(tool_group_name)
toolgroup_args[tool_group_name] = toolgroup.args
else:
toolgroups.add(toolgroup)
tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn)
if documents:
await self.handle_documents(session_id, documents, input_messages, tool_defs)
await self.handle_documents(session_id, documents, input_messages)
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
if session_info and session_info.vector_db_id:
toolgroup_args[RAG_TOOL_GROUP]["vector_db_ids"].append(session_info.vector_db_id)
for tool_name in self.tool_name_to_args.keys():
if tool_name == MEMORY_QUERY_TOOL:
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
else:
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
output_attachments = []
@ -561,11 +486,11 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference") as span:
async with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=tool_defs,
tools=self.tool_defs,
tool_prompt_format=self.agent_config.tool_config.tool_prompt_format,
response_format=self.agent_config.response_format,
stream=True,
@ -664,7 +589,7 @@ class ChatAgent(ShieldRunnerMixin):
)
if n_iter >= self.agent_config.max_infer_iters:
logcat.info("agents", f"done with MAX iterations ({n_iter}), exiting.")
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
message.stop_reason = StopReason.end_of_turn
@ -672,7 +597,7 @@ class ChatAgent(ShieldRunnerMixin):
break
if stop_reason == StopReason.out_of_tokens:
logcat.info("agents", "out of token budget, exiting.")
logger.info("out of token budget, exiting.")
yield message
break
@ -686,10 +611,10 @@ class ChatAgent(ShieldRunnerMixin):
message.content = [message.content] + output_attachments
yield message
else:
logcat.debug("agents", f"completion message with EOM (iter: {n_iter}): {str(message)}")
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
logcat.debug("agents", f"completion message (iter: {n_iter}) from the model: {str(message)}")
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -738,7 +663,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
async with tracing.span(
"tool_execution",
{
"tool_name": tool_name,
@ -747,12 +672,9 @@ class ChatAgent(ShieldRunnerMixin):
) as span:
tool_execution_start_time = datetime.now().astimezone().isoformat()
tool_call = message.tool_calls[0]
tool_result = await execute_tool_call_maybe(
self.tool_runtime_api,
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
toolgroup_args,
tool_to_group,
)
if tool_result.content is None:
raise ValueError(
@ -761,7 +683,6 @@ class ChatAgent(ShieldRunnerMixin):
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
)
]
@ -781,7 +702,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=result_message.tool_name,
tool_name=tool_call.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
@ -805,9 +726,16 @@ class ChatAgent(ShieldRunnerMixin):
input_messages = input_messages + [message, result_message]
async def _get_tool_defs(
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
async def _initialize_tools(
self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
) -> None:
toolgroup_to_args = {}
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
if isinstance(toolgroup, AgentToolGroupWithArgs):
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
toolgroup_to_args[tool_group_name] = toolgroup.args
# Determine which tools to include
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
agent_config_toolgroups = []
@ -816,8 +744,10 @@ class ChatAgent(ShieldRunnerMixin):
if name not in agent_config_toolgroups:
agent_config_toolgroups.append(name)
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_to_group = {}
tool_name_to_args = {}
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
@ -835,53 +765,38 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
tool_to_group[tool_def.name] = "__client_tools__"
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
if not tools.data:
available_tool_groups = ", ".join(
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
)
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data):
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
raise ValueError(
f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
)
for tool_def in tools.data:
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
tool_name = tool_def.identifier
built_in_type = BuiltinTool.brave_search
if tool_name == "web_search":
built_in_type = BuiltinTool.brave_search
identifier: str | BuiltinTool | None = tool_def.identifier
if identifier == "web_search":
identifier = BuiltinTool.brave_search
else:
built_in_type = BuiltinTool(tool_name)
identifier = BuiltinTool(identifier)
else:
# add if tool_name is unspecified or the tool_def identifier is the same as the tool_name
if input_tool_name in (None, tool_def.identifier):
identifier = tool_def.identifier
else:
identifier = None
if tool_name_to_def.get(built_in_type, None):
raise ValueError(f"Tool {built_in_type} already exists")
tool_name_to_def[built_in_type] = ToolDefinition(
tool_name=built_in_type,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool_def.parameters
},
)
tool_to_group[built_in_type] = tool_def.toolgroup_id
continue
if tool_name_to_def.get(tool_def.identifier, None):
raise ValueError(f"Tool {tool_def.identifier} already exists")
if tool_name in (None, tool_def.identifier):
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[tool_def.identifier] = ToolDefinition(
tool_name=tool_def.identifier,
tool_name=identifier,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
@ -893,9 +808,9 @@ class ChatAgent(ShieldRunnerMixin):
for param in tool_def.parameters
},
)
tool_to_group[tool_def.identifier] = tool_def.toolgroup_id
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
return list(tool_name_to_def.values()), tool_to_group
self.tool_defs, self.tool_name_to_args = list(tool_name_to_def.values()), tool_name_to_args
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
"""Parse a toolgroup name into its components.
@ -914,15 +829,46 @@ class ChatAgent(ShieldRunnerMixin):
tool_group, tool_name = split_names[0], None
return tool_group, tool_name
async def execute_tool_call_maybe(
self,
session_id: str,
tool_call: ToolCall,
) -> ToolInvocationResult:
tool_name = tool_call.tool_name
registered_tool_names = [tool_def.tool_name for tool_def in self.tool_defs]
if tool_name not in registered_tool_names:
raise ValueError(
f"Tool {tool_name} not found in provided tools, registered tools: {', '.join([str(x) for x in registered_tool_names])}"
)
if isinstance(tool_name, BuiltinTool):
if tool_name == BuiltinTool.brave_search:
tool_name_str = WEB_SEARCH_TOOL
else:
tool_name_str = tool_name.value
else:
tool_name_str = tool_name
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**self.tool_name_to_args.get(tool_name_str, {}),
},
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
async def handle_documents(
self,
session_id: str,
documents: List[Document],
input_messages: List[Message],
tool_defs: Dict[str, ToolDefinition],
) -> None:
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs)
memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in self.tool_defs)
code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in self.tool_defs)
content_items = []
url_items = []
pattern = re.compile("^(https?://|file://|data:)")
@ -1032,7 +978,7 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
path = urlparse(uri).path
basename = os.path.basename(path)
filepath = f"{tempdir}/{make_random_string() + basename}"
logcat.info("agents", f"Downloading {url} -> {filepath}")
logger.info(f"Downloading {url} -> {filepath}")
async with httpx.AsyncClient() as client:
r = await client.get(uri)
@ -1050,42 +996,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
return ToolResponseMessage(
call_id="",
tool_name=BuiltinTool.code_interpreter,
content=content,
)
async def execute_tool_call_maybe(
tool_runtime_api: ToolRuntime,
session_id: str,
tool_call: ToolCall,
toolgroup_args: Dict[str, Dict[str, Any]],
tool_to_group: Dict[str, str],
) -> ToolInvocationResult:
name = tool_call.tool_name
group_name = tool_to_group.get(name, None)
if group_name is None:
raise ValueError(f"Tool {name} not found in any tool group")
if isinstance(name, BuiltinTool):
if name == BuiltinTool.brave_search:
name = WEB_SEARCH_TOOL
else:
name = name.value
logcat.info("agents", f"executing tool call: {name} with args: {tool_call.arguments}")
result = await tool_runtime_api.invoke_tool(
tool_name=name,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**tool_call.arguments,
**toolgroup_args.get(group_name, {}),
},
)
logcat.debug("agents", f"tool call {name} completed with result: {result}")
return result
def _interpret_content_as_attachment(
content: str,
) -> Optional[Attachment]:

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
@ -21,12 +22,15 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session,
Turn,
)
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
@ -83,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
@ -119,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
@ -140,7 +144,6 @@ class MetaReferenceAgentsImpl(Agents):
documents: Optional[List[Document]] = None,
stream: Optional[bool] = False,
tool_config: Optional[ToolConfig] = None,
allow_turn_resume: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
@ -150,7 +153,6 @@ class MetaReferenceAgentsImpl(Agents):
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
allow_turn_resume=allow_turn_resume,
)
if stream:
return self._create_agent_turn_streaming(request)
@ -161,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
@ -170,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: List[ToolResponseMessage],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
@ -189,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
@ -211,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -233,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass

View file

@ -10,6 +10,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__)
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(
shield_id=identifier,
messages=messages,
)
for identifier in identifiers
]
)
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
for identifier, response in zip(identifiers, responses, strict=False):
if not response.violation:
continue

View file

@ -1,400 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import AsyncIterator, List, Optional, Union
import pytest
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseTurnCompletePayload,
StepType,
)
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import (
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
)
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.inline.agents.meta_reference.agents import (
MetaReferenceAgentsImpl,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="start",
delta="",
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="progress",
delta="AI is a fascinating field...",
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
delta="",
stop_reason="end_of_turn",
)
)
if stream:
return stream_response()
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant",
content="Mock response",
stop_reason="end_of_turn",
),
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
)
class MockSafetyAPI:
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
return RunShieldResponse(violation=None)
class MockVectorIOAPI:
def __init__(self):
self.chunks = {}
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
for chunk in chunks:
metadata = chunk.metadata
self.chunks[vector_db_id][metadata["document_id"]] = chunk
async def query_chunks(self, vector_db_id, query, params=None):
if vector_db_id not in self.chunks:
raise ValueError(f"Bank {vector_db_id} not found")
chunks = list(self.chunks[vector_db_id].values())
scores = [1.0] * len(chunks)
return QueryChunksResponse(chunks=chunks, scores=scores)
class MockToolGroupsAPI:
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return ToolGroup(
identifier=toolgroup_id,
provider_resource_id=toolgroup_id,
)
async def list_tool_groups(self) -> List[ToolGroup]:
return []
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
if tool_group_id == MEMORY_TOOLGROUP:
return [
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::rag",
parameters=[],
)
]
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
return [
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
return []
async def get_tool(self, tool_name: str) -> Tool:
return Tool(
identifier=tool_name,
provider_resource_id=tool_name,
toolgroup_id="mock_group",
tool_host=ToolHost.client,
description="Mock tool",
provider_id="mock_provider",
parameters=[],
)
async def unregister_tool_group(self, tool_group_id: str) -> None:
pass
class MockToolRuntimeAPI:
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return []
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
return ToolInvocationResult(content={"result": "Mock tool result"})
@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()
@pytest.fixture
def mock_safety_api():
return MockSafetyAPI()
@pytest.fixture
def mock_vector_io_api():
return MockVectorIOAPI()
@pytest.fixture
def mock_tool_groups_api():
return MockToolGroupsAPI()
@pytest.fixture
def mock_tool_runtime_api():
return MockToolRuntimeAPI()
@pytest.fixture
async def get_agents_impl(
mock_inference_api,
mock_safety_api,
mock_vector_io_api,
mock_tool_runtime_api,
mock_tool_groups_api,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
impl = MetaReferenceAgentsImpl(
config=MetaReferenceAgentsImplConfig(
persistence_store=SqliteKVStoreConfig(
db_name=sqlite_file.name,
),
),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
vector_io_api=mock_vector_io_api,
tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api,
)
await impl.initialize()
return impl
@pytest.fixture
async def get_chat_agent(get_agents_impl):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=[],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
MEMORY_TOOLGROUP = "builtin::rag"
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
@pytest.fixture
async def get_chat_agent_with_tools(get_agents_impl, request):
impl = await get_agents_impl
toolgroups = request.param
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
@pytest.mark.asyncio
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Hello")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
assert len(responses) > 0
assert (
len(responses) == 7
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
assert responses[0].event.payload.turn_id is not None
@pytest.mark.asyncio
async def test_run_multiple_shields_wrapper(get_chat_agent):
chat_agent = await get_chat_agent
messages = [UserMessage(content="Test message")]
shields = ["test_shield"]
responses = [
chunk
async for chunk in chat_agent.run_multiple_shields_wrapper(
turn_id="test_turn_id",
messages=messages,
shields=shields,
touchpoint="user-input",
)
]
assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.violation
@pytest.mark.asyncio
async def test_chat_agent_complex_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
assert len(responses) > 0
step_types = [
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
]
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
event_types = [
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
]
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
"Turn complete event is missing"
)
turn_complete_payload = next(
response.event.payload
for response in responses
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
)
turn = turn_complete_payload.turn
assert turn.input_messages == request.messages, "Input messages do not match"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"toolgroups, expected_memory, expected_code_interpreter",
[
([], False, False), # no tools
([MEMORY_TOOLGROUP], True, False), # memory only
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs()
if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs
if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs
if expected_memory and expected_code_interpreter:
# override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs(
toolgroups_for_turn=[
AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP,
args={"vector_dbs": ["test_vector_db"]},
)
]
)
assert MEMORY_QUERY_TOOL in new_tool_defs
assert BuiltinTool.code_interpreter not in new_tool_defs

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps,
_deps: Dict[str, Any],
):
from .datasetio import LocalFSDatasetIOImpl

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
class LocalFSDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "localfs_datasetio.db").as_posix()
) # Uses SQLite config specific to localfs storage
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="localfs_datasetio.db",
)
}

View file

@ -172,7 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
class MetaReferenceEvalConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "meta_reference_eval.db").as_posix()
) # Uses SQLite config specific to Meta Reference Eval storage
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="meta_reference_eval.db",
)
}

View file

@ -83,7 +83,7 @@ class MetaReferenceEvalImpl(
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> Job:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
@ -92,13 +92,13 @@ class MetaReferenceEvalImpl(
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(-1 if task_config.num_examples is None else task_config.num_examples),
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
task_config=task_config,
benchmark_config=benchmark_config,
)
# TODO: currently needs to wait for generation before returning
@ -108,9 +108,9 @@ class MetaReferenceEvalImpl(
return Job(job_id=job_id)
async def _run_agent_generation(
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
@ -151,9 +151,9 @@ class MetaReferenceEvalImpl(
return generations
async def _run_model_generation(
self, input_rows: List[Dict[str, Any]], task_config: BenchmarkConfig
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
) -> List[Dict[str, Any]]:
candidate = task_config.eval_candidate
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
generations = []
@ -189,13 +189,13 @@ class MetaReferenceEvalImpl(
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, task_config)
generations = await self._run_agent_generation(input_rows, benchmark_config)
elif candidate.type == "model":
generations = await self._run_model_generation(input_rows, task_config)
generations = await self._run_model_generation(input_rows, benchmark_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")
@ -204,9 +204,9 @@ class MetaReferenceEvalImpl(
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
]
if task_config.scoring_params is not None:
if benchmark_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Union
from typing import Any, Dict, Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
_deps: Dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -136,11 +136,13 @@ class MetaReferenceInferenceImpl(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
@ -244,7 +246,7 @@ class MetaReferenceInferenceImpl(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -253,6 +255,8 @@ class MetaReferenceInferenceImpl(
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
_deps: Dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -53,7 +53,7 @@ class SentenceTransformersInferenceImpl(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -64,7 +64,7 @@ class SentenceTransformersInferenceImpl(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import Any, Dict
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)

View file

@ -4,20 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field, field_validator
from typing import Any, Dict
from pydantic import BaseModel, Field
from llama_stack.providers.utils.inference import supported_inference_models
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
"""Configuration for the vLLM inference provider.
Note that the model name is no longer part of this static configuration.
You can bind an instance of this provider to a specific model with the
``models.register()`` API call."""
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
@ -26,32 +27,27 @@ class VLLMConfig(BaseModel):
default=4096,
description="Maximum number of tokens to generate.",
)
max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.")
max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.")
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
description=(
"How much GPU memory will be allocated when this provider has finished "
"loading, including memory that was already allocated before loading."
),
)
@classmethod
def sample_run_config(cls):
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"model": "${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}",
"tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:1}",
"max_tokens": "${env.MAX_TOKENS:4096}",
"max_model_len": "${env.MAX_MODEL_LEN:4096}",
"max_num_seqs": "${env.MAX_NUM_SEQS:4}",
"enforce_eager": "${env.ENFORCE_EAGER:False}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.7}",
"gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:0.3}",
}
@field_validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = supported_inference_models()
descriptors = [m.descriptor() for m in permitted_models]
repos = [m.huggingface_repo for m in permitted_models]
if model not in (descriptors + repos):
model_list = "\n\t".join(repos)
raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
return model

View file

@ -4,45 +4,71 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import json
import re
import uuid
from typing import AsyncGenerator, List, Optional
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat,
Inference,
InterleavedContentItem,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
TokenLogProbs,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama import sku_list
from llama_stack.models.llama.datatypes import (
StopReason,
ToolCall,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
get_sampling_options,
process_chat_completion_response,
get_stop_reason,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -50,188 +76,322 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
log = logging.getLogger(__name__)
# Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers.
# TODO: Expand this list
CONFIG_TYPE_TO_TOOL_PARSER = {
"GraniteConfig": "granite",
"MllamaConfig": "llama3_json",
"LlamaConfig": "llama3_json",
}
DEFAULT_TOOL_PARSER = "pythonic"
def _random_uuid() -> str:
logger = get_logger(__name__, category="inference")
def _random_uuid_str() -> str:
return str(uuid.uuid4().hex)
def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
"""
if response_format is None:
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
# value that crashes the executor on some code paths. Use ``None`` instead.
return None
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
# Translate the types that exist and detect if Llama Stack adds new ones.
if isinstance(response_format, JsonSchemaResponseFormat):
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
elif isinstance(response_format, GrammarResponseFormat):
# BNF grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
# representation of the grammar.
raise TypeError(
"Constrained decoding with BNF grammars is not currently implemented, because the "
"reference implementation does not implement it."
)
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
def _convert_sampling_params(
sampling_params: Optional[SamplingParams],
response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig],
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
# Stack dataclasses. These defaults are different from vLLM's defaults.
if sampling_params is None:
sampling_params = SamplingParams()
if log_prob_config is None:
log_prob_config = LogProbConfig()
if isinstance(sampling_params.strategy, TopKSamplingStrategy):
if sampling_params.strategy.top_k == 0:
# vLLM treats "k" differently for top-k sampling
vllm_top_k = -1
else:
vllm_top_k = sampling_params.strategy.top_k
else:
vllm_top_k = -1
if isinstance(sampling_params.strategy, TopPSamplingStrategy):
vllm_top_p = sampling_params.strategy.top_p
# Llama Stack only allows temperature with top-P.
vllm_temperature = sampling_params.strategy.temperature
else:
vllm_top_p = 1.0
vllm_temperature = 0.0
# vLLM allows top-p and top-k at the same time.
vllm_sampling_params = vllm.SamplingParams.from_optional(
max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens),
temperature=vllm_temperature,
top_p=vllm_top_p,
top_k=vllm_top_k,
repetition_penalty=sampling_params.repetition_penalty,
guided_decoding=_response_format_to_guided_decoding_params(response_format),
logprobs=log_prob_config.top_k,
)
return vllm_sampling_params
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"""Inference implementation for vLLM."""
"""
vLLM-based inference model adapter for Llama Stack with support for multiple models.
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
"""
config: VLLMConfig
register_helper: ModelRegistryHelper
model_ids: set[str]
resolved_model_id: str | None
engine: AsyncLLMEngine | None
chat: OpenAIServingChat | None
is_meta_llama_model: bool
def __init__(self, config: VLLMConfig):
self.config = config
logger.info(f"Config is: {self.config}")
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.formatter = ChatFormat(Tokenizer.get_instance())
# The following are initialized when paths are bound to this provider
self.resolved_model_id = None
self.model_ids = set()
self.engine = None
self.chat = None
self.is_meta_llama_model = False
async def initialize(self):
log.info("Initializing vLLM inference provider.")
###########################################################################
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
# https://docs.vllm.ai/en/latest/serving/usage_stats.html
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
async def initialize(self) -> None:
"""
Callback that is invoked through many levels of indirection during provider class
instantiation, sometime after when __init__() is called and before any model registration
methods or methods connected to a REST API are called.
model = resolve_model(self.config.model)
if model is None:
raise ValueError(f"Unknown model {self.config.model}")
It's not clear what assumptions the class can make about the platform's initialization
state here that can't be made during __init__(), and vLLM can't be started until we know
what model it's supposed to be serving, so nothing happens here currently.
"""
pass
if model.huggingface_repo is None:
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs(
model=model.huggingface_repo,
tokenizer=model.huggingface_repo,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
guided_decoding_backend="lm-format-enforcer",
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
async def shutdown(self):
"""Shut down the vLLM inference adapter."""
log.info("Shutting down vLLM inference provider.")
if self.engine:
async def shutdown(self) -> None:
logger.info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None:
self.engine.shutdown_background_loop()
self.engine = None
self.chat = None
self.model_ids = set()
self.resolved_model_id = None
###########################################################################
# METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint
with an inference provider.
Callback that is called when the server associates an inference endpoint with an
inference provider.
:param model: Object that encapsulates parameters necessary for identifying
a specific LLM.
:param model: Object that encapsulates parameters necessary for identifying a specific
LLM.
:returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object.
:returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object.
"""
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
# The current version of this provided is hard-coded to serve only
# the model specified in the YAML config file.
configured_model = resolve_model(self.config.model)
registered_model = resolve_model(model.model_id)
logger.debug(f"In register_model({model})")
# First attempt to interpret the model coordinates as a Llama model name
resolved_llama_model = sku_list.resolve_model(model.provider_model_id)
if resolved_llama_model is not None:
# Load from Hugging Face repo into default local cache dir
model_id_for_vllm = resolved_llama_model.huggingface_repo
# Detect a genuine Meta Llama model to trigger Meta-specific preprocessing.
# Don't set self.is_meta_llama_model until we actually load the model.
is_meta_llama_model = True
else: # if resolved_llama_model is None
# Not a Llama model name. Pass the model id through to vLLM's loader
model_id_for_vllm = model.provider_model_id
is_meta_llama_model = False
if self.resolved_model_id is not None:
if model_id_for_vllm != self.resolved_model_id:
raise ValueError(
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple "
f"copies of the provider instead."
)
else:
# Model already loaded
logger.info(
f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing."
)
self.model_ids.add(model.model_id)
return model
logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.")
if is_meta_llama_model:
logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.")
self.is_meta_llama_model = is_meta_llama_model
# If we get here, this is the first time registering a model.
# Preload so that the first inference request won't time out.
engine_args = AsyncEngineArgs(
model=model_id_for_vllm,
tokenizer=model_id_for_vllm,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
max_num_seqs=self.config.max_num_seqs,
max_model_len=self.config.max_model_len,
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
# vLLM currently requires the user to specify the tool parser manually. To choose a tool
# parser, we need to determine what model architecture is being used. For now, we infer
# that information from what config class the model uses.
low_level_model_config = self.engine.engine.get_model_config()
hf_config = low_level_model_config.hf_config
hf_config_class_name = hf_config.__class__.__name__
if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER:
tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name]
else:
# No info -- choose a default so we can at least attempt tool
# use.
tool_parser = DEFAULT_TOOL_PARSER
logger.debug(f"{hf_config_class_name=}")
logger.debug(f"{tool_parser=}")
# Wrap the lower-level engine in an OpenAI-compatible chat API
model_config = await self.engine.get_model_config()
self.chat = OpenAIServingChat(
engine_client=self.engine,
model_config=model_config,
models=OpenAIServingModels(
engine_client=self.engine,
model_config=model_config,
base_model_paths=[
# The layer below us will only see resolved model IDs
BaseModelPath(model_id_for_vllm, model_id_for_vllm)
],
),
response_role="assistant",
request_logger=None, # Use default logging
chat_template=None, # Use default template from model checkpoint
enable_auto_tools=True,
tool_parser=tool_parser,
chat_template_content_format="auto",
)
self.resolved_model_id = model_id_for_vllm
self.model_ids.add(model.model_id)
logger.info(f"Finished preloading model: {model_id_for_vllm}")
if configured_model.core_model_id != registered_model.core_model_id:
raise ValueError(
f"Requested model '{model.identifier}' is different from "
f"model '{self.config.model}' that this provider "
f"is configured to serve"
)
return model
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
options = get_sampling_options(sampling_params)
if "repeat_penalty" in options:
options["repetition_penalty"] = options["repeat_penalty"]
del options["repeat_penalty"]
return VLLMSamplingParams(**options)
async def unregister_model(self, model_id: str) -> None:
pass
"""
Callback that is called when the server removes an inference endpoint from an inference
provider.
:param model_id: The same external ID that the higher layers of the stack previously passed
to :func:`register_model()`
"""
if model_id not in self.model_ids:
raise ValueError(
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
)
self.model_ids.remove(model_id)
if len(self.model_ids) == 0:
# Last model was just unregistered. Shut down the connection to vLLM and free up
# resources.
# Note that this operation may cause in-flight chat completion requests on the
# now-unregistered model to return errors.
self.resolved_model_id = None
self.chat = None
self.engine.shutdown_background_loop()
self.engine = None
###########################################################################
# METHODS INHERITED FROM Inference INTERFACE
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
raise NotImplementedError("Completion not implemented for vLLM")
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
if not isinstance(content, str):
raise NotImplementedError("Multimodal input not currently supported")
if sampling_params is None:
sampling_params = SamplingParams()
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
assert self.engine is not None
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
logger.debug(f"{converted_sampling_params=}")
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = await chat_completion_request_to_prompt(request, self.config.model)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
if stream:
return self._stream_chat_completion(request, results_generator)
return self._streaming_completion(content, converted_sampling_params)
else:
return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> ChatCompletionResponse:
outputs = [o async for o in results_generator]
final_output = outputs[-1]
assert final_output is not None
outputs = final_output.outputs
finish_reason = outputs[-1].stop_reason
choice = OpenAICompatCompletionChoice(
finish_reason=finish_reason,
text="".join([output.text for output in outputs]),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
tokenizer = Tokenizer.get_instance()
async def _generate_and_convert_to_openai_compat():
cur = []
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :]
text = tokenizer.decode(new_tokens)
cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice(
finish_reason=output.finish_reason,
text=text,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
streaming_result = None
async for _ in self._streaming_completion(content, converted_sampling_params):
pass
return CompletionResponse(
content=streaming_result.delta,
stop_reason=streaming_result.stop_reason,
logprobs=streaming_result.logprobs,
)
async def embeddings(
self,
@ -242,3 +402,391 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model_id: str,
messages: List[Message], # type: ignore
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, # type: ignore
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Convert to Llama Stack internal format for consistency
request = ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if self.is_meta_llama_model:
# Bypass vLLM chat templating layer for Meta Llama models, because the
# templating layer in Llama Stack currently produces better results.
logger.debug(
f"Routing {self.resolved_model_id} chat completion through "
f"Llama Stack's templating layer instead of vLLM's."
)
return await self._chat_completion_for_meta_llama(request)
logger.debug(f"{self.resolved_model_id} is not a Meta Llama model")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
logger.debug(f"Converted request: {chat_completion_request}")
vllm_result = await self.chat.create_chat_completion(chat_completion_request)
logger.debug(f"Result from vLLM: {vllm_result}")
if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse):
raise ValueError(f"Error from vLLM layer: {vllm_result}")
# Return type depends on "stream" argument
if stream:
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call")
return self._convert_non_streaming_results(vllm_result)
###########################################################################
# INTERNAL METHODS
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
that arguments have been validated upstream.
:param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format
"""
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
# layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput.
results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below.
# AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if
# we drill down to the LLMEngine inside the AsyncLLMEngine.
# Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup,
# and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup.
llm_engine = self.engine.engine
tokenizer_group = llm_engine.tokenizer
eos_token_id = tokenizer_group.tokenizer.eos_token_id
request_output: vllm.RequestOutput = None
async for request_output in results_generator:
# Check for weird inference failures
if request_output.outputs is None or len(request_output.outputs) == 0:
# This case also should never happen
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
# us to return one.
output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text
# Convert logprobs from vLLM's format to Llama Stack's format
logprobs = [
TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()})
for logprob_dict in output.logprobs
]
# The final output chunk should be labeled with the reason that the overall generate()
# call completed.
logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}")
if output.stop_reason is None:
stop_reason = None # Still going
elif output.stop_reason == "stop":
stop_reason = StopReason.end_of_turn
elif output.stop_reason == "length":
stop_reason = StopReason.out_of_tokens
elif isinstance(output.stop_reason, int):
# If the model config specifies multiple end-of-sequence tokens, then vLLM
# will return the token ID of the EOS token in the stop_reason field.
stop_reason = StopReason.end_of_turn
else:
raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'")
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
# some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
# provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta=completion_string,
stop_reason=StopReason.out_of_tokens,
logprobs=logprobs,
)
def _convert_non_streaming_results(
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse:
"""
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
equivalent Llama Stack object.
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
the fields that aren't currently present in the Llama Stack dataclass.
"""
# There may be multiple responses, but we can only pass through the first one.
if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message
vllm_finish_reason = vllm_result.choices[0].finish_reason
converted_message = CompletionMessage(
role=vllm_message.role,
# Llama Stack API won't accept None for content field.
content=("" if vllm_message.content is None else vllm_message.content),
stop_reason=get_stop_reason(vllm_finish_reason),
tool_calls=[
ToolCall(
call_id=t.id,
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
)
for t in vllm_message.tool_calls
],
)
# TODO: Convert logprobs
logger.debug(f"Converted message: {converted_message}")
return ChatCompletionResponse(
completion_message=converted_message,
)
async def _chat_completion_for_meta_llama(
self, request: ChatCompletionRequest
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
"""
Subroutine that routes chat completions for Meta Llama models through Llama Stack's
chat template instead of using vLLM's version of that template. The Llama Stack version
of the chat template currently produces more reliable outputs.
Once vLLM's support for Meta Llama models has matured more, we should consider routing
Meta Llama requests through the vLLM chat completions API instead of using this method.
"""
formatter = ChatFormat(Tokenizer.get_instance())
# Note that this function call modifies `request` in place.
prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id)
model_id = list(self.model_ids)[0] # Any model ID will do here
completion_response_or_iterator = await self.completion(
model_id=model_id,
content=prompt,
sampling_params=request.sampling_params,
response_format=request.response_format,
stream=request.stream,
logprobs=request.logprobs,
)
if request.stream:
if not isinstance(completion_response_or_iterator, AsyncIterator):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request."
)
return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request)
# elsif not request.stream:
if not isinstance(completion_response_or_iterator, CompletionResponse):
raise TypeError(
f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request."
)
completion_response: CompletionResponse = completion_response_or_iterator
raw_message = formatter.decode_assistant_message_from_content(
completion_response.content, completion_response.stop_reason
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=completion_response.logprobs,
)
async def _chat_completion_for_meta_llama_streaming(
self, results_iterator: AsyncIterator, request: ChatCompletionRequest
) -> AsyncIterator:
"""
Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate
method to keep asyncio happy.
"""
# Convert to OpenAI format, then use shared code to convert to Llama Stack format.
async def _generate_and_convert_to_openai_compat():
chunk: CompletionResponseStreamChunk # Make Pylance happy
last_text_len = 0
async for chunk in results_iterator:
if chunk.stop_reason == StopReason.end_of_turn:
finish_reason = "stop"
elif chunk.stop_reason == StopReason.end_of_message:
finish_reason = "eos"
elif chunk.stop_reason == StopReason.out_of_tokens:
finish_reason = "length"
else:
finish_reason = None
# Convert delta back to an actual delta
text_delta = chunk.delta[last_text_len:]
last_text_len = len(chunk.delta)
logger.debug(f"{text_delta=}; {finish_reason=}")
yield OpenAICompatCompletionResponse(
choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)]
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(stream, request):
logger.debug(f"Returning chunk: {chunk}")
yield chunk
async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator:
"""
Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible
API into a second async iterator that returns Llama Stack objects.
:param vllm_result: Stream of strings that need to be parsed
"""
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: Dict[int, Dict] = dict()
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
stop_reason=None,
)
)
converted_stop_reason = None
async for chunk_str in vllm_result:
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
# end with "\n\n".
_prefix = "data: "
_suffix = "\n\n"
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'")
# In between the "data: " and newlines is an event record
data_str = chunk_str[len(_prefix) : -len(_suffix)]
# The end of the stream is indicated with "[DONE]"
if data_str == "[DONE]":
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=converted_stop_reason,
)
)
return
# Anything that is not "[DONE]" should be a JSON record
parsed_chunk = json.loads(data_str)
logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs only support
# returning one.
first_choice = parsed_chunk["choices"][0]
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"]
if "content" in delta_record:
# Text delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text=delta_record["content"]),
stop_reason=converted_stop_reason,
)
)
elif "tool_calls" in delta_record:
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
# calls, so buffer until we get a "tool calls" stop reason
for tc in delta_record["tool_calls"]:
index = tc["index"]
if index not in index_to_tool_call:
# First time this tool call is showing up
index_to_tool_call[index] = dict()
tool_call = index_to_tool_call[index]
if "id" in tc:
tool_call["call_id"] = tc["id"]
if "function" in tc:
if "name" in tc["function"]:
tool_call["tool_name"] = tc["function"]["name"]
if "arguments" in tc["function"]:
# Arguments comes in as pieces of a string
if "arguments_str" not in tool_call:
tool_call["arguments_str"] = ""
tool_call["arguments_str"] += tc["function"]["arguments"]
else:
raise ValueError(f"Don't know how to parse event delta: {delta_record}")
if first_choice["finish_reason"] == "tool_calls":
# Special OpenAI code for "tool calls complete".
# Output the buffered tool calls. Llama Stack requires a separate event per tool
# call.
for tool_call_record in index_to_tool_call.values():
# Arguments come in as a string. Parse the completed string.
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
del tool_call_record["arguments_str"]
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"),
stop_reason=converted_stop_reason,
)
)
# If we get here, we've lost the connection with the vLLM event stream before it ended
# normally.
raise ValueError("vLLM event stream ended without [DONE] message.")

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import TorchtunePostTrainingConfig
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Optional
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel
@ -12,3 +12,9 @@ from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"checkpoint_format": "meta",
}

View file

@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self):
pass
async def supervised_fine_tune(
self,
job_uuid: str,

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps):
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class CodeScannerConfig(BaseModel):
pass
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,10 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from typing import Any, Dict, List
from pydantic import BaseModel
class LlamaGuardConfig(BaseModel):
excluded_categories: List[str] = []
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"excluded_categories": [],
}

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict
from pydantic import BaseModel, field_validator
@ -23,3 +24,9 @@ class PromptGuardConfig(BaseModel):
if v not in [t.value for t in PromptGuardType]:
raise ValueError(f"Unknown prompt guard type: {v}")
return v
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"guard_type": "injection",
}

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import BasicScoringImpl

View file

@ -3,7 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class BasicScoringConfig(BaseModel): ...
class BasicScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
class BasicScoringImpl(

View file

@ -12,6 +12,7 @@ from llama_stack.apis.scoring_functions import (
)
MULTILINGUAL_ANSWER_REGEXES = [
r"The best answer is ",
r"Answer\s*:",
r"Answer\s*:", # Korean invisible character
r"উত্তর\s*:",

View file

@ -3,11 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .braintrust import BraintrustScoringImpl

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import LlmAsJudgeScoringImpl

View file

@ -3,7 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class LlmAsJudgeScoringConfig(BaseModel): ...
class LlmAsJudgeScoringConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
LLM_JUDGE_FN = LlmAsJudgeScoringFn
class LlmAsJudgeScoringImpl(
@ -43,23 +43,17 @@ class LlmAsJudgeScoringImpl(
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None:
for fn in LLM_JUDGE_FNS:
impl = fn(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl
impl = LLM_JUDGE_FN(inference_api=self.inference_api)
self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
for f in scoring_fn_defs_list:
for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
assert f.identifier.startswith("llm-as-judge"), (
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
)
@ -67,7 +61,7 @@ class LlmAsJudgeScoringImpl(
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def score_batch(
self,
@ -102,9 +96,7 @@ class LlmAsJudgeScoringImpl(
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn = self.llm_as_judge_fn
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)

View file

@ -6,7 +6,7 @@
import re
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.inference.inference import Inference, UserMessage
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -58,10 +58,9 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
judge_response = await self.inference_api.chat_completion(
model_id=fn_def.params.judge_model,
messages=[
{
"role": "user",
"content": judge_input_msg,
}
UserMessage(
content=judge_input_msg,
),
],
)
content = judge_response.completion_message.content

View file

@ -44,9 +44,9 @@ class TelemetryConfig(BaseModel):
return v
@classmethod
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db") -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/" + __distro_dir__ + "/" + db_name + "}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create(
{
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleTelemetryImpl
impl = SampleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig
class SampleTelemetryImpl(Telemetry):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -76,6 +76,7 @@ class CodeExecutionRequest:
only_last_cell_fail: bool = True
seed: int = 0
strip_fpaths_in_stderr: bool = True
use_bwrap: bool = True
class CodeExecutor:
@ -103,8 +104,6 @@ _set_seeds()\
script = "\n\n".join([seeds_prefix] + [CODE_ENV_PREFIX] + scripts)
with tempfile.TemporaryDirectory() as dpath:
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
code_fpath = os.path.join(dpath, "code.py")
with open(code_fpath, "w") as f:
f.write(script)
@ -118,6 +117,13 @@ _set_seeds()\
MPLBACKEND="module://matplotlib_custom_backend",
PYTHONPATH=f"{DIRNAME}:{python_path}",
)
if req.use_bwrap:
bwrap_prefix = "bwrap " + generate_bwrap_command(bind_dirs=[dpath])
cmd = [*bwrap_prefix.split(), sys.executable, "-c", script]
else:
cmd = [sys.executable, "-c", script]
stdout, stderr, returncode = do_subprocess(
cmd=cmd,
env=env,

View file

@ -6,6 +6,7 @@
import logging
import os
import tempfile
from typing import Any, Dict, List, Optional
@ -61,7 +62,9 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
script = kwargs["code"]
req = CodeExecutionRequest(scripts=[script])
# Use environment variable to control bwrap usage
force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes")
req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap)
res = self.code_executor.execute(req)
pieces = [res["process_status"]]
for out_type in ["stdout", "stderr"]:

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class CodeInterpreterToolConfig(BaseModel):
pass
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -4,8 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class RagToolRuntimeConfig(BaseModel):
pass
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {}

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -13,5 +13,5 @@ class ChromaVectorIOConfig(BaseModel):
db_path: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"db_path": "{env.CHROMADB_PATH}"}
def sample_run_config(cls, db_path: str = "${env.CHROMADB_PATH}", **kwargs: Any) -> Dict[str, Any]:
return {"db_path": db_path}

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import io
import json
@ -99,7 +100,7 @@ class FaissIndex(EmbeddingIndex):
await self._save_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = []
scores = []

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -15,5 +15,5 @@ class SQLiteVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]:
return {
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db",
"db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + "sqlite_vec.db",
}

View file

@ -7,11 +7,9 @@
from typing import List
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.kvstore import kvstore_dependencies
@ -39,13 +37,4 @@ def available_providers() -> List[ProviderSpec]:
Api.tool_groups,
],
),
remote_provider_spec(
api=Api.agents,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.remote.agents.sample",
config_class="llama_stack.providers.remote.agents.sample.SampleConfig",
),
),
]

View file

@ -68,15 +68,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.sentence_transformers",
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.remote.inference.sample",
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -27,27 +27,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.safety.prompt_guard",
config_class="llama_stack.providers.inline.safety.prompt_guard.PromptGuardConfig",
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::meta-reference",
pip_packages=[
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
api_dependencies=[
Api.inference,
],
deprecation_error="""
Provider `inline::meta-reference` for API `safety` does not work with the latest Llama Stack.
- if you are using Llama Guard v3, please use the `inline::llama-guard` provider instead.
- if you are using Prompt Guard, please use the `inline::prompt-guard` provider instead.
- if you are using Code Scanner, please use the `inline::code-scanner` provider instead.
""",
),
InlineProviderSpec(
api=Api.safety,
provider_type="inline::llama-guard",
@ -67,15 +46,6 @@ Provider `inline::meta-reference` for API `safety` does not work with the latest
module="llama_stack.providers.inline.safety.code_scanner",
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.remote.safety.sample",
config_class="llama_stack.providers.remote.safety.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(

View file

@ -7,11 +7,9 @@
from typing import List
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
@ -28,13 +26,4 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.remote.telemetry.sample",
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
),
),
]

View file

@ -34,6 +34,8 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
),
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
# source distribution and the wheels are not available for all platforms.
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite-vec",
@ -90,16 +92,6 @@ def available_providers() -> List[ProviderSpec]:
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
api=Api.vector_io,
adapter=AdapterSpec(
adapter_type="sample",
pip_packages=[],
module="llama_stack.providers.remote.vector_io.sample",
config_class="llama_stack.providers.remote.vector_io.sample.SampleVectorIOConfig",
),
api_dependencies=[],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
@ -110,4 +102,22 @@ def available_providers() -> List[ProviderSpec]:
),
api_dependencies=[Api.inference],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus",
pip_packages=["pymilvus"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
),
api_dependencies=[Api.inference],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
),
]

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleAgentsImpl
impl = SampleAgentsImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents import Agents
from .config import SampleConfig
class SampleAgentsImpl(Agents):
def __init__(self, config: SampleConfig):
self.config = config
async def initialize(self):
pass

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
@ -13,6 +14,13 @@ from llama_stack.providers.utils.kvstore.config import (
class HuggingfaceDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig = SqliteKVStoreConfig(
db_path=(RUNTIME_BASE_DIR / "huggingface_datasetio.db").as_posix()
) # Uses SQLite config specific to HF storage
kvstore: KVStoreConfig
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="huggingface_datasetio.db",
)
}

View file

@ -72,7 +72,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -83,7 +83,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -92,6 +92,8 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,

View file

@ -72,11 +72,13 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -112,7 +114,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -121,6 +123,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
@ -20,3 +21,15 @@ class DatabricksImplConfig(BaseModel):
default=None,
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL}",
api_token: str = "${env.DATABRICKS_API_TOKEN}",
**kwargs: Any,
) -> Dict[str, Any]:
return {
"url": url,
"api_token": api_token,
}

View file

@ -71,7 +71,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
@ -82,7 +82,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -91,6 +91,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -55,6 +55,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
@ -68,8 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
pass
def _get_api_key(self) -> str:
if self.config.api_key is not None:
return self.config.api_key.get_secret_value()
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
@ -86,11 +89,13 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -157,7 +162,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -166,6 +171,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -233,7 +240,8 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
"stream": request.stream,
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
}
logcat.debug("inference", f"params to fireworks: {params}")
logger.debug(f"params to fireworks: {params}")
return params
async def embeddings(

View file

@ -24,10 +24,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,

View file

@ -93,11 +93,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
if content_has_media(content):
raise NotImplementedError("Media is not supported")
@ -188,7 +190,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -197,6 +199,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if sampling_params is None:
sampling_params = SamplingParams()
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from ollama import AsyncClient
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
@ -35,6 +34,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
@ -59,7 +59,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import model_entries
log = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="inference")
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
@ -72,7 +72,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncClient(host=self.url)
async def initialize(self) -> None:
log.info(f"checking connectivity to Ollama at `{self.url}`...")
logger.info(f"checking connectivity to Ollama at `{self.url}`...")
try:
await self.client.ps()
except httpx.ConnectError as e:
@ -90,11 +90,13 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
@ -145,7 +147,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
@ -154,6 +156,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
@ -210,7 +214,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
"options": sampling_options,
"stream": request.stream,
}
logcat.debug("inference", f"params to ollama: {params}")
logger.debug(f"params to ollama: {params}")
return params
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
@ -286,7 +291,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.model_type == ModelType.embedding:
log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
response = await self.client.list()
else:

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack_client import LlamaStackClient
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> LlamaStackClient:
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference):
)
passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient(
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
@ -81,15 +84,17 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
request_params = {
"model_id": model.provider_resource_id,
"content": content,
"sampling_params": sampling_params,
@ -98,16 +103,19 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
# only pass through the not None params
return client.inference.completion(**params)
return await client.inference.completion(**json_params)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -116,10 +124,16 @@ class PassthroughInferenceAdapter(Inference):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
client = self._get_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
params = {
# TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
@ -131,10 +145,39 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.chat_completion(**params)
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings(
self,
@ -147,10 +190,29 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
return client.inference.embeddings(
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params

Some files were not shown because too many files have changed in this diff Show more