mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
models api client fix
This commit is contained in:
parent
8ad821a533
commit
24dbe448a3
2 changed files with 9 additions and 5 deletions
|
@ -13,10 +13,10 @@ from typing import Any, Dict, List, Optional
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_toolchain.core.datatypes import RemoteProviderConfig
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .api import * # noqa: F403
|
from .models import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class ModelsClient(Models):
|
class ModelsClient(Models):
|
||||||
|
|
|
@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||||
|
@ -331,6 +331,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
with open(yaml_config, "r") as fp:
|
with open(yaml_config, "r") as fp:
|
||||||
config = StackRunConfig(**yaml.safe_load(fp))
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
|
print(config)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||||
|
@ -340,6 +342,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
|
|
||||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||||
|
print(apis_to_serve)
|
||||||
|
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue