Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -12,9 +12,7 @@ from .fixtures import INFERENCE_FIXTURES
def pytest_configure(config):
for model in ["llama_8b", "llama_3b", "llama_vision"]:
config.addinivalue_line(
"markers", f"{model}: mark test to run only with the given model"
)
config.addinivalue_line("markers", f"{model}: mark test to run only with the given model")
for fixture_name in INFERENCE_FIXTURES:
config.addinivalue_line(
@ -24,12 +22,8 @@ def pytest_configure(config):
MODEL_PARAMS = [
pytest.param(
"meta-llama/Llama-3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"
),
pytest.param(
"meta-llama/Llama-3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"
),
pytest.param("meta-llama/Llama-3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
pytest.param("meta-llama/Llama-3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
VISION_MODEL_PARAMS = [
@ -49,9 +43,7 @@ def pytest_generate_tests(metafunc):
params = []
inference_models = getattr(test_config, "inference_models", [])
for model in inference_models:
if ("Vision" in cls_name and "Vision" in model) or (
"Vision" not in cls_name and "Vision" not in model
):
if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model):
params.append(pytest.param(model, id=model))
if not params:
@ -74,10 +66,7 @@ def pytest_generate_tests(metafunc):
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
if test_config:
if custom_fixtures := [
(
scenario.fixture_combo_id
or scenario.provider_fixtures.get("inference")
)
(scenario.fixture_combo_id or scenario.provider_fixtures.get("inference"))
for scenario in test_config.scenarios
]:
fixtures = custom_fixtures

View file

@ -47,9 +47,7 @@ def inference_remote() -> ProviderFixture:
@pytest.fixture(scope="session")
def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
@ -88,9 +86,7 @@ def inference_cerebras() -> ProviderFixture:
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
@ -99,9 +95,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
Provider(
provider_id="ollama",
provider_type="remote::ollama",
config=OllamaImplConfig(
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
).model_dump(),
config=OllamaImplConfig(host="localhost", port=os.getenv("OLLAMA_PORT", 11434)).model_dump(),
)
],
)
@ -109,9 +103,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
@pytest_asyncio.fixture(scope="session")
def inference_vllm(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
return ProviderFixture(
providers=[
Provider(

View file

@ -162,9 +162,7 @@ class TestConvertChatCompletionRequest:
def test_includes_stratgy(self):
request = self._dummy_chat_completion_request()
request.sampling_params.strategy = TopPSamplingStrategy(
temperature=0.5, top_p=0.95
)
request.sampling_params.strategy = TopPSamplingStrategy(temperature=0.5, top_p=0.95)
converted = convert_chat_completion_request(request)
@ -375,9 +373,7 @@ class TestConvertNonStreamChatCompletionResponse:
choices=[
Choice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Hello World"
),
message=ChatCompletionMessage(role="assistant", content="Hello World"),
finish_reason="stop",
)
],

View file

@ -29,11 +29,7 @@ class TestEmbeddings:
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
assert all(isinstance(value, float) for embedding in response.embeddings for value in embedding)
@pytest.mark.asyncio
async def test_batch_embeddings(self, inference_model, inference_stack):
@ -53,11 +49,7 @@ class TestEmbeddings:
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == len(texts)
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
assert all(isinstance(value, float) for embedding in response.embeddings for value in embedding)
embedding_dim = len(response.embeddings[0])
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)

View file

@ -44,11 +44,7 @@ from .utils import group_chunks
def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
if ("Llama3.1" in model or "Llama-3.1" in model)
else StopReason.end_of_turn
)
return StopReason.end_of_message if ("Llama3.1" in model or "Llama-3.1" in model) else StopReason.end_of_turn
@pytest.fixture
@ -179,13 +175,9 @@ class TestInference:
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if (
chunk.delta.type == "text" and chunk.delta.text
): # if there's a token, we expect logprobs
if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@ -236,9 +228,7 @@ class TestInference:
assert len(response.completion_message.content) > 0
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(
self, inference_model, inference_stack, common_params
):
async def test_structured_output(self, inference_model, inference_stack, common_params):
inference_impl, _ = inference_stack
class AnswerFormat(BaseModel):
@ -295,9 +285,7 @@ class TestInference:
AnswerFormat.model_validate_json(response.completion_message.content)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
inference_impl, _ = inference_stack
response = [
r
@ -310,9 +298,7 @@ class TestInference:
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
@ -387,9 +373,7 @@ class TestInference:
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
)
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
@ -404,13 +388,10 @@ class TestInference:
if "Llama3.1" in inference_model:
assert all(
chunk.event.delta.type == "tool_call"
for chunk in grouped[ChatCompletionResponseEventType.progress]
chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress]
)
first = grouped[ChatCompletionResponseEventType.progress][0]
if not isinstance(
first.event.delta.tool_call, ToolCall
): # first chunk may contain entire call
if not isinstance(first.event.delta.tool_call, ToolCall): # first chunk may contain entire call
assert first.event.delta.parse_status == ToolCallParseStatus.started
last = grouped[ChatCompletionResponseEventType.progress][-1]

View file

@ -73,9 +73,7 @@ class TestVisionModelInference:
assert expected_string in response.completion_message.content
@pytest.mark.asyncio
async def test_vision_chat_completion_streaming(
self, inference_model, inference_stack
):
async def test_vision_chat_completion_streaming(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
images = [
@ -100,9 +98,7 @@ class TestVisionModelInference:
UserMessage(
content=[
image,
TextContentItem(
text="Describe this image in two sentences."
),
TextContentItem(text="Describe this image in two sentences."),
]
),
],
@ -112,18 +108,12 @@ class TestVisionModelInference:
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk)
for chunk in response
)
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
content = "".join(
chunk.event.delta.text
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
content = "".join(chunk.event.delta.text for chunk in grouped[ChatCompletionResponseEventType.progress])
for expected_string in expected_strings:
assert expected_string in content

View file

@ -10,7 +10,5 @@ import itertools
def group_chunks(response):
return {
event_type: list(group)
for event_type, group in itertools.groupby(
response, key=lambda chunk: chunk.event.event_type
)
for event_type, group in itertools.groupby(response, key=lambda chunk: chunk.event.event_type)
}