mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
Fix register_model protocol to return Model
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
58a040a805
commit
6df600f014
5 changed files with 8 additions and 5 deletions
|
@ -21,7 +21,7 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
async def register_model(self, model: Model) -> None: ...
|
async def register_model(self, model: Model) -> Model: ...
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class SentenceTransformersInferenceImpl(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
|
|
@ -300,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model):
|
async def register_model(self, model: Model) -> Model:
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||||
|
@ -314,6 +314,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -83,7 +83,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> Model:
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
if model.provider_resource_id != self.model_id:
|
if model.provider_resource_id != self.model_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -346,7 +346,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
async def register_model(self, model: Model) -> Model:
|
||||||
assert self.client is not None
|
assert self.client is not None
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
res = await self.client.models.list()
|
res = await self.client.models.list()
|
||||||
|
@ -356,6 +356,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||||
f"Available models: {', '.join(available_models)}"
|
f"Available models: {', '.join(available_models)}"
|
||||||
)
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
options = get_sampling_options(request.sampling_params)
|
options = get_sampling_options(request.sampling_params)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue