mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-02 20:40:36 +00:00
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -27,13 +27,10 @@ def supported_inference_models() -> List[Model]:
|
|||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family
|
||||
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
|
||||
m.huggingface_repo: m.descriptor() for m in all_registered_models()
|
||||
}
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {m.huggingface_repo: m.descriptor() for m in all_registered_models()}
|
||||
|
|
|
@ -28,9 +28,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
contents: List[InterleavedContent],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(
|
||||
model.provider_resource_id
|
||||
)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(contents)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
|
|
|
@ -36,9 +36,7 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
|
|||
)
|
||||
|
||||
|
||||
def build_model_alias_with_just_provider_model_id(
|
||||
provider_model_id: str, model_descriptor: str
|
||||
) -> ModelAlias:
|
||||
def build_model_alias_with_just_provider_model_id(provider_model_id: str, model_descriptor: str) -> ModelAlias:
|
||||
return ModelAlias(
|
||||
provider_model_id=provider_model_id,
|
||||
aliases=[],
|
||||
|
@ -54,16 +52,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
for alias in alias_obj.aliases:
|
||||
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
|
||||
# also add a mapping from provider model id to itself for easy lookup
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
|
||||
alias_obj.provider_model_id
|
||||
)
|
||||
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
|
||||
# ensure we can go from llama model to provider model id
|
||||
self.alias_to_provider_id_map[alias_obj.llama_model] = (
|
||||
alias_obj.provider_model_id
|
||||
)
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
|
||||
alias_obj.llama_model
|
||||
)
|
||||
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
|
||||
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str:
|
||||
if identifier in self.alias_to_provider_id_map:
|
||||
|
@ -82,9 +74,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(
|
||||
model.provider_resource_id
|
||||
)
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
|
@ -100,18 +90,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if (
|
||||
model.metadata["llama_model"]
|
||||
not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
|
||||
):
|
||||
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
|
||||
model.metadata["llama_model"]
|
||||
]
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
|
@ -135,9 +135,7 @@ def convert_openai_completion_logprobs(
|
|||
return None
|
||||
|
||||
|
||||
def convert_openai_completion_logprobs_stream(
|
||||
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
||||
):
|
||||
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
|
||||
if logprobs is None:
|
||||
return None
|
||||
if isinstance(logprobs, float):
|
||||
|
@ -148,9 +146,7 @@ def convert_openai_completion_logprobs_stream(
|
|||
return None
|
||||
|
||||
|
||||
def process_completion_response(
|
||||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
|
@ -341,17 +337,13 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict(
|
||||
message: Message, download: bool = False
|
||||
) -> dict:
|
||||
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
|
||||
async def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageContentItem):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": await convert_image_content_to_url(
|
||||
content, download=download
|
||||
),
|
||||
"url": await convert_image_content_to_url(content, download=download),
|
||||
},
|
||||
}
|
||||
else:
|
||||
|
|
|
@ -119,9 +119,7 @@ async def interleaved_content_convert_to_raw(
|
|||
if image.url.uri.startswith("data"):
|
||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
||||
)
|
||||
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
||||
_, image_data = match.groups()
|
||||
data = base64.b64decode(image_data)
|
||||
elif image.url.uri.startswith("file://"):
|
||||
|
@ -201,19 +199,13 @@ async def convert_image_content_to_url(
|
|||
|
||||
content, format = await localize_image_content(media)
|
||||
if include_format:
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||
"utf-8"
|
||||
)
|
||||
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
return base64.b64encode(content).decode("utf-8")
|
||||
|
||||
|
||||
async def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_content(request.content)
|
||||
|
@ -223,9 +215,7 @@ async def completion_request_to_prompt(
|
|||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(
|
||||
request.response_format, request.content
|
||||
)
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
model_input = formatter.encode_content(request.content)
|
||||
|
@ -288,8 +278,7 @@ def chat_completion_request_to_messages(
|
|||
return request.messages
|
||||
|
||||
if model.model_family == ModelFamily.llama3_1 or (
|
||||
model.model_family == ModelFamily.llama3_2
|
||||
and is_multimodal(model.core_model_id)
|
||||
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
|
@ -327,9 +316,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
|
||||
|
@ -361,9 +348,7 @@ def augment_messages_for_tools_llama_3_1(
|
|||
if isinstance(existing_system_message.content, str):
|
||||
sys_content += _process(existing_system_message.content)
|
||||
elif isinstance(existing_system_message.content, list):
|
||||
sys_content += "\n".join(
|
||||
[_process(c) for c in existing_system_message.content]
|
||||
)
|
||||
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
|
@ -397,9 +382,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if existing_messages[0].role == Role.system.value:
|
||||
existing_system_message = existing_messages.pop(0)
|
||||
|
||||
assert (
|
||||
existing_messages[0].role != Role.system.value
|
||||
), "Should only have 1 system message"
|
||||
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||
|
||||
messages = []
|
||||
sys_content = ""
|
||||
|
@ -422,9 +405,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
if custom_tools:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
|
||||
|
||||
tool_gen = PythonListCustomToolGenerator()
|
||||
tool_template = tool_gen.gen(custom_tools)
|
||||
|
@ -433,9 +414,7 @@ def augment_messages_for_tools_llama_3_2(
|
|||
sys_content += "\n"
|
||||
|
||||
if existing_system_message:
|
||||
sys_content += interleaved_content_as_str(
|
||||
existing_system_message.content, sep="\n"
|
||||
)
|
||||
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||
|
||||
messages.append(SystemMessage(content=sys_content))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue