routers wip

This commit is contained in:
Xi Yan 2024-09-19 08:32:47 -07:00
parent ef0e717bd0
commit f3ff3a3001
4 changed files with 246 additions and 155 deletions

View file

@ -16,6 +16,9 @@ from llama_models.datatypes import CoreModelId, Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.providers.adapters.inference.ollama.ollama import (
OllamaInferenceAdapter,
)
from llama_stack.providers.impls.meta_reference.inference.inference import (
MetaReferenceInferenceImpl,
@ -23,6 +26,7 @@ from llama_stack.providers.impls.meta_reference.inference.inference import (
from llama_stack.providers.impls.meta_reference.safety.safety import (
MetaReferenceSafetyImpl,
)
from llama_stack.providers.routers.inference.inference import InferenceRouterImpl
from .config import MetaReferenceImplConfig
@ -39,7 +43,7 @@ class MetaReferenceModelsImpl(Models):
self.safety_api = safety_api
self.models_list = []
model = get_model_id_from_api(self.inference_api)
# model = get_model_id_from_api(self.inference_api)
# TODO, make the inference route provider and use router provider to do the lookup dynamically
if isinstance(
@ -56,6 +60,25 @@ class MetaReferenceModelsImpl(Models):
)
)
if isinstance(
self.inference_api,
OllamaInferenceAdapter,
):
self.models_list.append(
ModelSpec(
providers_spec={
"inference": [{"provider_type": "remote::ollama"}],
},
)
)
if isinstance(
self.inference_api,
InferenceRouterImpl,
):
print("Found router")
print(self.inference_api.providers)
if isinstance(
self.safety_api,
MetaReferenceSafetyImpl,