forked from phoenix-oss/llama-stack-mirror
merge
This commit is contained in:
commit
a54d757ade
197 changed files with 9392 additions and 3089 deletions
|
@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
|
|||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
]
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
@ -312,20 +309,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
name="AgentTurnResponseEventPayload",
|
||||
)
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Agents(Protocol):
|
||||
"""Agents API for creating and interacting with agentic systems.
|
||||
|
||||
|
@ -399,7 +393,7 @@ class Agents(Protocol):
|
|||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
|
||||
@webmethod(route="/agents", method="POST")
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -411,7 +405,9 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -443,6 +439,7 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
|
@ -505,7 +502,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
|||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = register_schema(
|
||||
Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="InterleavedContentItem",
|
||||
)
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = register_schema(
|
||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||
name="InterleavedContent",
|
||||
)
|
||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
|||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
)
|
||||
ContentDelta = Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
|
|
|
@ -72,24 +72,22 @@ class DialogType(BaseModel):
|
|||
type: Literal["dialog"] = "dialog"
|
||||
|
||||
|
||||
ParamType = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
name="ParamType",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
|
|
|
@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
|
|||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = register_schema(
|
||||
Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="DataSource",
|
||||
)
|
||||
DataSource = Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
|
@ -121,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
|
|||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
|
|
144
llama_stack/apis/eval/eval.py
Normal file
144
llama_stack/apis/eval/eval.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
"""A model candidate for evaluation.
|
||||
|
||||
:param model: The model ID to evaluate.
|
||||
:param sampling_params: The sampling parameters for the model.
|
||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||
"""
|
||||
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
system_message: Optional[SystemMessage] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
||||
:param eval_candidate: The candidate to evaluate.
|
||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
|
|
@ -144,18 +144,16 @@ class CompletionMessage(BaseModel):
|
|||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
name="Message",
|
||||
)
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
|
|||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = register_schema(
|
||||
Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ResponseFormat",
|
||||
)
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
|
|
|
@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
|
|||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -184,7 +182,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
149
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
149
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
judge_score_regexes: Optional[List[str]] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
params: Optional[ScoringFnParams] = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_scoring_fn_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListScoringFunctionsResponse(BaseModel):
|
||||
data: List[ScoringFn]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
|||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
name="StructuredLogPayload",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
|||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
name="Event",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
||||
:param document_id: The unique identifier for the document.
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
:param metadata: Additional metadata for the document.
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
|
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -69,7 +69,7 @@ class ToolGroup(Resource):
|
|||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: InterleavedContent
|
||||
content: Optional[InterleavedContent] = None
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
rag_tool: RAGToolRuntime
|
||||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
|
|
|
@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorIO(Protocol):
|
||||
vector_db_store: VectorDBStore
|
||||
vector_db_store: VectorDBStore | None = None
|
||||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
|
|
|
@ -404,7 +404,7 @@ def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
|||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
if datetime.now(timezone.utc) > manifest.expires_on:
|
||||
if datetime.now(timezone.utc) > manifest.expires_on.astimezone(timezone.utc):
|
||||
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
|
||||
|
||||
console = Console()
|
||||
|
|
86
llama_stack/distribution/access_control.py
Normal file
86
llama_stack/distribution/access_control.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_access(
|
||||
obj_identifier: str,
|
||||
obj_attributes: Optional[AccessAttributes],
|
||||
user_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
Access control algorithm:
|
||||
1. If the resource has no access_attributes, access is GRANTED to all authenticated users
|
||||
2. If the user has no attributes, access is DENIED to any object with access_attributes defined
|
||||
3. For each attribute category in the resource's access_attributes:
|
||||
a. If the user lacks that category, access is DENIED
|
||||
b. If the user has the category but none of the required values, access is DENIED
|
||||
c. If the user has at least one matching value in each required category, access is GRANTED
|
||||
|
||||
Example:
|
||||
# Resource requires:
|
||||
access_attributes = AccessAttributes(
|
||||
roles=["admin", "data-scientist"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
|
||||
# User has:
|
||||
user_attributes = {
|
||||
"roles": ["data-scientist", "engineer"],
|
||||
"teams": ["ml-team", "infra-team"],
|
||||
"projects": ["llama-3"]
|
||||
}
|
||||
|
||||
# Result: Access GRANTED
|
||||
# - User has the "data-scientist" role (matches one of the required roles)
|
||||
# - AND user is part of the "ml-team" (matches the required team)
|
||||
# - The extra "projects" attribute is ignored
|
||||
|
||||
Args:
|
||||
obj_identifier: The identifier of the resource object to check access for
|
||||
obj_attributes: The access attributes of the resource object
|
||||
user_attributes: The attributes of the current user
|
||||
|
||||
Returns:
|
||||
bool: True if access is granted, False if denied
|
||||
"""
|
||||
# If object has no access attributes, allow access by default
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
dict_attribs = obj_attributes.model_dump(exclude_none=True)
|
||||
if not dict_attribs:
|
||||
return True
|
||||
|
||||
# Check each attribute category (requires ALL categories to match)
|
||||
# TODO: formalize this into a proper ABAC policy
|
||||
for attr_key, required_values in dict_attribs.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
if not user_values:
|
||||
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
|
||||
return False
|
||||
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj_identifier}: "
|
||||
f"no match for attribute '{attr_key}', required one of {required_values}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(f"Access granted to {obj_identifier}")
|
||||
return True
|
|
@ -90,6 +90,7 @@ RUN apt-get update && apt-get install -y \
|
|||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
@ -235,7 +236,7 @@ image_tag="$image_name:$version_tag"
|
|||
# Detect platform architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ -n "$BUILD_PLATFORM" ]; then
|
||||
CLI_ARGS+=("--platform $BUILD_PLATFORM")
|
||||
CLI_ARGS+=("--platform" "$BUILD_PLATFORM")
|
||||
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
|
||||
CLI_ARGS+=("--platform" "linux/arm64")
|
||||
elif [ "$ARCH" = "x86_64" ]; then
|
||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
|
||||
|
@ -28,6 +29,115 @@ LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
|||
RoutingKey = Union[str, List[str]]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
"""Structured representation of user attributes for access control.
|
||||
|
||||
This model defines a structured approach to representing user attributes
|
||||
with common standard categories for access control.
|
||||
|
||||
Standard attribute categories include:
|
||||
- roles: Role-based attributes (e.g., admin, data-scientist)
|
||||
- teams: Team-based attributes (e.g., ml-team, infra-team)
|
||||
- projects: Project access attributes (e.g., llama-3, customer-insights)
|
||||
- namespaces: Namespace-based access control for resource isolation
|
||||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: Optional[List[str]] = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: Optional[List[str]] = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: Optional[List[str]] = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
||||
class ResourceWithACL(Resource):
|
||||
"""Extension of Resource that adds attribute-based access control capabilities.
|
||||
|
||||
This class adds an optional access_attributes field that allows fine-grained control
|
||||
over which users can access each resource. When attributes are defined, a user must have
|
||||
matching attributes to access the resource.
|
||||
|
||||
Attribute Matching Algorithm:
|
||||
1. If a resource has no access_attributes (None or empty dict), it's visible to all authenticated users
|
||||
2. Each key in access_attributes represents an attribute category (e.g., "roles", "teams", "projects")
|
||||
3. The matching algorithm requires ALL categories to match (AND relationship between categories)
|
||||
4. Within each category, ANY value match is sufficient (OR relationship within a category)
|
||||
|
||||
Examples:
|
||||
# Resource visible to everyone (no access control)
|
||||
model = Model(identifier="llama-2", ...)
|
||||
|
||||
# Resource visible only to admins
|
||||
model = Model(
|
||||
identifier="gpt-4",
|
||||
access_attributes=AccessAttributes(roles=["admin"])
|
||||
)
|
||||
|
||||
# Resource visible to data scientists on the ML team
|
||||
model = Model(
|
||||
identifier="private-model",
|
||||
access_attributes=AccessAttributes(
|
||||
roles=["data-scientist", "researcher"],
|
||||
teams=["ml-team"]
|
||||
)
|
||||
)
|
||||
# ^ User must have at least one of the roles AND be on the ml-team
|
||||
|
||||
# Resource visible to users with specific project access
|
||||
vector_db = VectorDB(
|
||||
identifier="customer-embeddings",
|
||||
access_attributes=AccessAttributes(
|
||||
projects=["customer-insights"],
|
||||
namespaces=["confidential"]
|
||||
)
|
||||
)
|
||||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
class ModelWithACL(Model, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ShieldWithACL(Shield, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBWithACL(VectorDB, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetWithACL(Dataset, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ScoringFnWithACL(ScoringFn, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkWithACL(Benchmark, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolWithACL(Tool, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||
pass
|
||||
|
||||
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
|
@ -41,13 +151,14 @@ RoutableObject = Union[
|
|||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
@ -11,9 +11,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.apis.inspect import (
|
||||
HealthInfo,
|
||||
Inspect,
|
||||
ListProvidersResponse,
|
||||
ListRoutesResponse,
|
||||
ProviderInfo,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
|
@ -39,24 +37,6 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
ret = []
|
||||
for api, providers in run_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
run_config = self.config.run_config
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
|
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in self.impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = self.impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
return True
|
||||
|
||||
async def request(
|
||||
|
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = self.endpoint_impls.get(method)
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, func in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
|
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params = self._find_matching_endpoint(options.method, path)
|
||||
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
await start_trace(options.url, {"__location__": "library_client"})
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
|
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params = self._find_matching_endpoint(options.method, path)
|
||||
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
async def gen():
|
||||
await start_trace(options.url, {"__location__": "library_client"})
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
|
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not body:
|
||||
return {}
|
||||
|
||||
func, _ = self._find_matching_endpoint(method, path)
|
||||
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
|
@ -7,21 +7,26 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, ContextManager, Dict, Optional
|
||||
from typing import Any, ContextManager, Dict, List, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Context variable for request provider data
|
||||
# Context variable for request provider data and auth attributes
|
||||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(self, provider_data: Optional[Dict[str, Any]] = None):
|
||||
self.provider_data = provider_data
|
||||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
self.token = None
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -80,7 +85,17 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
|||
return None
|
||||
|
||||
|
||||
def request_provider_data_context(headers: Dict[str, str]) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers for the duration of the context"""
|
||||
def request_provider_data_context(
|
||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
) -> ContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__auth_attributes")
|
||||
|
|
|
@ -19,6 +19,8 @@ from llama_stack.apis.datasets import (
|
|||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
@ -32,11 +34,22 @@ from llama_stack.apis.tools import (
|
|||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
BenchmarkWithACL,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithACL,
|
||||
ShieldWithACL,
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -165,6 +178,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
@ -181,6 +199,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
@ -193,7 +218,17 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj
|
||||
for obj in filtered_objs
|
||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
@ -230,7 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -276,7 +311,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
shield = ShieldWithACL(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -330,7 +365,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
@ -358,6 +393,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
elif source["type"] == "rows":
|
||||
source = RowsDataSource.parse_obj(source)
|
||||
|
||||
if not dataset_id:
|
||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||
|
||||
|
@ -378,7 +419,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = Dataset(
|
||||
dataset = DatasetWithACL(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -429,7 +470,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
benchmark = Benchmark(
|
||||
benchmark = BenchmarkWithACL(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
grader_ids=grader_ids,
|
||||
|
@ -473,7 +514,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
|
@ -498,7 +539,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
@ -511,7 +552,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id).data
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
|
|
@ -5,16 +5,118 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
|
@ -32,25 +134,57 @@ class AuthenticationMiddleware:
|
|||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
auth_data = {
|
||||
"api_key": api_key,
|
||||
"request": {
|
||||
"path": path,
|
||||
"headers": request_headers,
|
||||
"params": params,
|
||||
},
|
||||
}
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.auth_endpoint, json=auth_data)
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [api_key],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
|
|||
route: str
|
||||
method: str
|
||||
name: str
|
||||
descriptive_name: str | None = None
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
endpoints.append(
|
||||
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||
)
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def initialize_endpoint_impls(impls):
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||
func,
|
||||
endpoint.descriptive_name or endpoint.route,
|
||||
)
|
||||
|
||||
return endpoint_impls
|
||||
|
||||
|
||||
def find_matching_endpoint(method, path, endpoint_impls):
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
endpoint_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = endpoint_impls.get(method.lower())
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, descriptive_name) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, descriptive_name
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
|
@ -179,8 +183,11 @@ async def sse_generator(event_gen):
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
# Use context manager for request provider data
|
||||
with request_provider_data_context(request.headers):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user_attributes):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
|
@ -219,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app):
|
||||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope.get("path", "")
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
@ -348,7 +371,6 @@ def main():
|
|||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
@ -366,7 +388,7 @@ def main():
|
|||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
@ -412,6 +434,7 @@ def main():
|
|||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -12,9 +12,12 @@ import pydantic
|
|||
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||
|
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
|
|||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
try:
|
||||
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
|
||||
all_objects.append(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
|
||||
continue
|
||||
|
||||
return all_objects
|
||||
|
||||
|
||||
|
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
if not json_str:
|
||||
return None
|
||||
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
try:
|
||||
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
|
||||
return None
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.kvstore.set(
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.shared.document import Document
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
@ -35,7 +33,7 @@ def rag_chat_page():
|
|||
)
|
||||
if st.button("Create Vector Database"):
|
||||
documents = [
|
||||
Document(
|
||||
RAGDocument(
|
||||
document_id=uploaded_file.name,
|
||||
content=data_url_from_file(uploaded_file),
|
||||
)
|
||||
|
@ -167,7 +165,7 @@ def rag_chat_page():
|
|||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
for log in EventLogger().log(response):
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "tool_execution":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
|
|
|
@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
|||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
arguments: Dict[str, RecursiveType]
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: Union[str, Dict[str, RecursiveType]]
|
||||
arguments_json: Optional[str] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -179,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
|
|||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = register_schema(
|
||||
Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="SamplingStrategy",
|
||||
)
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
@ -203,9 +204,10 @@ class ChatFormat:
|
|||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
if isinstance(tool_arguments, dict):
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
|
@ -229,6 +231,7 @@ class ChatFormat:
|
|||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
|
|
@ -244,6 +244,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
template_str = textwrap.dedent(
|
||||
"""
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
|
|
@ -11,11 +11,8 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
|
|
|
@ -15,8 +15,11 @@ import json
|
|||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
|
|||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
try:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
||||
) from e
|
||||
|
||||
result.append((function_name, function_args))
|
||||
|
||||
|
|
|
@ -180,25 +180,29 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
||||
async def _run_turn(
|
||||
self,
|
||||
|
|
|
@ -13,6 +13,9 @@ from typing import List, Optional
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -24,6 +27,7 @@ class AgentSessionInfo(BaseModel):
|
|||
# TODO: is this used anywhere?
|
||||
vector_db_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
|
@ -33,11 +37,18 @@ class AgentPersistence:
|
|||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
auth_attributes = get_auth_attributes()
|
||||
access_attributes = AccessAttributes(**auth_attributes) if auth_attributes else None
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
|
@ -51,12 +62,34 @@ class AgentPersistence:
|
|||
if not value:
|
||||
return None
|
||||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes"):
|
||||
return True
|
||||
|
||||
return check_access(session_info.session_id, session_info.access_attributes, get_auth_attributes())
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
|
||||
session_info = await self.get_session_info(session_id)
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
session_info.vector_db_id = vector_db_id
|
||||
await self.kvstore.set(
|
||||
|
@ -65,12 +98,18 @@ class AgentPersistence:
|
|||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
values = await self.kvstore.range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
|
@ -87,6 +126,9 @@ class AgentPersistence:
|
|||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Optional[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
|
@ -95,24 +137,36 @@ class AgentPersistence:
|
|||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> Optional[int]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
|
|
|
@ -35,12 +35,12 @@ class PandasDataframeDataset:
|
|||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def load(self) -> None:
|
||||
async def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
||||
if self.dataset_def.source.type == "uri":
|
||||
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
self.df = await get_dataframe_from_uri(self.dataset_def.source.uri)
|
||||
elif self.dataset_def.source.type == "rows":
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
|
@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
) -> IterrowsResponse:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
dataset_impl.load()
|
||||
await dataset_impl.load()
|
||||
|
||||
start_index = start_index or 0
|
||||
|
||||
|
@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
dataset_impl = PandasDataframeDataset(dataset_def)
|
||||
dataset_impl.load()
|
||||
await dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
@ -20,8 +20,8 @@ from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
|||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse, JobStatus
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
@ -101,7 +101,7 @@ class MetaReferenceEvalImpl(
|
|||
# need job scheduler queue (ray/celery) w/ jobs api
|
||||
job_id = str(len(self.jobs))
|
||||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id)
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: List[Dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
|
@ -215,17 +215,18 @@ class MetaReferenceEvalImpl(
|
|||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
return None
|
||||
raise ValueError(f"Job {job_id} not found")
|
||||
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(benchmark_id, job_id)
|
||||
job = await self.job_status(benchmark_id, job_id)
|
||||
status = job.status
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
tool_name=t.function.name,
|
||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||
arguments=json.loads(t.function.arguments),
|
||||
arguments_json=t.function.arguments,
|
||||
)
|
||||
for t in vllm_message.tool_calls
|
||||
],
|
||||
|
|
|
@ -23,7 +23,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import (
|
||||
RegexParserMathResponseScoringFn,
|
||||
)
|
||||
|
@ -36,6 +38,8 @@ FIXED_FNS = [
|
|||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.docvqa import docvqa
|
||||
|
||||
CONTRACTIONS = {
|
||||
"aint": "ain't",
|
||||
"arent": "aren't",
|
||||
"cant": "can't",
|
||||
"couldve": "could've",
|
||||
"couldnt": "couldn't",
|
||||
"couldn'tve": "couldn't've",
|
||||
"couldnt've": "couldn't've",
|
||||
"didnt": "didn't",
|
||||
"doesnt": "doesn't",
|
||||
"dont": "don't",
|
||||
"hadnt": "hadn't",
|
||||
"hadnt've": "hadn't've",
|
||||
"hadn'tve": "hadn't've",
|
||||
"hasnt": "hasn't",
|
||||
"havent": "haven't",
|
||||
"hed": "he'd",
|
||||
"hed've": "he'd've",
|
||||
"he'dve": "he'd've",
|
||||
"hes": "he's",
|
||||
"howd": "how'd",
|
||||
"howll": "how'll",
|
||||
"hows": "how's",
|
||||
"Id've": "I'd've",
|
||||
"I'dve": "I'd've",
|
||||
"Im": "I'm",
|
||||
"Ive": "I've",
|
||||
"isnt": "isn't",
|
||||
"itd": "it'd",
|
||||
"itd've": "it'd've",
|
||||
"it'dve": "it'd've",
|
||||
"itll": "it'll",
|
||||
"let's": "let's",
|
||||
"maam": "ma'am",
|
||||
"mightnt": "mightn't",
|
||||
"mightnt've": "mightn't've",
|
||||
"mightn'tve": "mightn't've",
|
||||
"mightve": "might've",
|
||||
"mustnt": "mustn't",
|
||||
"mustve": "must've",
|
||||
"neednt": "needn't",
|
||||
"notve": "not've",
|
||||
"oclock": "o'clock",
|
||||
"oughtnt": "oughtn't",
|
||||
"ow's'at": "'ow's'at",
|
||||
"'ows'at": "'ow's'at",
|
||||
"'ow'sat": "'ow's'at",
|
||||
"shant": "shan't",
|
||||
"shed've": "she'd've",
|
||||
"she'dve": "she'd've",
|
||||
"she's": "she's",
|
||||
"shouldve": "should've",
|
||||
"shouldnt": "shouldn't",
|
||||
"shouldnt've": "shouldn't've",
|
||||
"shouldn'tve": "shouldn't've",
|
||||
"somebody'd": "somebodyd",
|
||||
"somebodyd've": "somebody'd've",
|
||||
"somebody'dve": "somebody'd've",
|
||||
"somebodyll": "somebody'll",
|
||||
"somebodys": "somebody's",
|
||||
"someoned": "someone'd",
|
||||
"someoned've": "someone'd've",
|
||||
"someone'dve": "someone'd've",
|
||||
"someonell": "someone'll",
|
||||
"someones": "someone's",
|
||||
"somethingd": "something'd",
|
||||
"somethingd've": "something'd've",
|
||||
"something'dve": "something'd've",
|
||||
"somethingll": "something'll",
|
||||
"thats": "that's",
|
||||
"thered": "there'd",
|
||||
"thered've": "there'd've",
|
||||
"there'dve": "there'd've",
|
||||
"therere": "there're",
|
||||
"theres": "there's",
|
||||
"theyd": "they'd",
|
||||
"theyd've": "they'd've",
|
||||
"they'dve": "they'd've",
|
||||
"theyll": "they'll",
|
||||
"theyre": "they're",
|
||||
"theyve": "they've",
|
||||
"twas": "'twas",
|
||||
"wasnt": "wasn't",
|
||||
"wed've": "we'd've",
|
||||
"we'dve": "we'd've",
|
||||
"weve": "we've",
|
||||
"werent": "weren't",
|
||||
"whatll": "what'll",
|
||||
"whatre": "what're",
|
||||
"whats": "what's",
|
||||
"whatve": "what've",
|
||||
"whens": "when's",
|
||||
"whered": "where'd",
|
||||
"wheres": "where's",
|
||||
"whereve": "where've",
|
||||
"whod": "who'd",
|
||||
"whod've": "who'd've",
|
||||
"who'dve": "who'd've",
|
||||
"wholl": "who'll",
|
||||
"whos": "who's",
|
||||
"whove": "who've",
|
||||
"whyll": "why'll",
|
||||
"whyre": "why're",
|
||||
"whys": "why's",
|
||||
"wont": "won't",
|
||||
"wouldve": "would've",
|
||||
"wouldnt": "wouldn't",
|
||||
"wouldnt've": "wouldn't've",
|
||||
"wouldn'tve": "wouldn't've",
|
||||
"yall": "y'all",
|
||||
"yall'll": "y'all'll",
|
||||
"y'allll": "y'all'll",
|
||||
"yall'd've": "y'all'd've",
|
||||
"y'alld've": "y'all'd've",
|
||||
"y'all'dve": "y'all'd've",
|
||||
"youd": "you'd",
|
||||
"youd've": "you'd've",
|
||||
"you'dve": "you'd've",
|
||||
"youll": "you'll",
|
||||
"youre": "you're",
|
||||
"youve": "you've",
|
||||
"1st": "first",
|
||||
"2nd": "second",
|
||||
"3rd": "third",
|
||||
}
|
||||
NUMBERS = {
|
||||
"none": "0",
|
||||
"zero": "0",
|
||||
"one": "1",
|
||||
"two": "2",
|
||||
"three": "3",
|
||||
"four": "4",
|
||||
"five": "5",
|
||||
"six": "6",
|
||||
"seven": "7",
|
||||
"eight": "8",
|
||||
"nine": "9",
|
||||
"ten": "10",
|
||||
}
|
||||
ARTICLES = [
|
||||
"a",
|
||||
"an",
|
||||
"the",
|
||||
"to",
|
||||
"in",
|
||||
"from",
|
||||
"by",
|
||||
] # Contains a bit more than just articles, but we want to get rid of these elements influencing the accuracy
|
||||
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||||
COMMA_STRIP = re.compile(r"(\d)(\,)(\d)")
|
||||
PUNCTUATION = [
|
||||
";",
|
||||
r"/",
|
||||
"[",
|
||||
"]",
|
||||
'"',
|
||||
"{",
|
||||
"}",
|
||||
"(",
|
||||
")",
|
||||
"=",
|
||||
"+",
|
||||
"\\",
|
||||
"_",
|
||||
"-",
|
||||
">",
|
||||
"<",
|
||||
"@",
|
||||
"`",
|
||||
",",
|
||||
"?",
|
||||
"!",
|
||||
]
|
||||
|
||||
|
||||
def normalize_answer(s: str) -> str:
|
||||
# process punctuation
|
||||
for p in PUNCTUATION:
|
||||
if (p + " " in s or " " + p in s) or (re.search(COMMA_STRIP, s) is not None):
|
||||
s = s.replace(p, "")
|
||||
else:
|
||||
s = s.replace(p, " ")
|
||||
s = PERIOD_STRIP.sub("", s, re.UNICODE)
|
||||
|
||||
# process digits and articles
|
||||
temp_text = s.lower().split()
|
||||
out_text = []
|
||||
for word in temp_text:
|
||||
word = NUMBERS.setdefault(word, word)
|
||||
if word not in ARTICLES:
|
||||
out_text.append(word)
|
||||
|
||||
# standardize contractions
|
||||
for word_id, word in enumerate(out_text):
|
||||
if word in CONTRACTIONS:
|
||||
out_text[word_id] = CONTRACTIONS[word]
|
||||
return " ".join(out_text)
|
||||
|
||||
|
||||
class DocVQAScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
docvqa basically matches the generated answer against several allowed
|
||||
choices, but we need to normalize the answer to avoid penalizing
|
||||
trivial differences
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
docvqa.identifier: docvqa,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = "docvqa",
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
expected_answers = json.loads(input_row["expected_answer"])
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if normalize_answer(generated_answer) in [normalize_answer(s) for s in expected_answers] else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
docvqa = ScoringFn(
|
||||
identifier="basic::docvqa",
|
||||
description="DocVQA Visual Question & Answer scoring function",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="docvqa",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
ifeval = ScoringFn(
|
||||
identifier="basic::ifeval",
|
||||
description="Eval intruction follow capacity by checkping how many instructions can be followed in each example",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="ifeval",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.weighted_average],
|
||||
),
|
||||
)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from .fn_defs.ifeval import (
|
||||
ifeval,
|
||||
)
|
||||
|
||||
|
||||
class IfEvalScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn Instruction-Following Eval (IFEval) benchmark
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
ifeval.identifier: ifeval,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
from ..utils.ifeval_utils import INSTRUCTION_DICT, INSTRUCTION_LIST
|
||||
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
instruction_list = input_row["instruction_id_list"]
|
||||
generated_answer = input_row["generated_answer"].strip()
|
||||
|
||||
is_following_list = []
|
||||
results = dict(
|
||||
{k + "_correct": 0.0 for k in INSTRUCTION_LIST},
|
||||
**{k + "_total": 0.0 for k in INSTRUCTION_LIST},
|
||||
)
|
||||
|
||||
for index, instruction_id in enumerate(instruction_list):
|
||||
instruction_cls = INSTRUCTION_DICT[instruction_id]
|
||||
instruction = instruction_cls(instruction_id)
|
||||
results[instruction_id + "_total"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_total"] += 1.0
|
||||
|
||||
clean_input_row = {k: v for k, v in input_row["kwargs"][index].items() if v is not None}
|
||||
print(clean_input_row)
|
||||
instruction.build_description(**clean_input_row)
|
||||
args = instruction.get_instruction_args()
|
||||
if args and "prompt" in args:
|
||||
instruction.build_description(prompt=input_row["prompt"])
|
||||
|
||||
if generated_answer and instruction.check_following(generated_answer):
|
||||
is_following_list.append(True)
|
||||
results[instruction_id + "_correct"] += 1.0
|
||||
results[instruction_id.split(":")[0] + "_correct"] += 1.0
|
||||
else:
|
||||
is_following_list.append(False)
|
||||
|
||||
if len(is_following_list) == 0:
|
||||
return {
|
||||
"score": 0.0,
|
||||
"weight": 0.0,
|
||||
}
|
||||
|
||||
return {
|
||||
"score": float(sum(is_following_list)) / float(len(is_following_list)),
|
||||
"weight": float(len(is_following_list)),
|
||||
}
|
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
3319
llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -6,12 +6,14 @@
|
|||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[Api, Any]):
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
impl = TelemetryAdapter(config, deps)
|
||||
|
|
|
@ -13,19 +13,20 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
|||
|
||||
|
||||
class TelemetrySink(str, Enum):
|
||||
OTEL = "otel"
|
||||
OTEL_TRACE = "otel_trace"
|
||||
OTEL_METRIC = "otel_metric"
|
||||
SQLITE = "sqlite"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_endpoint: str = Field(
|
||||
otel_trace_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
description="The OpenTelemetry collector endpoint URL for traces",
|
||||
)
|
||||
service_name: str = Field(
|
||||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
otel_metric_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/metrics",
|
||||
description="The OpenTelemetry collector endpoint URL for metrics",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
|
@ -46,7 +47,6 @@ class TelemetryConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
}
|
||||
|
|
|
@ -101,6 +101,6 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
"""Shutdown the processor."""
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: float = None) -> bool:
|
||||
def force_flush(self, timeout_millis: float | None = None) -> bool:
|
||||
"""Force flush any pending spans."""
|
||||
return True
|
||||
|
|
|
@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
|||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
|
@ -100,14 +101,14 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format(span.get_span_context().trace_id, "032x")
|
||||
span_id = format(span.get_span_context().span_id, "016x")
|
||||
trace_id = format_trace_id(span.get_span_context().trace_id)
|
||||
span_id = format_span_id(span.get_span_context().span_id)
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format(parent_context.span_id, "016x")
|
||||
parent_span_id = format_span_id(parent_context.span_id)
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
|
@ -123,7 +124,7 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
(span_id if span.attributes.get("__root_span__") == "true" else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9, timezone.utc).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9, timezone.utc).isoformat(),
|
||||
),
|
||||
|
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
|
|||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
|
@ -54,30 +54,21 @@ _global_lock = threading.Lock()
|
|||
_TRACER_PROVIDER = None
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[Api, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps.get(Api.datasetio)
|
||||
self.meter = None
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
ResourceAttributes.SERVICE_NAME: "",
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -91,15 +82,16 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
if TelemetrySink.OTEL_TRACE in self.config.sinks:
|
||||
span_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_trace_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
endpoint=self.config.otel_metric_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(resource=resource, metric_readers=[metric_reader])
|
||||
|
@ -109,7 +101,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
if TelemetrySink.OTEL_METRIC in self.config.sinks:
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
|
@ -135,7 +127,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span_id = event.span_id
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
|
@ -146,7 +138,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**event.attributes,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
|
@ -154,6 +146,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
|
@ -163,6 +156,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
|
@ -182,6 +176,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
|
@ -192,8 +187,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
|
@ -204,14 +198,23 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
else:
|
||||
context = trace.set_span_in_context(
|
||||
trace.NonRecordingSpan(
|
||||
trace.SpanContext(
|
||||
trace_id=int(event.trace_id, 16),
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
|
||||
)
|
||||
)
|
||||
)
|
||||
event.attributes["__root_span__"] = "true"
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
|
|
|
@ -69,7 +69,7 @@ def popen_not_allowed(*args, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
_subprocess.Popen = popen_not_allowed
|
||||
_subprocess.Popen = popen_not_allowed # type: ignore
|
||||
|
||||
|
||||
import atexit as _atexit
|
||||
|
@ -104,7 +104,7 @@ def _open_connections():
|
|||
return _NETWORK_CONNECTIONS
|
||||
|
||||
|
||||
_builtins._open_connections = _open_connections
|
||||
_builtins._open_connections = _open_connections # type: ignore
|
||||
|
||||
|
||||
@_atexit.register
|
||||
|
|
|
@ -161,9 +161,9 @@ _set_seeds()\
|
|||
def process_matplotlib_response(response, matplotlib_dump_dir: str):
|
||||
image_data = response["image_data"]
|
||||
# Convert the base64 string to a bytes object
|
||||
images = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
images_raw = [base64.b64decode(d["image_base64"]) for d in image_data]
|
||||
# Create a list of PIL images from the bytes objects
|
||||
images = [Image.open(BytesIO(img)) for img in images]
|
||||
images = [Image.open(BytesIO(img)) for img in images_raw]
|
||||
# Create a list of image paths
|
||||
image_paths = []
|
||||
for i, img in enumerate(images):
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.datatypes import Api
|
|||
from .config import RagToolRuntimeConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[str, Any]):
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: Dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import TypeAdapter
|
|||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
@ -23,6 +24,7 @@ from llama_stack.apis.tools import (
|
|||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolInvocationResult,
|
||||
ToolParameter,
|
||||
|
@ -62,6 +64,12 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def register_tool(self, tool: Tool) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> None:
|
||||
return
|
||||
|
||||
async def insert(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
|
@ -121,11 +129,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
return RAGQueryResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False)
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
||||
chunks = chunks[: query_config.max_chunks]
|
||||
|
||||
tokens = 0
|
||||
picked = [
|
||||
picked: list[InterleavedContentItem] = [
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
|
|
|
@ -15,11 +15,13 @@ import faiss
|
|||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
@ -35,16 +37,14 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
|||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
chunk_by_index: Dict[int, str]
|
||||
|
||||
def __init__(self, dimension: int, kvstore=None, bank_id: str = None):
|
||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.chunk_by_index = {}
|
||||
self.chunk_by_index: dict[int, Chunk] = {}
|
||||
self.kvstore = kvstore
|
||||
self.bank_id = bank_id
|
||||
|
||||
@classmethod
|
||||
async def create(cls, dimension: int, kvstore=None, bank_id: str = None):
|
||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||
instance = cls(dimension, kvstore, bank_id)
|
||||
await instance.initialize()
|
||||
return instance
|
||||
|
@ -114,11 +114,11 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
self.cache: dict[str, VectorDBWithIndex] = {}
|
||||
self.kvstore: KVStore | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
@ -144,6 +144,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
|
@ -161,6 +163,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
return [i.vector_db for i in self.cache.values()]
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
assert self.kvstore is not None
|
||||
|
||||
if vector_db_id not in self.cache:
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
|
|
@ -15,9 +15,10 @@ import numpy as np
|
|||
import sqlite_vec
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -78,6 +79,8 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
|
||||
If any insert fails, the transaction is rolled back to maintain consistency.
|
||||
"""
|
||||
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
|
||||
|
||||
cur = self.connection.cursor()
|
||||
try:
|
||||
# Start transaction
|
||||
|
@ -89,6 +92,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
metadata_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
|
||||
for chunk in batch_chunks
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert metadata (ON CONFLICT to avoid duplicates)
|
||||
cur.executemany(
|
||||
|
@ -103,6 +107,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
embedding_data = [
|
||||
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), serialize_vector(emb.tolist()))
|
||||
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
|
||||
if isinstance(chunk.content, str)
|
||||
]
|
||||
# Insert embeddings in batch
|
||||
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data)
|
||||
|
@ -154,7 +159,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
|
||||
"""
|
||||
|
||||
def __init__(self, config, inference_api: Api.inference) -> None:
|
||||
def __init__(self, config, inference_api: Inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache: Dict[str, VectorDBWithIndex] = {}
|
||||
|
|
29
llama_stack/providers/registry/eval.py
Normal file
29
llama_stack/providers/registry/eval.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.eval,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=["tree_sitter", "pythainlp", "langdetect", "emoji", "nltk"],
|
||||
module="llama_stack.providers.inline.eval.meta_reference",
|
||||
config_class="llama_stack.providers.inline.eval.meta_reference.MetaReferenceEvalConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
Api.scoring,
|
||||
Api.inference,
|
||||
Api.agents,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
@ -21,7 +21,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.safety,
|
||||
provider_type="inline::prompt-guard",
|
||||
pip_packages=[
|
||||
"transformers",
|
||||
"transformers[accelerate]",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.safety.prompt_guard",
|
||||
|
|
|
@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
|
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if not tool_calls:
|
||||
return []
|
||||
|
||||
for call in tool_calls:
|
||||
call_function_arguments = json.loads(call.function.arguments)
|
||||
|
||||
compitable_tool_calls = [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call_function_arguments,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
|
|
@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
|
|||
if not tool_calls:
|
||||
return []
|
||||
|
||||
call_function_arguments = None
|
||||
for call in tool_calls:
|
||||
call_function_arguments = json.loads(call.function.arguments)
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=call_function_arguments,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
call_id=tool_call_buf.call_id,
|
||||
tool_name=tool_call_buf.tool_name,
|
||||
arguments=args,
|
||||
arguments_json=args_str,
|
||||
),
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from urllib.parse import unquote
|
||||
|
@ -13,12 +14,15 @@ import pandas
|
|||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
|
||||
def get_dataframe_from_uri(uri: str):
|
||||
async def get_dataframe_from_uri(uri: str):
|
||||
df = None
|
||||
if uri.endswith(".csv"):
|
||||
df = pandas.read_csv(uri)
|
||||
# Moving to its own thread to avoid io from blocking the eventloop
|
||||
# This isn't ideal as it moves more then just the IO to a new thread
|
||||
# but it is as close as we can easly get
|
||||
df = await asyncio.to_thread(pandas.read_csv, uri)
|
||||
elif uri.endswith(".xlsx"):
|
||||
df = pandas.read_excel(uri)
|
||||
df = await asyncio.to_thread(pandas.read_excel, uri)
|
||||
elif uri.startswith("data:"):
|
||||
parts = parse_data_url(uri)
|
||||
data = parts["data"]
|
||||
|
|
|
@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
|
|||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||
async def impl(
|
||||
content_: InterleavedContent,
|
||||
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
|
||||
) -> Union[
|
||||
str,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
List[OpenAIChatCompletionContentPartParam],
|
||||
]:
|
||||
# Llama Stack and OpenAI spec match for str and text input
|
||||
if isinstance(content_, str):
|
||||
return content_
|
||||
|
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
|
|||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
|
@ -609,6 +613,7 @@ def convert_tool_call(
|
|||
call_id=tool_call.id,
|
||||
tool_name=tool_call.function.name,
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
arguments_json=tool_call.function.arguments,
|
||||
)
|
||||
except Exception:
|
||||
return UnparseableToolCall(
|
||||
|
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
|
|||
call_id=call.id,
|
||||
tool_name=call.function.name,
|
||||
arguments=json.loads(call.function.arguments),
|
||||
arguments_json=call.function.arguments,
|
||||
)
|
||||
for call in tool_calls
|
||||
]
|
||||
|
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
|
|||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn(
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not enable_incremental_tool_calls:
|
||||
|
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
|
|||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=arguments,
|
||||
arguments_json=buffer["arguments"],
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
|
@ -28,6 +28,17 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
|
|||
}
|
||||
|
||||
|
||||
def aggregate_weighted_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
return {
|
||||
"weighted_average": sum(
|
||||
result["score"] * result["weight"]
|
||||
for result in scoring_results
|
||||
if result["score"] is not None and result["weight"] is not None
|
||||
)
|
||||
/ sum(result["weight"] for result in scoring_results if result["weight"] is not None),
|
||||
}
|
||||
|
||||
|
||||
def aggregate_categorical_count(
|
||||
scoring_results: List[ScoringResultRow],
|
||||
) -> Dict[str, Any]:
|
||||
|
@ -46,6 +57,7 @@ def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
|||
AGGREGATION_FUNCTIONS = {
|
||||
AggregationFunctionType.accuracy: aggregate_accuracy,
|
||||
AggregationFunctionType.average: aggregate_average,
|
||||
AggregationFunctionType.weighted_average: aggregate_weighted_average,
|
||||
AggregationFunctionType.categorical_count: aggregate_categorical_count,
|
||||
AggregationFunctionType.median: aggregate_median,
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.telemetry import QueryCondition, QuerySpansResponse, Span
|
|||
class TelemetryDatasetMixin:
|
||||
"""Mixin class that provides dataset-related functionality for telemetry providers."""
|
||||
|
||||
datasetio_api: DatasetIO
|
||||
datasetio_api: DatasetIO | None
|
||||
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
@ -31,11 +30,44 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
|
|||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def generate_short_uuid(len: int = 8):
|
||||
full_uuid = uuid.uuid4()
|
||||
uuid_bytes = full_uuid.bytes
|
||||
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
||||
return encoded.rstrip(b"=").decode("ascii")[:len]
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
Args:
|
||||
trace_id: Trace ID int
|
||||
|
||||
Returns:
|
||||
The trace ID as 32-byte hexadecimal string
|
||||
"""
|
||||
return format(trace_id, "032x")
|
||||
|
||||
|
||||
def span_id_to_str(span_id: int) -> str:
|
||||
"""Convenience span ID formatting method
|
||||
Args:
|
||||
span_id: Span ID int
|
||||
|
||||
Returns:
|
||||
The span ID as 16-byte hexadecimal string
|
||||
"""
|
||||
return format(span_id, "016x")
|
||||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id_to_str(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = random.getrandbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = random.getrandbits(128)
|
||||
return trace_id_to_str(trace_id)
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
|
||||
|
@ -83,7 +115,7 @@ class TraceContext:
|
|||
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_short_uuid(),
|
||||
span_id=generate_span_id(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(timezone.utc),
|
||||
|
@ -143,7 +175,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceCont
|
|||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return
|
||||
|
||||
trace_id = generate_short_uuid(16)
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Protocol, TypeVar
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
|
||||
|
||||
|
@ -18,13 +18,11 @@ class WebMethod:
|
|||
response_examples: Optional[List[Any]] = None
|
||||
method: Optional[str] = None
|
||||
raw_bytes_request_body: Optional[bool] = False
|
||||
# A descriptive name of the corresponding span created by tracing
|
||||
descriptive_name: Optional[str] = None
|
||||
|
||||
|
||||
class HasWebMethod(Protocol):
|
||||
__webmethod__: WebMethod
|
||||
|
||||
|
||||
T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def webmethod(
|
||||
|
@ -34,6 +32,7 @@ def webmethod(
|
|||
request_examples: Optional[List[Any]] = None,
|
||||
response_examples: Optional[List[Any]] = None,
|
||||
raw_bytes_request_body: Optional[bool] = False,
|
||||
descriptive_name: Optional[str] = None,
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
@ -44,15 +43,16 @@ def webmethod(
|
|||
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
||||
"""
|
||||
|
||||
def wrap(cls: T) -> T:
|
||||
cls.__webmethod__ = WebMethod(
|
||||
def wrap(func: T) -> T:
|
||||
func.__webmethod__ = WebMethod( # type: ignore
|
||||
route=route,
|
||||
method=method,
|
||||
public=public or False,
|
||||
request_examples=request_examples,
|
||||
response_examples=response_examples,
|
||||
raw_bytes_request_body=raw_bytes_request_body,
|
||||
descriptive_name=descriptive_name,
|
||||
)
|
||||
return cls
|
||||
return func
|
||||
|
||||
return wrap
|
||||
|
|
|
@ -17,6 +17,7 @@ import enum
|
|||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
|
@ -455,7 +456,7 @@ class JsonSchemaGenerator:
|
|||
"maxItems": len(args),
|
||||
"prefixItems": [self.type_to_schema(member_type) for member_type in args],
|
||||
}
|
||||
elif origin_type is Union:
|
||||
elif origin_type in (Union, types.UnionType):
|
||||
discriminator = None
|
||||
if typing.get_origin(data_type) is Annotated:
|
||||
discriminator = typing.get_args(data_type)[1].discriminator
|
||||
|
|
|
@ -78,7 +78,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
},
|
||||
|
|
|
@ -47,9 +47,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -37,7 +37,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -102,7 +102,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"CEREBRAS_API_KEY": (
|
||||
|
|
|
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
@ -55,6 +56,6 @@ docker run \
|
|||
```bash
|
||||
llama stack build --template cerebras --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
|
||||
```
|
||||
|
|
|
@ -58,7 +58,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
|
||||
tool_runtime:
|
||||
|
|
|
@ -108,7 +108,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"FIREWORKS_API_KEY": (
|
||||
|
|
|
@ -40,7 +40,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -43,6 +43,7 @@ export CUDA_VISIBLE_DEVICES=0
|
|||
export LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run --rm -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-v $HOME/.cache/huggingface:/data \
|
||||
-e HF_TOKEN=$HF_TOKEN \
|
||||
|
@ -66,6 +67,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|||
export CUDA_VISIBLE_DEVICES=1
|
||||
|
||||
docker run --rm -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-v $HOME/.cache/huggingface:/data \
|
||||
-e HF_TOKEN=$HF_TOKEN \
|
||||
|
@ -108,6 +110,7 @@ This method allows you to get started quickly without having to build the distri
|
|||
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
|
@ -135,6 +138,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
|
||||
|
|
|
@ -43,7 +43,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -39,7 +39,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -184,7 +184,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"FIREWORKS_API_KEY": (
|
||||
|
|
|
@ -69,7 +69,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -28,7 +28,11 @@ providers:
|
|||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/meta_reference_eval.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
|
@ -40,7 +44,11 @@ providers:
|
|||
datasetio:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config: {}
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/localfs_datasetio.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
@ -58,7 +66,7 @@ providers:
|
|||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/agents_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/agents_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
|
@ -70,7 +78,7 @@ providers:
|
|||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/faiss_store.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/faiss_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
@ -82,7 +90,7 @@ providers:
|
|||
metadata_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/experimental-post-training}/registry.db
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
|
|
|
@ -49,9 +49,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -160,7 +160,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"FIREWORKS_API_KEY": (
|
||||
|
|
|
@ -48,7 +48,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -43,7 +43,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -49,9 +49,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
|
|
|
@ -95,7 +95,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMASTACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"GROQ_API_KEY": (
|
||||
|
|
|
@ -43,7 +43,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -125,7 +125,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"HF_API_TOKEN": (
|
||||
|
|
|
@ -48,7 +48,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -43,7 +43,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -126,7 +126,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"HF_API_TOKEN": (
|
||||
|
|
|
@ -48,7 +48,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -43,7 +43,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -65,9 +65,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
@ -80,6 +81,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
@ -95,7 +97,7 @@ Make sure you have done `uv pip install llama-stack` and have the Llama Stack CL
|
|||
```bash
|
||||
llama stack build --template {{ name }} --image-type conda
|
||||
llama stack run distributions/{{ name }}/run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
```
|
||||
|
||||
|
@ -103,7 +105,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
|
||||
```bash
|
||||
llama stack run distributions/{{ name }}/run-with-safety.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
```
|
||||
|
|
|
@ -132,7 +132,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
|
|
|
@ -50,7 +50,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -44,7 +44,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -67,9 +67,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
|||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
|
|
@ -98,7 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"5001",
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
"INFERENCE_MODEL": (
|
||||
|
|
|
@ -46,7 +46,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
|
||||
datasetio:
|
||||
|
|
|
@ -39,9 +39,10 @@ You can do this via Conda (build code) or Docker which has a pre-built image.
|
|||
This method allows you to get started quickly without having to build the distribution code.
|
||||
|
||||
```bash
|
||||
LLAMA_STACK_PORT=5001
|
||||
LLAMA_STACK_PORT=8321
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
|
@ -55,7 +56,7 @@ docker run \
|
|||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
llama stack run ./run.yaml \
|
||||
--port 5001 \
|
||||
--port 8321 \
|
||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||
```
|
||||
|
|
|
@ -46,7 +46,6 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
datasetio:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue