Fixed model name; use routing_key to get model

This commit is contained in:
Frieda (Jingying) Huang 2024-10-04 21:53:54 -04:00
parent 8ed548b18e
commit 0edf24b227
3 changed files with 19 additions and 8 deletions

View file

@ -156,6 +156,11 @@ async def instantiate_provider(
assert isinstance(provider_config, GenericProviderConfig) assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class) config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config) config = config_type(**provider_config.config)
if hasattr(provider_config, "routing_key"):
routing_key = provider_config.routing_key
deps["routing_key"] = routing_key
args = [config, deps] args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"

View file

@ -15,5 +15,12 @@ async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .ollama import OllamaInferenceAdapter from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config.url) impl = OllamaInferenceAdapter(config.url)
await impl.initialize()
routing_key = _deps.get("routing_key")
if not routing_key:
raise ValueError(
"Routing key is required for the Ollama adapter but was not found."
)
await impl.initialize(routing_key)
return impl return impl

View file

@ -23,7 +23,7 @@ from llama_stack.providers.utils.inference.routable import RoutableProviderForMo
# TODO: Eventually this will move to the llama cli model list command # TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models # mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_SKUS = {
"Llama-Guard-3-1B": "xe/llamaguard3:latest", "Llama-Guard-3-8B": "xe/llamaguard3:latest",
"llama_guard": "xe/llamaguard3:latest", "llama_guard": "xe/llamaguard3:latest",
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
@ -40,18 +40,18 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
self.url = url self.url = url
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer) self.formatter = ChatFormat(tokenizer)
self.model = "Llama3.2-1B-Instruct"
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
return AsyncClient(host=self.url) return AsyncClient(host=self.url)
async def initialize(self) -> None: async def initialize(self, routing_key: str) -> None:
print("Initializing Ollama, checking connectivity to server...") print("Initializing Ollama, checking connectivity to server...")
try: try:
await self.client.ps() await self.client.ps()
print(f"Connected to Ollama server. Pre-downloading {self.model}...") ollama_model = self.map_to_provider_model(routing_key)
await self.predownload_models(ollama_model=self.model) print(f"Connected to Ollama server. Pre-downloading {ollama_model}...")
await self.predownload_models(ollama_model)
except httpx.ConnectError as e: except httpx.ConnectError as e:
raise RuntimeError( raise RuntimeError(
"Ollama Server is not running, start it using `ollama serve` in a separate terminal" "Ollama Server is not running, start it using `ollama serve` in a separate terminal"
@ -140,8 +140,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model) ollama_model = self.map_to_provider_model(request.model)
if ollama_model != self.model: self.predownload_models(ollama_model)
self.predownload_models(ollama_model)
if not request.stream: if not request.stream:
r = await self.client.chat( r = await self.client.chat(