fix imports

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:38:20 -07:00 committed by Dinesh Yeduguru
parent a3064ca6fc
commit b61730ef6b
4 changed files with 13 additions and 12 deletions

View file

@ -11,7 +11,6 @@ from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -27,6 +26,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import Registry as DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -68,7 +68,7 @@ class ProviderWithSpec(Provider):
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
dist_registry: distribution_store.Registry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
"""
Does two things:
@ -241,7 +241,7 @@ async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[str, Any],
inner_impls: Dict[str, Any],
dist_registry: distribution_store.Registry,
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()

View file

@ -7,7 +7,8 @@
from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.store import Registry as DistributionRegistry
from .routing_tables import (
DatasetsRoutingTable,
@ -22,7 +23,7 @@ async def get_routing_table_impl(
api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol],
_deps,
dist_registry: distribution_store.Registry,
dist_registry: DistributionRegistry,
) -> Any:
api_to_tables = {
"memory_banks": MemoryBanksRoutingTable,

View file

@ -13,8 +13,8 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.store import Registry as DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
def get_impl_api(p: Any) -> Api:
@ -53,7 +53,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: distribution_store.Registry,
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry

View file

@ -22,9 +22,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
@ -39,10 +36,13 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import DiskRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints
@ -292,7 +292,7 @@ def main(
)
)
dist_registry = distribution_store.DiskRegistry(dist_kvstore)
dist_registry = DiskRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls: