mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
models api client fix
This commit is contained in:
parent
8ad821a533
commit
24dbe448a3
2 changed files with 9 additions and 5 deletions
|
@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional
|
|||
import fire
|
||||
import httpx
|
||||
|
||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from termcolor import cprint
|
||||
|
||||
from .api import * # noqa: F403
|
||||
from .models import * # noqa: F403
|
||||
|
||||
|
||||
class ModelsClient(Models):
|
||||
|
|
|
@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
|
@ -331,6 +331,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
with open(yaml_config, "r") as fp:
|
||||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
|
||||
print(config)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
|
@ -340,6 +342,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
print(apis_to_serve)
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
endpoints = all_endpoints[api]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue