diff --git a/llama_stack/apis/models/clients.py b/llama_stack/apis/models/clients.py index 85009eb3d..1ff7d5414 100644 --- a/llama_stack/apis/models/clients.py +++ b/llama_stack/apis/models/clients.py @@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional import fire import httpx -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig from termcolor import cprint -from .api import * # noqa: F403 +from .models import * # noqa: F403 class ModelsClient(Models): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16d24cad5..68c197edb 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers @@ -331,6 +331,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: config = StackRunConfig(**yaml.safe_load(fp)) + print(config) + app = FastAPI() impls, specs = asyncio.run(resolve_impls(config.provider_map)) @@ -340,6 +342,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): all_endpoints = api_endpoints() apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) + print(apis_to_serve) + for api_str in apis_to_serve: api = Api(api_str) endpoints = all_endpoints[api]