address feedback

This commit is contained in:
Dinesh Yeduguru 2024-11-13 13:02:45 -08:00
parent 96b1bafcde
commit 7e4765c45b
3 changed files with 44 additions and 8 deletions

View file

@ -47,7 +47,7 @@ def build_model_aliases():
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.model_register_helper = ModelRegistryHelper(build_model_aliases())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = None
@ -129,12 +129,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
model = await self.model_register_helper.register_model(model)
model = await self.register_helper.register_model(model)
res = self.client.models.list()
available_models = [m.id for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model {model.provider_resource_id} is not being served by vLLM"
f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}"
)
return model
@ -157,7 +158,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request,
self.model_register_helper.get_llama_model(request.model),
self.register_helper.get_llama_model(request.model),
self.formatter,
)
else:
@ -166,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
), "Together does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(
request,
self.model_register_helper.get_llama_model(request.model),
self.register_helper.get_llama_model(request.model),
self.formatter,
)