Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -86,9 +86,7 @@ class ShieldCallStep(StepCommon):
@json_schema_type
class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval.value] = (
StepType.memory_retrieval.value
)
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
vector_db_ids: str
inserted_context: InterleavedContent
@ -184,9 +182,7 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = (
AgentTurnResponseEventType.step_start.value
)
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
step_type: StepType
step_id: str
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
@ -194,9 +190,7 @@ class AgentTurnResponseStepStartPayload(BaseModel):
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = (
AgentTurnResponseEventType.step_complete.value
)
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
step_type: StepType
step_id: str
step_details: Step
@ -206,9 +200,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = (
AgentTurnResponseEventType.step_progress.value
)
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
step_type: StepType
step_id: str
@ -217,17 +209,13 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = (
AgentTurnResponseEventType.turn_start.value
)
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = (
AgentTurnResponseEventType.turn_complete.value
)
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
turn: Turn
@ -329,9 +317,7 @@ class Agents(Protocol):
toolgroups: Optional[List[AgentToolGroup]] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET"
)
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
async def get_agents_turn(
self,
agent_id: str,

View file

@ -63,9 +63,7 @@ class EventLogger:
if isinstance(chunk, ToolResponseMessage):
yield (
chunk,
LogEvent(
role="CustomTool", content=chunk.content, color="grey"
),
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
)
continue
@ -81,17 +79,12 @@ class EventLogger:
step_type = event.payload.step_type
# handle safety
if (
step_type == StepType.shield_call
and event_type == EventType.step_complete.value
):
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
violation = event.payload.step_details.violation
if not violation:
yield (
event,
LogEvent(
role=step_type, content="No Violation", color="magenta"
),
LogEvent(role=step_type, content="No Violation", color="magenta"),
)
else:
yield (
@ -110,9 +103,7 @@ class EventLogger:
# TODO: Currently this event is never received
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
LogEvent(role=step_type, content="", end="", color="yellow"),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
@ -125,9 +116,7 @@ class EventLogger:
):
yield (
event,
LogEvent(
role=step_type, content="", end="", color="yellow"
),
LogEvent(role=step_type, content="", end="", color="yellow"),
)
delta = event.payload.delta
@ -161,9 +150,7 @@ class EventLogger:
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(
response.tool_calls[0], tool_prompt_format
)
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
else:
content = response.content
yield (
@ -202,10 +189,7 @@ class EventLogger:
),
)
if (
step_type == StepType.memory_retrieval
and event_type == EventType.step_complete.value
):
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"

View file

@ -39,6 +39,4 @@ class DatasetIO(Protocol):
) -> PaginatedRowsResult: ...
@webmethod(route="/datasetio/rows", method="POST")
async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -63,9 +63,7 @@ class AppEvalTaskConfig(BaseModel):
EvalTaskConfig = register_schema(
Annotated[
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
],
Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
name="EvalTaskConfig",
)

View file

@ -245,9 +245,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
"""
type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value
)
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
json_schema: Dict[str, Any]
@ -406,9 +404,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.

View file

@ -89,9 +89,7 @@ class QATFinetuningConfig(BaseModel):
AlgorithmConfig = register_schema(
Annotated[
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
],
Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")],
name="AlgorithmConfig",
)
@ -204,14 +202,10 @@ class PostTraining(Protocol):
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]: ...
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]: ...
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -23,9 +23,7 @@ class ResourceType(Enum):
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
@ -34,6 +32,4 @@ class Resource(BaseModel):
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(
description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)"
)
type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)")

View file

@ -43,9 +43,7 @@ class AggregationFunctionType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
ScoringFnParamsType.llm_as_judge.value
)
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
judge_model: str
prompt_template: Optional[str] = None
judge_score_regexes: Optional[List[str]] = Field(
@ -60,9 +58,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = (
ScoringFnParamsType.regex_parser.value
)
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
@ -112,9 +108,7 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = (
ResourceType.scoring_function.value
)
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
@property
def scoring_fn_id(self) -> str:
@ -141,9 +135,7 @@ class ScoringFunctions(Protocol):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
async def get_scoring_function(
self, scoring_fn_id: str, /
) -> Optional[ScoringFn]: ...
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function(

View file

@ -102,9 +102,7 @@ class StructuredLogType(Enum):
@json_schema_type
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = (
StructuredLogType.SPAN_START.value
)
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
name: str
parent_span_id: Optional[str] = None
@ -190,9 +188,7 @@ class QuerySpanTreeResponse(BaseModel):
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")
async def log_event(
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
) -> None: ...
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
@webmethod(route="/telemetry/traces", method="GET")
async def query_traces(

View file

@ -64,9 +64,7 @@ RAGQueryGeneratorConfig = register_schema(
class RAGQueryConfig(BaseModel):
# This config defines how a query is generated using the messages
# for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(
default=DefaultRAGQueryGeneratorConfig()
)
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096
max_chunks: int = 5

View file

@ -150,8 +150,6 @@ class ToolRuntime(Protocol):
) -> List[ToolDef]: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(
self, tool_name: str, kwargs: Dict[str, Any]
) -> ToolInvocationResult:
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments"""
...