From 24dbe448a3535341c5434f8087c934d102e82486 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 18 Sep 2024 14:52:37 -0700 Subject: [PATCH] models api client fix --- llama_stack/apis/models/clients.py | 4 ++-- llama_stack/distribution/server/server.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/llama_stack/apis/models/clients.py b/llama_stack/apis/models/clients.py index 85009eb3d..1ff7d5414 100644 --- a/llama_stack/apis/models/clients.py +++ b/llama_stack/apis/models/clients.py @@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional import fire import httpx -from llama_toolchain.core.datatypes import RemoteProviderConfig +from llama_stack.distribution.datatypes import RemoteProviderConfig from termcolor import cprint -from .api import * # noqa: F403 +from .models import * # noqa: F403 class ModelsClient(Models): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16d24cad5..68c197edb 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers @@ -331,6 +331,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): with open(yaml_config, "r") as fp: config = StackRunConfig(**yaml.safe_load(fp)) + print(config) + app = FastAPI() impls, specs = asyncio.run(resolve_impls(config.provider_map)) @@ -340,6 +342,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): all_endpoints = api_endpoints() apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) + print(apis_to_serve) + for api_str in apis_to_serve: api = Api(api_str) endpoints = all_endpoints[api]