This commit is contained in:
Ashwin Bharambe 2024-10-06 17:20:33 -07:00 committed by Ashwin Bharambe
parent 91e0063593
commit 1550187cd8

View file

@ -31,7 +31,6 @@ class MetaReferenceInferenceImpl(Inference):
if model is None: if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`") raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model self.model = model
self.registered_model_defs = []
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
async def initialize(self) -> None: async def initialize(self) -> None:
@ -39,24 +38,10 @@ class MetaReferenceInferenceImpl(Inference):
self.generator.start() self.generator.start()
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
existing = await self.get_model(model.identifier)
if existing is not None:
return
if model.identifier != self.model.descriptor(): if model.identifier != self.model.descriptor():
raise RuntimeError( raise RuntimeError(
f"Model mismatch: {model.identifier} != {self.model.descriptor()}" f"Model mismatch: {model.identifier} != {self.model.descriptor()}"
) )
self.registered_model_defs = [model]
async def list_models(self) -> List[ModelDef]:
return self.registered_model_defs
async def get_model(self, identifier: str) -> Optional[ModelDef]:
for model in self.registered_model_defs:
if model.identifier == identifier:
return model
return None
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.generator.stop() self.generator.stop()