From 5b2282afd452483143007f6216b27375ac62ffc5 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 13:17:27 -0800 Subject: [PATCH] ollama and databricks --- .../remote/inference/databricks/databricks.py | 32 ++++++-- .../remote/inference/ollama/ollama.py | 74 ++++++++++++++----- 2 files changed, 78 insertions(+), 28 deletions(-) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 8e1f7693a..fedea0f86 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -6,6 +6,8 @@ from typing import AsyncGenerator +from llama_models.datatypes import CoreModelId + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message @@ -15,7 +17,10 @@ from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -28,16 +33,25 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import DatabricksImplConfig -DATABRICKS_SUPPORTED_MODELS = { - "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", - "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", -} +model_aliases = [ + ModelAlias( + provider_model_id="databricks-meta-llama-3-1-70b-instruct", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="databricks-meta-llama-3-1-405b-instruct", + aliases=["Llama3.1-405B-Instruct"], + llama_model=CoreModelId.llama3_1_405b_instruct.value, + ), +] class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: DatabricksImplConfig) -> None: ModelRegistryHelper.__init__( - self, provider_to_common_model_aliases_map=DATABRICKS_SUPPORTED_MODELS + self, + model_aliases=model_aliases, ) self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -113,8 +127,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): def _get_params(self, request: ChatCompletionRequest) -> dict: return { - "model": self.map_to_provider_model(request.model), - "prompt": chat_completion_request_to_prompt(request, self.formatter), + "model": request.model, + "prompt": chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ), "stream": request.stream, **get_sampling_options(request.sampling_params), } diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f5750e0cf..bc80c7db2 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -7,13 +7,18 @@ from typing import AsyncGenerator import httpx +from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from ollama import AsyncClient +from llama_stack.providers.utils.inference.model_registry import ( + ModelAlias, + ModelRegistryHelper, +) + from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate @@ -33,19 +38,52 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( request_has_media, ) -OLLAMA_SUPPORTED_MODELS = { - "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", - "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", - "Llama3.2-1B-Instruct": "llama3.2:1b-instruct-fp16", - "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", - "Llama-Guard-3-8B": "llama-guard3:8b", - "Llama-Guard-3-1B": "llama-guard3:1b", - "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16", -} + +model_aliases = [ + ModelAlias( + provider_model_id="llama3.1:8b-instruct-fp16", + aliases=["Llama3.1-8B-Instruct"], + llama_model=CoreModelId.llama3_1_8b_instruct.value, + ), + ModelAlias( + provider_model_id="llama3.1:70b-instruct-fp16", + aliases=["Llama3.1-70B-Instruct"], + llama_model=CoreModelId.llama3_1_70b_instruct.value, + ), + ModelAlias( + provider_model_id="llama3.2:1b-instruct-fp16", + aliases=["Llama3.2-1B-Instruct"], + llama_model=CoreModelId.llama3_2_1b_instruct.value, + ), + ModelAlias( + provider_model_id="llama3.2:3b-instruct-fp16", + aliases=["Llama3.2-3B-Instruct"], + llama_model=CoreModelId.llama3_2_3b_instruct.value, + ), + ModelAlias( + provider_model_id="llama-guard3:8b", + aliases=["Llama-Guard-3-8B"], + llama_model=CoreModelId.llama_guard_3_8b.value, + ), + ModelAlias( + provider_model_id="llama-guard3:1b", + aliases=["Llama-Guard-3-1B"], + llama_model=CoreModelId.llama_guard_3_1b.value, + ), + ModelAlias( + provider_model_id="x/llama3.2-vision:11b-instruct-fp16", + aliases=["Llama3.2-11B-Vision-Instruct"], + llama_model=CoreModelId.llama3_2_11b_vision_instruct.value, + ), +] -class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): +class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPrivate): def __init__(self, url: str) -> None: + ModelRegistryHelper.__init__( + self, + model_aliases=model_aliases, + ) self.url = url self.formatter = ChatFormat(Tokenizer.get_instance()) @@ -65,12 +103,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def shutdown(self) -> None: pass - async def register_model(self, model: Model) -> None: - if model.provider_resource_id not in OLLAMA_SUPPORTED_MODELS: - raise ValueError( - f"Model {model.provider_resource_id} is not supported by Ollama" - ) - async def list_models(self) -> List[Model]: ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} @@ -103,8 +135,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = CompletionRequest( - model=model_id, + model=model.provider_resource_id, content=content, sampling_params=sampling_params, stream=stream, @@ -160,8 +193,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.model_store.get_model(model_id) request = ChatCompletionRequest( - model=model_id, + model=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -199,7 +233,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: input_dict["raw"] = True input_dict["prompt"] = chat_completion_request_to_prompt( - request, self.formatter + request, self.get_llama_model(request.model), self.formatter ) else: assert (