Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -27,13 +27,10 @@ def supported_inference_models() -> List[Model]:
m
for m in all_registered_models()
if (
m.model_family
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
or is_supported_safety_model(m)
)
]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
m.huggingface_repo: m.descriptor() for m in all_registered_models()
}
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {m.huggingface_repo: m.descriptor() for m in all_registered_models()}

View file

@ -28,9 +28,7 @@ class SentenceTransformerEmbeddingMixin:
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(
model.provider_resource_id
)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(contents)
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -36,9 +36,7 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
)
def build_model_alias_with_just_provider_model_id(
provider_model_id: str, model_descriptor: str
) -> ModelAlias:
def build_model_alias_with_just_provider_model_id(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias(
provider_model_id=provider_model_id,
aliases=[],
@ -54,16 +52,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
for alias in alias_obj.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
alias_obj.provider_model_id
)
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
# ensure we can go from llama model to provider model id
self.alias_to_provider_id_map[alias_obj.llama_model] = (
alias_obj.provider_model_id
)
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
alias_obj.llama_model
)
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
def get_provider_model_id(self, identifier: str) -> str:
if identifier in self.alias_to_provider_id_map:
@ -82,9 +74,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
@ -100,18 +90,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if (
model.metadata["llama_model"]
not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
):
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
model.metadata["llama_model"]
]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
)
return model

View file

@ -135,9 +135,7 @@ def convert_openai_completion_logprobs(
return None
def convert_openai_completion_logprobs_stream(
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
if logprobs is None:
return None
if isinstance(logprobs, float):
@ -148,9 +146,7 @@ def convert_openai_completion_logprobs_stream(
return None
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
@ -341,17 +337,13 @@ async def process_chat_completion_stream_response(
)
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(
content, download=download
),
"url": await convert_image_content_to_url(content, download=download),
},
}
else:

View file

@ -119,9 +119,7 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(
f"Invalid data URL format, {image.url.uri[:40]}..."
)
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
@ -201,19 +199,13 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
else:
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
model_input = formatter.encode_content(request.content)
@ -223,9 +215,7 @@ async def completion_request_to_prompt(
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
model_input = formatter.encode_content(request.content)
@ -288,8 +278,7 @@ def chat_completion_request_to_messages(
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
@ -327,9 +316,7 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
@ -361,9 +348,7 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
messages.append(SystemMessage(content=sys_content))
@ -397,9 +382,7 @@ def augment_messages_for_tools_llama_3_2(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
sys_content = ""
@ -422,9 +405,7 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
@ -433,9 +414,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n"
if existing_system_message:
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
messages.append(SystemMessage(content=sys_content))