forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -76,9 +76,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}
|
||||
)
|
||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||
|
@ -178,9 +176,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
return StreamingResponse(sse_generator(func(**kwargs)), media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
return await maybe_await(value)
|
||||
|
@ -190,11 +186,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
)
|
||||
]
|
||||
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
path_params = extract_path_params(route)
|
||||
|
@ -202,15 +194,9 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
# Annotate parameters that are in the path with Path(...) and others with Body(...)
|
||||
new_params = [new_params[0]] + [
|
||||
(
|
||||
param.replace(
|
||||
annotation=Annotated[
|
||||
param.annotation, FastapiPath(..., title=param.name)
|
||||
]
|
||||
)
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
@ -244,12 +230,8 @@ class ClientVersionMiddleware:
|
|||
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||
if client_version:
|
||||
try:
|
||||
client_version_parts = tuple(
|
||||
map(int, client_version.split(".")[:2])
|
||||
)
|
||||
server_version_parts = tuple(
|
||||
map(int, self.server_version.split(".")[:2])
|
||||
)
|
||||
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
|
||||
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
|
||||
if client_version_parts != server_version_parts:
|
||||
|
||||
async def send_version_error(send):
|
||||
|
@ -267,9 +249,7 @@ class ClientVersionMiddleware:
|
|||
}
|
||||
}
|
||||
).encode()
|
||||
await send(
|
||||
{"type": "http.response.body", "body": error_msg}
|
||||
)
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
return await send_version_error(send)
|
||||
except (ValueError, IndexError):
|
||||
|
@ -296,9 +276,7 @@ def main():
|
|||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-ipv6", action="store_true", help="Whether to disable IPv6 support"
|
||||
)
|
||||
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
|
@ -323,9 +301,7 @@ def main():
|
|||
raise ValueError(f"Config file {config_file} does not exist")
|
||||
print(f"Using config file: {config_file}")
|
||||
elif args.template:
|
||||
config_file = (
|
||||
Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
)
|
||||
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml"
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Template {args.template} does not exist")
|
||||
print(f"Using template {args.template} config file: {config_file}")
|
||||
|
@ -383,9 +359,7 @@ def main():
|
|||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=UserWarning, module="pydantic._internal._fields"
|
||||
)
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
|
@ -416,9 +390,7 @@ def main():
|
|||
|
||||
def extract_path_params(route: str) -> List[str]:
|
||||
segments = route.split("/")
|
||||
params = [
|
||||
seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")
|
||||
]
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
return params
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue