Fix precommit check after moving to ruff (#927)

Lint check in main branch is failing. This fixes the lint check after we
moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We
need to move to a `ruff.toml` file as well as fixing and ignoring some
additional checks.

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-02 09:46:45 -05:00 committed by GitHub
parent 4773092dd1
commit 34ab7a3b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
217 changed files with 981 additions and 2681 deletions

View file

@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
http_exc = translate_exception(exc)
return JSONResponse(
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
sig = inspect.signature(func)
new_params = [
inspect.Parameter(
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
]
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
new_params.extend(sig.parameters.values())
path_params = extract_path_params(route)
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
# Annotate parameters that are in the path with Path(...) and others with Body(...)
new_params = [new_params[0]] + [
(
param.replace(
annotation=Annotated[
param.annotation, FastapiPath(..., title=param.name)
]
)
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else param.replace(
annotation=Annotated[param.annotation, Body(..., embed=True)]
)
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
for param in new_params[1:]
]
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
if client_version:
try:
client_version_parts = tuple(
map(int, client_version.split(".")[:2])
)
server_version_parts = tuple(
map(int, self.server_version.split(".")[:2])
)
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
if client_version_parts != server_version_parts:
async def send_version_error(send):
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
}
}
).encode()
await send(
{"type": "http.response.body", "body": error_msg}
)
await send({"type": "http.response.body", "body": error_msg})
return await send_version_error(send)
except (ValueError, IndexError):
@ -296,9 +276,7 @@ def main():
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
)
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument(
"--env",
action="append",
@ -323,9 +301,7 @@ def main():
raise ValueError(f"Config file {config_file} does not exist")
print(f"Using config file: {config_file}")
elif args.template:
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
)
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
if not config_file.exists():
raise ValueError(f"Template {args.template} does not exist")
print(f"Using template {args.template} config file: {config_file}")
@ -383,9 +359,7 @@ def main():
impl_method = getattr(impl, endpoint.name)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=UserWarning, module="pydantic._internal._fields"
)
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
@ -416,9 +390,7 @@ def main():
def extract_path_params(route: str) -> List[str]:
segments = route.split("/")
params = [
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
]
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
return params