models api client fix

This commit is contained in:
Xi Yan 2024-09-18 14:52:37 -07:00
parent 8ad821a533
commit 24dbe448a3
2 changed files with 9 additions and 5 deletions

View file

@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional
import fire import fire
import httpx import httpx
from llama_toolchain.core.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint from termcolor import cprint
from .api import * # noqa: F403 from .models import * # noqa: F403
class ModelsClient(Models): class ModelsClient(Models):

View file

@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers from llama_stack.distribution.distribution import api_endpoints, api_providers
@ -331,6 +331,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp: with open(yaml_config, "r") as fp:
config = StackRunConfig(**yaml.safe_load(fp)) config = StackRunConfig(**yaml.safe_load(fp))
print(config)
app = FastAPI() app = FastAPI()
impls, specs = asyncio.run(resolve_impls(config.provider_map)) impls, specs = asyncio.run(resolve_impls(config.provider_map))
@ -340,6 +342,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
print(apis_to_serve)
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
endpoints = all_endpoints[api] endpoints = all_endpoints[api]