models api client fix

This commit is contained in:
Xi Yan 2024-09-18 14:52:37 -07:00
parent 8ad821a533
commit 24dbe448a3
2 changed files with 9 additions and 5 deletions

View file

@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional
import fire
import httpx
from llama_toolchain.core.datatypes import RemoteProviderConfig
from llama_stack.distribution.datatypes import RemoteProviderConfig
from termcolor import cprint
from .api import * # noqa: F403
from .models import * # noqa: F403
class ModelsClient(Models):

View file

@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
@ -331,6 +331,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
with open(yaml_config, "r") as fp:
config = StackRunConfig(**yaml.safe_load(fp))
print(config)
app = FastAPI()
impls, specs = asyncio.run(resolve_impls(config.provider_map))
@ -340,6 +342,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
print(apis_to_serve)
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]