mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -44,11 +44,7 @@ from .utils import group_chunks
|
|||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return (
|
||||
StopReason.end_of_message
|
||||
if ("Llama3.1" in model or "Llama-3.1" in model)
|
||||
else StopReason.end_of_turn
|
||||
)
|
||||
return StopReason.end_of_message if ("Llama3.1" in model or "Llama-3.1" in model) else StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -179,13 +175,9 @@ class TestInference:
|
|||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if (
|
||||
chunk.delta.type == "text" and chunk.delta.text
|
||||
): # if there's a token, we expect logprobs
|
||||
if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
)
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
|
@ -236,9 +228,7 @@ class TestInference:
|
|||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_structured_output(
|
||||
self, inference_model, inference_stack, common_params
|
||||
):
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
@ -295,9 +285,7 @@ class TestInference:
|
|||
AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
|
||||
inference_impl, _ = inference_stack
|
||||
response = [
|
||||
r
|
||||
|
@ -310,9 +298,7 @@ class TestInference:
|
|||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
|
@ -387,9 +373,7 @@ class TestInference:
|
|||
)
|
||||
]
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
|
@ -404,13 +388,10 @@ class TestInference:
|
|||
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
chunk.event.delta.type == "tool_call"
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
if not isinstance(
|
||||
first.event.delta.tool_call, ToolCall
|
||||
): # first chunk may contain entire call
|
||||
if not isinstance(first.event.delta.tool_call, ToolCall): # first chunk may contain entire call
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue