remove model lookup class

This commit is contained in:
Dinesh Yeduguru 2024-11-12 20:00:48 -08:00
parent 606df220f5
commit 1bb01f9346
2 changed files with 5 additions and 19 deletions

View file

@ -538,7 +538,7 @@ Once the server is set up, we can test it with a client to verify it's working c
$ curl http://localhost:5000/inference/chat_completion \
-H "Content-Type: application/json" \
-d '{
"model": "Llama3.1-8B-Instruct",
"model_id": "Llama3.1-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2 sentence poem about the moon"}

View file

@ -15,7 +15,6 @@ ModelAlias = namedtuple("ModelAlias", ["provider_model_id", "aliases", "llama_mo
def get_huggingface_repo(model_descriptor: str) -> Optional[str]:
"""Get the Hugging Face repository for a given CoreModelId."""
for model in all_registered_models():
if model.descriptor() == model_descriptor:
return model.huggingface_repo
@ -33,11 +32,8 @@ def build_model_alias(provider_model_id: str, model_descriptor: str) -> ModelAli
)
class ModelLookup:
def __init__(
self,
model_aliases: List[ModelAlias],
):
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_aliases: List[ModelAlias]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for alias_obj in model_aliases:
@ -57,22 +53,12 @@ class ModelLookup:
else:
raise ValueError(f"Unknown model: `{identifier}`")
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_aliases: List[ModelAlias]):
self.model_lookup = ModelLookup(model_aliases)
def get_llama_model(self, provider_model_id: str) -> str:
return self.model_lookup.provider_id_to_llama_model_map[provider_model_id]
return self.provider_id_to_llama_model_map[provider_model_id]
async def register_model(self, model: Model) -> Model:
provider_model_id = self.model_lookup.get_provider_model_id(
model.provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if not provider_model_id:
raise ValueError(f"Unknown model: `{model.provider_resource_id}`")
model.provider_resource_id = provider_model_id
return model