fix imports

This commit is contained in:
Dinesh Yeduguru 2024-11-01 14:38:20 -07:00 committed by Dinesh Yeduguru
parent a3064ca6fc
commit b61730ef6b
4 changed files with 13 additions and 12 deletions

View file

@ -11,7 +11,6 @@ from typing import Any, Dict, List, Set
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -27,6 +26,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import Registry as DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -68,7 +68,7 @@ class ProviderWithSpec(Provider):
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]], provider_registry: Dict[Api, Dict[str, ProviderSpec]],
dist_registry: distribution_store.Registry, dist_registry: DistributionRegistry,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """
Does two things: Does two things:
@ -241,7 +241,7 @@ async def instantiate_provider(
provider: ProviderWithSpec, provider: ProviderWithSpec,
deps: Dict[str, Any], deps: Dict[str, Any],
inner_impls: Dict[str, Any], inner_impls: Dict[str, Any],
dist_registry: distribution_store.Registry, dist_registry: DistributionRegistry,
): ):
protocols = api_protocol_map() protocols = api_protocol_map()
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()

View file

@ -7,7 +7,8 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.store import Registry as DistributionRegistry
from .routing_tables import ( from .routing_tables import (
DatasetsRoutingTable, DatasetsRoutingTable,
@ -22,7 +23,7 @@ async def get_routing_table_impl(
api: Api, api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol], impls_by_provider_id: Dict[str, RoutedProtocol],
_deps, _deps,
dist_registry: distribution_store.Registry, dist_registry: DistributionRegistry,
) -> Any: ) -> Any:
api_to_tables = { api_to_tables = {
"memory_banks": MemoryBanksRoutingTable, "memory_banks": MemoryBanksRoutingTable,

View file

@ -13,8 +13,8 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.distribution.store import Registry as DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:
@ -53,7 +53,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
impls_by_provider_id: Dict[str, RoutedProtocol], impls_by_provider_id: Dict[str, RoutedProtocol],
dist_registry: distribution_store.Registry, dist_registry: DistributionRegistry,
) -> None: ) -> None:
self.impls_by_provider_id = impls_by_provider_id self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry self.dist_registry = dist_registry

View file

@ -22,9 +22,6 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
@ -39,10 +36,13 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import llama_stack.distribution.store as distribution_store
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import DiskRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
@ -292,7 +292,7 @@ def main(
) )
) )
dist_registry = distribution_store.DiskRegistry(dist_kvstore) dist_registry = DiskRegistry(dist_kvstore)
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls: if Api.telemetry in impls: