Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -15,9 +15,7 @@ from llama_stack.providers.utils.bedrock.refreshable_boto_session import (
)
def create_bedrock_client(
config: BedrockBaseConfig, service_name: str = "bedrock-runtime"
) -> BaseClient:
def create_bedrock_client(config: BedrockBaseConfig, service_name: str = "bedrock-runtime") -> BaseClient:
"""Creates a boto3 client for Bedrock services with the given configuration.
Args:

View file

@ -28,8 +28,7 @@ class BedrockBaseConfig(BaseModel):
)
profile_name: Optional[str] = Field(
default=None,
description="The profile name that contains credentials to use."
"Default use environment variable: AWS_PROFILE",
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
)
total_max_attempts: Optional[int] = Field(
default=None,

View file

@ -68,9 +68,7 @@ class RefreshableBotoSession:
# if sts_arn is given, get credential by assuming the given role
if self.sts_arn:
sts_client = session.client(
service_name="sts", region_name=self.region_name
)
sts_client = session.client(service_name="sts", region_name=self.region_name)
response = sts_client.assume_role(
RoleArn=self.sts_arn,
RoleSessionName=self.session_name,

View file

@ -68,9 +68,7 @@ def validate_dataset_schema(
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
)
raise ValueError(f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}")
def validate_row_schema(
@ -81,6 +79,4 @@ def validate_row_schema(
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)
raise ValueError(f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}")

View file

@ -27,13 +27,10 @@ def supported_inference_models() -> List[Model]:
m
for m in all_registered_models()
if (
m.model_family
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
or is_supported_safety_model(m)
)
]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {
m.huggingface_repo: m.descriptor() for m in all_registered_models()
}
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR = {m.huggingface_repo: m.descriptor() for m in all_registered_models()}

View file

@ -28,9 +28,7 @@ class SentenceTransformerEmbeddingMixin:
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(
model.provider_resource_id
)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(contents)
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -36,9 +36,7 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
)
def build_model_alias_with_just_provider_model_id(
provider_model_id: str, model_descriptor: str
) -> ModelAlias:
def build_model_alias_with_just_provider_model_id(provider_model_id: str, model_descriptor: str) -> ModelAlias:
return ModelAlias(
provider_model_id=provider_model_id,
aliases=[],
@ -54,16 +52,10 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
for alias in alias_obj.aliases:
self.alias_to_provider_id_map[alias] = alias_obj.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[alias_obj.provider_model_id] = (
alias_obj.provider_model_id
)
self.alias_to_provider_id_map[alias_obj.provider_model_id] = alias_obj.provider_model_id
# ensure we can go from llama model to provider model id
self.alias_to_provider_id_map[alias_obj.llama_model] = (
alias_obj.provider_model_id
)
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = (
alias_obj.llama_model
)
self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id
self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model
def get_provider_model_id(self, identifier: str) -> str:
if identifier in self.alias_to_provider_id_map:
@ -82,9 +74,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
@ -100,18 +90,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if (
model.metadata["llama_model"]
not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR
):
if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[
model.metadata["llama_model"]
]
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]]
)
return model

View file

@ -135,9 +135,7 @@ def convert_openai_completion_logprobs(
return None
def convert_openai_completion_logprobs_stream(
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
if logprobs is None:
return None
if isinstance(logprobs, float):
@ -148,9 +146,7 @@ def convert_openai_completion_logprobs_stream(
return None
def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
@ -341,17 +337,13 @@ async def process_chat_completion_stream_response(
)
async def convert_message_to_openai_dict(
message: Message, download: bool = False
) -> dict:
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem):
return {
"type": "image_url",
"image_url": {
"url": await convert_image_content_to_url(
content, download=download
),
"url": await convert_image_content_to_url(content, download=download),
},
}
else:

View file

@ -119,9 +119,7 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(
f"Invalid data URL format, {image.url.uri[:40]}..."
)
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
@ -201,19 +199,13 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
else:
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(
request: CompletionRequest, formatter: ChatFormat
) -> str:
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
model_input = formatter.encode_content(request.content)
@ -223,9 +215,7 @@ async def completion_request_to_prompt(
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.content = content
request = await convert_request_to_raw(request)
model_input = formatter.encode_content(request.content)
@ -288,8 +278,7 @@ def chat_completion_request_to_messages(
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
@ -327,9 +316,7 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
@ -361,9 +348,7 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
messages.append(SystemMessage(content=sys_content))
@ -397,9 +382,7 @@ def augment_messages_for_tools_llama_3_2(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
messages = []
sys_content = ""
@ -422,9 +405,7 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools:
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
raise ValueError(f"Non supported ToolPromptFormat {request.tool_prompt_format}")
tool_gen = PythonListCustomToolGenerator()
tool_template = tool_gen.gen(custom_tools)
@ -433,9 +414,7 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n"
if existing_system_message:
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
messages.append(SystemMessage(content=sys_content))

View file

@ -10,9 +10,7 @@ from typing import List, Optional, Protocol
class KVStore(Protocol):
# TODO: make the value type bytes instead of str
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ...
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None: ...
async def get(self, key: str) -> Optional[str]: ...

View file

@ -54,16 +54,11 @@ class SqliteKVStoreConfig(CommonConfig):
)
@classmethod
def sample_run_config(
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
):
def sample_run_config(cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/"
+ __distro_dir__
+ "}/"
+ db_name,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + db_name,
}

View file

@ -28,11 +28,7 @@ class InmemoryKVStoreImpl(KVStore):
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> List[str]:
return [
self._store[key]
for key in self._store.keys()
if key >= start_key and key < end_key
]
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:

View file

@ -46,7 +46,6 @@ class PostgresKVStoreImpl(KVStore):
"""
)
except Exception as e:
log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e
@ -55,9 +54,7 @@ class PostgresKVStoreImpl(KVStore):
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""

