instantiate inference models

This commit is contained in:
Xi Yan 2024-09-19 21:33:35 -07:00
parent d2ec822b12
commit 7071c46422
6 changed files with 40 additions and 20 deletions

View file

@ -18,8 +18,6 @@ async def get_provider_impl(config: BuiltinImplConfig, deps: Dict[Api, ProviderS
config, BuiltinImplConfig
), f"Unexpected config type: {type(config)}"
print(config)
impl = BuiltinModelsImpl(config)
await impl.initialize()
return impl

View file

@ -25,7 +25,6 @@ class BuiltinModelsImpl(Models):
config: BuiltinImplConfig,
) -> None:
self.config = config
cprint(self.config, "red")
self.models = {
entry.core_model_id: ModelSpec(
llama_model_metadata=resolve_model(entry.core_model_id),

View file

@ -9,9 +9,9 @@ from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import Api
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
async def get_router_impl(models_api: Api):
from .inference import InferenceRouterImpl
impl = InferenceRouterImpl(inner_impls, deps)
impl = InferenceRouterImpl(models_api)
await impl.initialize()
return impl

View file

@ -8,7 +8,14 @@ from typing import Any, AsyncGenerator, Dict, List, Tuple
from llama_stack.distribution.datatypes import Api
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import Models
from llama_stack.distribution.datatypes import GenericProviderConfig
from llama_stack.distribution.distribution import api_providers
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.providers.impls.builtin.models.models import BuiltinModelsImpl
from llama_stack.providers.registry.inference import available_providers
from termcolor import cprint
class InferenceRouterImpl(Inference):
@ -16,19 +23,36 @@ class InferenceRouterImpl(Inference):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
deps: List[Api],
models_api: Models,
) -> None:
self.inner_impls = inner_impls
self.deps = deps
print("INIT INFERENCE ROUTER!")
# self.providers = {}
# for routing_key, provider_impl in inner_impls:
# self.providers[routing_key] = provider_impl
# map of model_id to provider impl
self.providers = {}
self.models_api = models_api
async def initialize(self) -> None:
pass
inference_providers = api_providers()[Api.inference]
models_list_response = await self.models_api.list_models()
for model_spec in models_list_response.models_list:
if model_spec.api != Api.inference.value:
continue
if model_spec.provider_id not in inference_providers:
raise ValueError(
f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider"
)
impl = await instantiate_provider(
inference_providers[model_spec.provider_id],
deps=[],
provider_config=GenericProviderConfig(
provider_id=model_spec.provider_id,
config=model_spec.provider_config,
),
)
cprint(f"impl={impl}", "blue")
# look up and initialize provider implementations for each model
core_model_id = model_spec.llama_model_metadata.core_model_id
async def shutdown(self) -> None:
pass