mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
remove model lookup class
This commit is contained in:
parent
606df220f5
commit
1bb01f9346
2 changed files with 5 additions and 19 deletions
|
@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c
|
||||||
$ curl http://localhost:5000/inference/chat_completion \
|
$ curl http://localhost:5000/inference/chat_completion \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "Llama3.1-8B-Instruct",
|
"model_id": "Llama3.1-8B-Instruct",
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}
|
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}
|
||||||
|
|
|
@ -15,7 +15,6 @@ ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_mo
|
||||||
|
|
||||||
|
|
||||||
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
|
||||||
"""Get the Hugging Face repository for a given CoreModelId."""
|
|
||||||
for model in all_registered_models():
|
for model in all_registered_models():
|
||||||
if model.descriptor() == model_descriptor:
|
if model.descriptor() == model_descriptor:
|
||||||
return model.huggingface_repo
|
return model.huggingface_repo
|
||||||
|
@ -33,11 +32,8 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelLookup:
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(self, model_aliases: List[ModelAlias]):
|
||||||
self,
|
|
||||||
model_aliases: List[ModelAlias],
|
|
||||||
):
|
|
||||||
self.alias_to_provider_id_map = {}
|
self.alias_to_provider_id_map = {}
|
||||||
self.provider_id_to_llama_model_map = {}
|
self.provider_id_to_llama_model_map = {}
|
||||||
for alias_obj in model_aliases:
|
for alias_obj in model_aliases:
|
||||||
|
@ -57,22 +53,12 @@ class ModelLookup:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model: `{identifier}`")
|
raise ValueError(f"Unknown model: `{identifier}`")
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
|
||||||
|
|
||||||
def __init__(self, model_aliases: List[ModelAlias]):
|
|
||||||
self.model_lookup = ModelLookup(model_aliases)
|
|
||||||
|
|
||||||
def get_llama_model(self, provider_model_id: str) -> str:
|
def get_llama_model(self, provider_model_id: str) -> str:
|
||||||
return self.model_lookup.provider_id_to_llama_model_map[provider_model_id]
|
return self.provider_id_to_llama_model_map[provider_model_id]
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
provider_model_id = self.model_lookup.get_provider_model_id(
|
model.provider_resource_id = self.get_provider_model_id(
|
||||||
model.provider_resource_id
|
model.provider_resource_id
|
||||||
)
|
)
|
||||||
if not provider_model_id:
|
|
||||||
raise ValueError(f"Unknown model: `{model.provider_resource_id}`")
|
|
||||||
|
|
||||||
model.provider_resource_id = provider_model_id
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue