forked from phoenix-oss/llama-stack-mirror
model registration in ollama and vllm check against the available models in the provider (#446)
tests: pytest -v -s -m "ollama" llama_stack/providers/tests/inference/test_text_inference.py pytest -v -s -m vllm_remote llama_stack/providers/tests/inference/test_text_inference.py --env VLLM_URL="http://localhost:9798/v1" ---------
This commit is contained in:
parent
7f6ac2fbd7
commit
787e2034b7
4 changed files with 73 additions and 14 deletions
|
@ -71,12 +71,9 @@ model_aliases = [
|
|||
]
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||
class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_aliases=model_aliases,
|
||||
)
|
||||
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||
self.url = url
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
|
@ -203,7 +200,9 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
|||
else:
|
||||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.get_llama_model(request.model), self.formatter
|
||||
request,
|
||||
self.register_helper.get_llama_model(request.model),
|
||||
self.formatter,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
|
@ -282,6 +281,18 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
|||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
models = await self.client.ps()
|
||||
available_models = [m["model"] for m in models["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue