This commit is contained in:
Dinesh Yeduguru 2024-11-12 13:23:02 -08:00
parent 5b2282afd4
commit 71219b4937

View file

@ -20,7 +20,7 @@ from llama_stack.providers.utils.inference.model_registry import (
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -103,29 +103,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[Model]:
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
ret = []
res = await self.client.ps()
for r in res["models"]:
if r["model"] not in ollama_to_llama:
print(f"Ollama is running a model unknown to Llama Stack: {r['model']}")
continue
llama_model = ollama_to_llama[r["model"]]
print(f"Found model {llama_model} in Ollama")
ret.append(
Model(
identifier=llama_model,
metadata={
"ollama_model": r["model"],
},
)
)
return ret
async def completion(
self,
model_id: str,
@ -243,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
input_dict["raw"] = True
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,