Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -80,9 +80,7 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
self.config, model_id, llama_model
)
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
self.generator.start()
else:
self.generator = Llama.build(self.config, model_id, llama_model)
@ -100,9 +98,7 @@ class MetaReferenceInferenceImpl(
"No avaible model yet, please register your requested model or add your model in the resouces first"
)
elif request.model != self.model_id:
raise RuntimeError(
f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}"
)
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
async def unregister_model(self, model_id: str) -> None:
pass
@ -184,13 +180,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs = [
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
]
logprobs = [TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})]
yield CompletionResponseStreamChunk(
delta=text,
@ -212,9 +202,7 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
def impl():
tokens = []
logprobs = []
@ -231,13 +219,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
@ -289,9 +271,7 @@ class MetaReferenceInferenceImpl(
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(
request, self.llama_model.core_model_id.value
)
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
@ -304,9 +284,7 @@ class MetaReferenceInferenceImpl(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
def impl():
tokens = []
logprobs = []
@ -323,20 +301,12 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
@ -352,9 +322,7 @@ class MetaReferenceInferenceImpl(
else:
return impl()
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
def impl():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -405,13 +373,7 @@ class MetaReferenceInferenceImpl(
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
)
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -424,9 +386,7 @@ class MetaReferenceInferenceImpl(
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls: