mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
instantiate inference models
This commit is contained in:
parent
d2ec822b12
commit
7071c46422
6 changed files with 40 additions and 20 deletions
|
@ -8,7 +8,14 @@ from typing import Any, AsyncGenerator, Dict, List, Tuple
|
|||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import Models
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
from llama_stack.distribution.distribution import api_providers
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
from llama_stack.providers.impls.builtin.models.models import BuiltinModelsImpl
|
||||
from llama_stack.providers.registry.inference import available_providers
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class InferenceRouterImpl(Inference):
|
||||
|
@ -16,19 +23,36 @@ class InferenceRouterImpl(Inference):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
deps: List[Api],
|
||||
models_api: Models,
|
||||
) -> None:
|
||||
self.inner_impls = inner_impls
|
||||
self.deps = deps
|
||||
print("INIT INFERENCE ROUTER!")
|
||||
|
||||
# self.providers = {}
|
||||
# for routing_key, provider_impl in inner_impls:
|
||||
# self.providers[routing_key] = provider_impl
|
||||
# map of model_id to provider impl
|
||||
self.providers = {}
|
||||
self.models_api = models_api
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
inference_providers = api_providers()[Api.inference]
|
||||
|
||||
models_list_response = await self.models_api.list_models()
|
||||
for model_spec in models_list_response.models_list:
|
||||
|
||||
if model_spec.api != Api.inference.value:
|
||||
continue
|
||||
|
||||
if model_spec.provider_id not in inference_providers:
|
||||
raise ValueError(
|
||||
f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider"
|
||||
)
|
||||
impl = await instantiate_provider(
|
||||
inference_providers[model_spec.provider_id],
|
||||
deps=[],
|
||||
provider_config=GenericProviderConfig(
|
||||
provider_id=model_spec.provider_id,
|
||||
config=model_spec.provider_config,
|
||||
),
|
||||
)
|
||||
cprint(f"impl={impl}", "blue")
|
||||
# look up and initialize provider implementations for each model
|
||||
core_model_id = model_spec.llama_model_metadata.core_model_id
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue