mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: remove impl_getter function
We already have an impl at this point, no need to validate this again. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
95e9455335
commit
234eaf4709
1 changed files with 31 additions and 38 deletions
|
|
@ -44,7 +44,7 @@ from llama_stack.core.request_headers import (
|
||||||
request_provider_data_context,
|
request_provider_data_context,
|
||||||
user_from_scope,
|
user_from_scope,
|
||||||
)
|
)
|
||||||
from llama_stack.core.server.fastapi_router_registry import build_router
|
from llama_stack.core.server.fastapi_router_registry import build_router, has_router
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
from llama_stack.core.stack import (
|
from llama_stack.core.stack import (
|
||||||
Stack,
|
Stack,
|
||||||
|
|
@ -465,51 +465,44 @@ def create_app() -> StackApp:
|
||||||
apis_to_serve.add("prompts")
|
apis_to_serve.add("prompts")
|
||||||
apis_to_serve.add("conversations")
|
apis_to_serve.add("conversations")
|
||||||
|
|
||||||
def impl_getter(api: Api) -> Any:
|
|
||||||
"""Get the implementation for a given API."""
|
|
||||||
try:
|
|
||||||
return impls[api]
|
|
||||||
except KeyError as e:
|
|
||||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
|
||||||
|
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
# Try to discover and use a router factory from the API package
|
# Try to discover and use a router factory from the API package
|
||||||
router = build_router(api, impl_getter)
|
if has_router(api):
|
||||||
if router:
|
impl = impls[api]
|
||||||
app.include_router(router)
|
router = build_router(api, impl)
|
||||||
logger.debug(f"Registered router for {api} API")
|
if router:
|
||||||
else:
|
app.include_router(router)
|
||||||
# Fall back to old webmethod-based route discovery until the migration is complete
|
logger.debug(f"Registered router for {api} API")
|
||||||
routes = all_routes[api]
|
continue
|
||||||
try:
|
|
||||||
impl = impls[api]
|
|
||||||
except KeyError as e:
|
|
||||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
|
||||||
|
|
||||||
for route, _ in routes:
|
# Fall back to old webmethod-based route discovery until the migration is complete
|
||||||
if not hasattr(impl, route.name):
|
impl = impls[api]
|
||||||
# ideally this should be a typing violation already
|
|
||||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
|
||||||
|
|
||||||
impl_method = getattr(impl, route.name)
|
routes = all_routes[api]
|
||||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
for route, _ in routes:
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
if not hasattr(impl, route.name):
|
||||||
if not available_methods:
|
# ideally this should be a typing violation already
|
||||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||||
method = available_methods[0]
|
|
||||||
logger.debug(f"{method} {route.path}")
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
impl_method = getattr(impl, route.name)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||||
getattr(app, method.lower())(route.path, response_model=None)(
|
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||||
create_dynamic_typed_route(
|
if not available_methods:
|
||||||
impl_method,
|
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||||
method.lower(),
|
method = available_methods[0]
|
||||||
route.path,
|
logger.debug(f"{method} {route.path}")
|
||||||
)
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
|
getattr(app, method.lower())(route.path, response_model=None)(
|
||||||
|
create_dynamic_typed_route(
|
||||||
|
impl_method,
|
||||||
|
method.lower(),
|
||||||
|
route.path,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue