Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -44,11 +44,7 @@ from .utils import group_chunks
def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
return StopReason.end_of_message if ("Llama3.1" in model or "Llama-3.1" in model) else StopReason.end_of_turn
@pytest.fixture
@ -179,13 +175,9 @@ class TestInference:
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@ -236,9 +228,7 @@ class TestInference:
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
async def test_structured_output(self, inference_model, inference_stack, common_params):
inference_impl, _ = inference_stack
class AnswerFormat(BaseModel):
@ -295,9 +285,7 @@ class TestInference:
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
inference_impl, _ = inference_stack
response = [
r
@ -310,9 +298,7 @@ class TestInference:
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
@ -387,9 +373,7 @@ class TestInference:
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
@ -404,13 +388,10 @@ class TestInference:
if "Llama3.1" in inference_model:
assert all(
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.tool_call, ToolCall
): # first chunk may contain entire call
if not isinstance(first.event.delta.tool_call, ToolCall): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]