View file

@ -25,9 +25,7 @@ class RedisKVStoreImpl(KVStore):
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
if expiration:
@ -66,9 +64,7 @@ class RedisKVStoreImpl(KVStore):
if matching_keys:
values = await self.redis.mget(matching_keys)
return [
value.decode("utf-8") if isinstance(value, bytes) else value
for value in values
if value is not None
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
]
return []

View file

@ -34,9 +34,7 @@ class SqliteKVStoreImpl(KVStore):
)
await db.commit()
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
async def set(self, key: str, value: str, expiration: Optional[datetime] = None) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
@ -46,9 +44,7 @@ class SqliteKVStoreImpl(KVStore):
async def get(self, key: str) -> Optional[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
) as cursor:
async with db.execute(f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)) as cursor:
row = await cursor.fetchone()
if row is None:
return None

View file

@ -141,9 +141,7 @@ async def content_from_doc(doc: RAGDocument) -> str:
return interleaved_content_as_str(doc.content)
def make_overlapped_chunks(
document_id: str, text: str, window_len: int, overlap_len: int
) -> List[Chunk]:
def make_overlapped_chunks(document_id: str, text: str, window_len: int, overlap_len: int) -> List[Chunk]:
tokenizer = Tokenizer.get_instance()
tokens = tokenizer.encode(text, bos=False, eos=False)
@ -171,9 +169,7 @@ class EmbeddingIndex(ABC):
raise NotImplementedError()
@abstractmethod
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryChunksResponse:
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
@ -209,8 +205,6 @@ class VectorDBWithIndex:
score_threshold = params.get("score_threshold", 0.0)
query_str = interleaved_content_as_str(query)
embeddings_response = await self.inference_api.embeddings(
self.vector_db.embedding_model, [query_str]
)
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_str])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold)

View file

@ -23,9 +23,7 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any
def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
return {
"average": sum(
result["score"] for result in scoring_results if result["score"] is not None
)
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}

View file

@ -70,9 +70,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:
raise ValueError(
f"Scoring function def with identifier {scoring_fn.identifier} already exists."
)
raise ValueError(f"Scoring function def with identifier {scoring_fn.identifier} already exists.")
self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn
@abstractmethod
@ -98,11 +96,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
params.aggregation_functions = scoring_params.aggregation_functions
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
if params and hasattr(params, "aggregation_functions") and params.aggregation_functions:
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)
@ -112,7 +106,4 @@ class RegisteredBaseScoringFn(BaseScoringFn):
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows
]
return [await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows]

View file

@ -64,8 +64,7 @@ class TelemetryDatasetMixin:
for span in spans_by_id_resp.data.values():
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
attr in span.attributes and span.attributes[attr] is not None for attr in attributes_to_return
):
spans.append(
Span(

View file

@ -118,10 +118,7 @@ class SQLiteTraceStore(TraceStore):
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attributes_to_return
)
json_object = ", ".join(f"'{key}', json_extract(s.attributes, '$.{key}')" for key in attributes_to_return)
attributes_select = f"json_object({json_object})"
# SQLite CTE query with filtered attributes

View file

@ -45,16 +45,12 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i + 1}"
)
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
@ -70,14 +66,10 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
@ -92,9 +84,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
@ -109,9 +99,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try: