mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: more mypy fixes (#2029)
# What does this PR do? Mainly tried to cover the entire llama_stack/apis directory, we only have one left. Some excludes were just noop. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
feb9eb8b0d
commit
1a529705da
27 changed files with 581 additions and 166 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
@ -35,6 +36,14 @@ from .openai_responses import (
|
|||
OpenAIResponseObjectStream,
|
||||
)
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
@ -73,7 +82,7 @@ class StepCommon(BaseModel):
|
|||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(Enum):
|
||||
class StepType(StrEnum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
|
@ -97,7 +106,7 @@ class InferenceStep(StepCommon):
|
|||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference.value] = StepType.inference.value
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
|
@ -109,7 +118,7 @@ class ToolExecutionStep(StepCommon):
|
|||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
@ -121,7 +130,7 @@ class ShieldCallStep(StepCommon):
|
|||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
|
@ -133,7 +142,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
|
||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
@ -154,7 +163,7 @@ class Turn(BaseModel):
|
|||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=list)
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
@ -182,10 +191,10 @@ register_schema(AgentToolGroup, name="AgentTool")
|
|||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=list)
|
||||
output_shields: list[str] | None = Field(default_factory=list)
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=list)
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
@ -246,7 +255,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
|||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(Enum):
|
||||
class AgentTurnResponseEventType(StrEnum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
@ -258,15 +267,15 @@ class AgentTurnResponseEventType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
@ -276,7 +285,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
|
|||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
|
@ -285,21 +294,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
|
||||
AgentTurnResponseEventType.turn_awaiting_input.value
|
||||
)
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||
turn: Turn
|
||||
|
||||
|
||||
|
@ -341,7 +348,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
|
|
@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Benchmark(CommonBenchmarkFields, Resource):
|
||||
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
|
||||
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
|
||||
|
||||
@property
|
||||
def benchmark_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_benchmark_id(self) -> str:
|
||||
def provider_benchmark_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
|
|||
|
||||
url: URL | None = None
|
||||
# data is a base64 encoded string, hint with contentEncoding=base64
|
||||
data: str | None = Field(contentEncoding="base64", default=None)
|
||||
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
|
|
@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class Dataset(CommonDatasetFields, Resource):
|
||||
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
|
||||
type: Literal[ResourceType.dataset] = ResourceType.dataset
|
||||
|
||||
@property
|
||||
def dataset_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_dataset_id(self) -> str:
|
||||
def provider_dataset_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
|
@ -35,6 +36,16 @@ register_schema(ToolCall)
|
|||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
|
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
|
|||
role: Literal["assistant"] = "assistant"
|
||||
content: InterleavedContent
|
||||
stop_reason: StopReason
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=list)
|
||||
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
|
||||
|
||||
|
||||
Message = Annotated[
|
||||
|
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
|
||||
class ResponseFormatType(Enum):
|
||||
class ResponseFormatType(StrEnum):
|
||||
"""Types of formats for structured (guided) decoding.
|
||||
|
||||
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
|
||||
|
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
|
|||
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
|
||||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
|
||||
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
|
||||
json_schema: dict[str, Any]
|
||||
|
||||
|
||||
|
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
|
|||
:param bnf: The BNF grammar specification the response should conform to
|
||||
"""
|
||||
|
||||
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
|
||||
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
|
||||
bnf: dict[str, Any]
|
||||
|
||||
|
||||
|
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
messages: list[Message]
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=list)
|
||||
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
|
||||
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
|
||||
|
||||
response_format: ResponseFormat | None = None
|
||||
|
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
|
|||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
description: str | None
|
||||
strict: bool | None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: dict[str, Any] | None = None
|
||||
schema: dict[str, Any] | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -29,14 +29,14 @@ class ModelType(str, Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Model(CommonModelFields, Resource):
|
||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||
type: Literal[ResourceType.model] = ResourceType.model
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_model_id(self) -> str:
|
||||
def provider_model_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
|
|
@ -4,12 +4,23 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class ResourceType(Enum):
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResourceType(StrEnum):
|
||||
model = "model"
|
||||
shield = "shield"
|
||||
vector_db = "vector_db"
|
||||
|
@ -25,9 +36,9 @@ class Resource(BaseModel):
|
|||
|
||||
identifier: str = Field(description="Unique identifier for this resource in llama stack")
|
||||
|
||||
provider_resource_id: str = Field(
|
||||
description="Unique identifier for this resource in the provider",
|
||||
provider_resource_id: str | None = Field(
|
||||
default=None,
|
||||
description="Unique identifier for this resource in the provider",
|
||||
)
|
||||
|
||||
provider_id: str = Field(description="ID of the provider that owns this resource")
|
||||
|
|
|
@ -53,5 +53,5 @@ class Safety(Protocol):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# TODO: use enum.StrEnum when we drop support for python 3.10
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
|
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
|
|||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from enum import StrEnum
|
||||
else:
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
"""Backport of StrEnum for Python 3.10 and below."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
||||
# with standard metrics so they can be rolled up?
|
||||
@json_schema_type
|
||||
class ScoringFnParamsType(Enum):
|
||||
class ScoringFnParamsType(StrEnum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
class AggregationFunctionType(StrEnum):
|
||||
average = "average"
|
||||
weighted_average = "weighted_average"
|
||||
median = "median"
|
||||
|
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
|
||||
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
||||
judge_model: str
|
||||
prompt_template: str | None = None
|
||||
judge_score_regexes: list[str] | None = Field(
|
||||
judge_score_regexes: list[str] = Field(
|
||||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
|
||||
parsing_regexes: list[str] | None = Field(
|
||||
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
||||
parsing_regexes: list[str] = Field(
|
||||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
default_factory=lambda: [],
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: list[AggregationFunctionType] | None = Field(
|
||||
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
||||
aggregation_functions: list[AggregationFunctionType] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ScoringFn(CommonScoringFnFields, Resource):
|
||||
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
|
||||
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
||||
|
||||
@property
|
||||
def scoring_fn_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_scoring_fn_id(self) -> str:
|
||||
def provider_scoring_fn_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
|
|
@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
|
|||
class Shield(CommonShieldFields, Resource):
|
||||
"""A safety shield resource that can be used to check content"""
|
||||
|
||||
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
|
||||
type: Literal[ResourceType.shield] = ResourceType.shield
|
||||
|
||||
@property
|
||||
def shield_id(self) -> str:
|
||||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_shield_id(self) -> str:
|
||||
def provider_shield_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class Span(BaseModel):
|
|||
name: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
attributes: dict[str, Any] | None = Field(default_factory=dict)
|
||||
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
|
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
|
|||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
unit: str
|
||||
|
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
|
||||
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||
name: str
|
||||
parent_span_id: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanEndPayload(BaseModel):
|
||||
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
|
||||
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||
status: SpanStatus
|
||||
|
||||
|
||||
|
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
|||
|
||||
@json_schema_type
|
||||
class StructuredLogEvent(EventCommon):
|
||||
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
|
||||
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class ToolHost(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
|
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class ToolGroup(Resource):
|
||||
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
|
||||
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
|
||||
mcp_endpoint: URL | None = None
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
|
|||
|
||||
@json_schema_type
|
||||
class VectorDB(Resource):
|
||||
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
|
||||
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
|
||||
|
||||
embedding_model: str
|
||||
embedding_dimension: int
|
||||
|
@ -25,7 +25,7 @@ class VectorDB(Resource):
|
|||
return self.identifier
|
||||
|
||||
@property
|
||||
def provider_vector_db_id(self) -> str:
|
||||
def provider_vector_db_id(self) -> str | None:
|
||||
return self.provider_resource_id
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue