mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fixed model name; use routing_key to get model
This commit is contained in:
parent
8ed548b18e
commit
0edf24b227
3 changed files with 19 additions and 8 deletions
|
@ -156,6 +156,11 @@ async def instantiate_provider(
|
|||
assert isinstance(provider_config, GenericProviderConfig)
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
|
||||
if hasattr(provider_config, "routing_key"):
|
||||
routing_key = provider_config.routing_key
|
||||
deps["routing_key"] = routing_key
|
||||
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
|
|
@ -15,5 +15,12 @@ async def get_adapter_impl(config: RemoteProviderConfig, _deps):
|
|||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
|
||||
routing_key = _deps.get("routing_key")
|
||||
if not routing_key:
|
||||
raise ValueError(
|
||||
"Routing key is required for the Ollama adapter but was not found."
|
||||
)
|
||||
|
||||
await impl.initialize(routing_key)
|
||||
return impl
|
||||
|
|
|
@ -23,7 +23,7 @@ from llama_stack.providers.utils.inference.routable import RoutableProviderForMo
|
|||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
"Llama-Guard-3-1B": "xe/llamaguard3:latest",
|
||||
"Llama-Guard-3-8B": "xe/llamaguard3:latest",
|
||||
"llama_guard": "xe/llamaguard3:latest",
|
||||
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
|
||||
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",
|
||||
|
@ -40,18 +40,18 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
self.url = url
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.model = "Llama3.2-1B-Instruct"
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
return AsyncClient(host=self.url)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def initialize(self, routing_key: str) -> None:
|
||||
print("Initializing Ollama, checking connectivity to server...")
|
||||
try:
|
||||
await self.client.ps()
|
||||
print(f"Connected to Ollama server. Pre-downloading {self.model}...")
|
||||
await self.predownload_models(ollama_model=self.model)
|
||||
ollama_model = self.map_to_provider_model(routing_key)
|
||||
print(f"Connected to Ollama server. Pre-downloading {ollama_model}...")
|
||||
await self.predownload_models(ollama_model)
|
||||
except httpx.ConnectError as e:
|
||||
raise RuntimeError(
|
||||
"Ollama Server is not running, start it using `ollama serve` in a separate terminal"
|
||||
|
@ -140,8 +140,7 @@ class OllamaInferenceAdapter(Inference, RoutableProviderForModels):
|
|||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.map_to_provider_model(request.model)
|
||||
|
||||
if ollama_model != self.model:
|
||||
self.predownload_models(ollama_model)
|
||||
self.predownload_models(ollama_model)
|
||||
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue