forked from phoenix-oss/llama-stack-mirror
merge
This commit is contained in:
commit
a54d757ade
197 changed files with 9392 additions and 3089 deletions
|
@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -189,13 +188,11 @@ class AgentToolGroupWithArgs(BaseModel):
|
|||
args: Dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = register_schema(
|
||||
Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
],
|
||||
name="AgentTool",
|
||||
)
|
||||
AgentToolGroup = Union[
|
||||
str,
|
||||
AgentToolGroupWithArgs,
|
||||
]
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
|
@ -312,20 +309,18 @@ class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
|||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
Field(discriminator="event_type"),
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
Union[
|
||||
AgentTurnResponseStepStartPayload,
|
||||
AgentTurnResponseStepProgressPayload,
|
||||
AgentTurnResponseStepCompletePayload,
|
||||
AgentTurnResponseTurnStartPayload,
|
||||
AgentTurnResponseTurnCompletePayload,
|
||||
AgentTurnResponseTurnAwaitingInputPayload,
|
||||
],
|
||||
name="AgentTurnResponseEventPayload",
|
||||
)
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -387,7 +382,6 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Agents(Protocol):
|
||||
"""Agents API for creating and interacting with agentic systems.
|
||||
|
||||
|
@ -399,7 +393,7 @@ class Agents(Protocol):
|
|||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
|
||||
@webmethod(route="/agents", method="POST")
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -411,7 +405,9 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -443,6 +439,7 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
|
@ -505,7 +502,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
|
@ -63,19 +63,15 @@ class TextContentItem(BaseModel):
|
|||
|
||||
|
||||
# other modalities can be added here
|
||||
InterleavedContentItem = register_schema(
|
||||
Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="InterleavedContentItem",
|
||||
)
|
||||
InterleavedContentItem = Annotated[
|
||||
Union[ImageContentItem, TextContentItem],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(InterleavedContentItem, name="InterleavedContentItem")
|
||||
|
||||
# accept a single "str" as a special case since it is common
|
||||
InterleavedContent = register_schema(
|
||||
Union[str, InterleavedContentItem, List[InterleavedContentItem]],
|
||||
name="InterleavedContent",
|
||||
)
|
||||
InterleavedContent = Union[str, InterleavedContentItem, List[InterleavedContentItem]]
|
||||
register_schema(InterleavedContent, name="InterleavedContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -109,10 +105,8 @@ class ToolCallDelta(BaseModel):
|
|||
|
||||
|
||||
# streaming completions send a stream of ContentDeltas
|
||||
ContentDelta = register_schema(
|
||||
Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ContentDelta",
|
||||
)
|
||||
ContentDelta = Annotated[
|
||||
Union[TextDelta, ImageDelta, ToolCallDelta],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ContentDelta, name="ContentDelta")
|
||||
|
|
|
@ -72,24 +72,22 @@ class DialogType(BaseModel):
|
|||
type: Literal["dialog"] = "dialog"
|
||||
|
||||
|
||||
ParamType = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
ParamType = Annotated[
|
||||
Union[
|
||||
StringType,
|
||||
NumberType,
|
||||
BooleanType,
|
||||
ArrayType,
|
||||
ObjectType,
|
||||
JsonType,
|
||||
UnionType,
|
||||
ChatCompletionInputType,
|
||||
CompletionInputType,
|
||||
AgentTurnInputType,
|
||||
],
|
||||
name="ParamType",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
||||
"""
|
||||
# TODO: recursive definition of ParamType in these containers
|
||||
|
|
|
@ -84,13 +84,11 @@ class RowsDataSource(BaseModel):
|
|||
rows: List[Dict[str, Any]]
|
||||
|
||||
|
||||
DataSource = register_schema(
|
||||
Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="DataSource",
|
||||
)
|
||||
DataSource = Annotated[
|
||||
Union[URIDataSource, RowsDataSource],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(DataSource, name="DataSource")
|
||||
|
||||
|
||||
class CommonDatasetFields(BaseModel):
|
||||
|
@ -121,8 +119,6 @@ class Dataset(CommonDatasetFields, Resource):
|
|||
|
||||
class DatasetInput(CommonDatasetFields, BaseModel):
|
||||
dataset_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListDatasetsResponse(BaseModel):
|
||||
|
|
144
llama_stack/apis/eval/eval.py
Normal file
144
llama_stack/apis/eval/eval.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelCandidate(BaseModel):
|
||||
"""A model candidate for evaluation.
|
||||
|
||||
:param model: The model ID to evaluate.
|
||||
:param sampling_params: The sampling parameters for the model.
|
||||
:param system_message: (Optional) The system message providing instructions or context to the model.
|
||||
"""
|
||||
|
||||
type: Literal["model"] = "model"
|
||||
model: str
|
||||
sampling_params: SamplingParams
|
||||
system_message: Optional[SystemMessage] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""A benchmark configuration for evaluation.
|
||||
|
||||
:param eval_candidate: The candidate to evaluate.
|
||||
:param scoring_params: Map between scoring function id and parameters for each scoring function you want to run
|
||||
:param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated
|
||||
"""
|
||||
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
num_examples: Optional[int] = Field(
|
||||
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
||||
default=None,
|
||||
)
|
||||
# we could optinally add any specific dataset config here
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""The response from an evaluation.
|
||||
|
||||
:param generations: The generations from the evaluation.
|
||||
:param scores: The scores from the evaluation.
|
||||
"""
|
||||
|
||||
generations: List[Dict[str, Any]]
|
||||
# each key in the dict is a scoring function name
|
||||
scores: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> Job:
|
||||
"""Run an evaluation on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
"""Evaluate a list of rows on a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
"""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
"""
|
||||
|
|
@ -144,18 +144,16 @@ class CompletionMessage(BaseModel):
|
|||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
Message = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
Message = Annotated[
|
||||
Union[
|
||||
UserMessage,
|
||||
SystemMessage,
|
||||
ToolResponseMessage,
|
||||
CompletionMessage,
|
||||
],
|
||||
name="Message",
|
||||
)
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(Message, name="Message")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -263,13 +261,11 @@ class GrammarResponseFormat(BaseModel):
|
|||
bnf: Dict[str, Any]
|
||||
|
||||
|
||||
ResponseFormat = register_schema(
|
||||
Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
],
|
||||
name="ResponseFormat",
|
||||
)
|
||||
ResponseFormat = Annotated[
|
||||
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ResponseFormat, name="ResponseFormat")
|
||||
|
||||
|
||||
# This is an internally used class
|
||||
|
|
|
@ -24,17 +24,6 @@ class HealthInfo(BaseModel):
|
|||
# TODO: add a provider level status
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderInfo(BaseModel):
|
||||
api: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
data: List[ProviderInfo]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VersionInfo(BaseModel):
|
||||
version: str
|
||||
|
@ -46,9 +35,6 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
@ -88,10 +88,8 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = register_schema(
|
||||
Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")],
|
||||
name="AlgorithmConfig",
|
||||
)
|
||||
AlgorithmConfig = Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")]
|
||||
register_schema(AlgorithmConfig, name="AlgorithmConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -184,7 +182,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
149
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
149
llama_stack/apis/scoring_functions/scoring_functions.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Union,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
judge_score_regexes: Optional[List[str]] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ScoringFnParams, name="ScoringFnParams")
|
||||
|
||||
|
||||
class CommonScoringFnFields(BaseModel):
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
params: Optional[ScoringFnParams] = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
||||
scoring_fn_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_scoring_fn_id: Optional[str] = None
|
||||
|
||||
|
||||
class ListScoringFunctionsResponse(BaseModel):
|
||||
data: List[ScoringFn]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
) -> None: ...
|
||||
|
|
@ -146,16 +146,14 @@ class SpanEndPayload(BaseModel):
|
|||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
StructuredLogPayload = Annotated[
|
||||
Union[
|
||||
SpanStartPayload,
|
||||
SpanEndPayload,
|
||||
],
|
||||
name="StructuredLogPayload",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -164,17 +162,15 @@ class StructuredLogEvent(EventCommon):
|
|||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
Event = Annotated[
|
||||
Union[
|
||||
UnstructuredLogEvent,
|
||||
MetricEvent,
|
||||
StructuredLogEvent,
|
||||
],
|
||||
name="Event",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
A document to be used for document ingestion in the RAG Tool.
|
||||
|
||||
:param document_id: The unique identifier for the document.
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
:param metadata: Additional metadata for the document.
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str | None = None
|
||||
|
@ -49,16 +58,14 @@ class LLMRAGQueryGeneratorConfig(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
RAGQueryGeneratorConfig = register_schema(
|
||||
Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
RAGQueryGeneratorConfig = Annotated[
|
||||
Union[
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
],
|
||||
name="RAGQueryGeneratorConfig",
|
||||
)
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -69,7 +69,7 @@ class ToolGroup(Resource):
|
|||
|
||||
@json_schema_type
|
||||
class ToolInvocationResult(BaseModel):
|
||||
content: InterleavedContent
|
||||
content: Optional[InterleavedContent] = None
|
||||
error_message: Optional[str] = None
|
||||
error_code: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
@ -140,9 +140,9 @@ class SpecialToolGroup(Enum):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
rag_tool: RAGToolRuntime
|
||||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
|
|
|
@ -36,7 +36,7 @@ class VectorDBStore(Protocol):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorIO(Protocol):
|
||||
vector_db_store: VectorDBStore
|
||||
vector_db_store: VectorDBStore | None = None
|
||||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue