Fixed model name; use routing_key to get model

This commit is contained in:
Frieda (Jingying) Huang 2024-10-04 21:53:54 -04:00
parent 8ed548b18e
commit 0edf24b227
3 changed files with 19 additions and 8 deletions

View file

@ -156,6 +156,11 @@ async def instantiate_provider(
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
if hasattr(provider_config, "routing_key"):
routing_key = provider_config.routing_key
deps["routing_key"] = routing_key
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"

View file

@ -15,5 +15,12 @@ async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
routing_key = _deps.get("routing_key")
if not routing_key:
raise ValueError(
"Routing key is required for the Ollama adapter but was not found."
)
await impl.initialize(routing_key)
return impl

View file

@ -23,7 +23,7 @@ from llama_stack.providers.utils.inference.routable import RoutableProviderForMo
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Llama-Guard-3-1B": "xe/llamaguard3:latest",
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
"llama_guard": "xe/llamaguard3:latest",
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
@ -40,18 +40,18 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
self.url = url
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
self.model = "Llama3.2-1B-Instruct"
@property
def client(self) -> AsyncClient:
return AsyncClient(host=self.url)
async def initialize(self) -> None:
async def initialize(self, routing_key: str) -> None:
print("Initializing Ollama, checking connectivity to server...")
try:
await self.client.ps()
print(f"Connected to Ollama server. Pre-downloading {self.model}...")
await self.predownload_models(ollama_model=self.model)
ollama_model = self.map_to_provider_model(routing_key)
print(f"Connected to Ollama server. Pre-downloading {ollama_model}...")
await self.predownload_models(ollama_model)
except httpx.ConnectError as e:
raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
@ -140,8 +140,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
if ollama_model != self.model:
self.predownload_models(ollama_model)
self.predownload_models(ollama_model)
if not request.stream:
r = await self.client.chat(