From 608e827d3670ba41ee648b5c81e2a12d2de2f2e1 Mon Sep 17 00:00:00 2001 From: prithu-dasgupta Date: Fri, 4 Oct 2024 15:53:18 -0700 Subject: [PATCH] update provider and test --- .../templates/local-databricks-build.yaml | 2 +- .../adapters/inference/databricks/__init__.py | 20 ++----- .../adapters/inference/databricks/config.py | 6 +- .../inference/databricks/databricks.py | 60 +++++++------------ llama_stack/providers/registry/inference.py | 2 +- 5 files changed, 34 insertions(+), 56 deletions(-) diff --git a/llama_stack/distribution/templates/local-databricks-build.yaml b/llama_stack/distribution/templates/local-databricks-build.yaml index 8d5c543df..754af7668 100644 --- a/llama_stack/distribution/templates/local-databricks-build.yaml +++ b/llama_stack/distribution/templates/local-databricks-build.yaml @@ -7,4 +7,4 @@ distribution_spec: safety: meta-reference agents: meta-reference telemetry: meta-reference -image_type: conda +image_type: conda \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/databricks/__init__.py b/llama_stack/providers/adapters/inference/databricks/__init__.py index c00d9c28c..097579d25 100644 --- a/llama_stack/providers/adapters/inference/databricks/__init__.py +++ b/llama_stack/providers/adapters/inference/databricks/__init__.py @@ -5,20 +5,12 @@ # the root directory of this source tree. from .config import DatabricksImplConfig -from .databricks import InferenceEndpointAdapter, DatabricksAdapter - +from .databricks import DatabricksInferenceAdapter async def get_adapter_impl(config: DatabricksImplConfig, _deps): - assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}" - - if config.url is not None: - impl = DatabricksAdapter(config) - elif config.is_inference_endpoint(): - impl = InferenceEndpointAdapter(config) - else: - raise ValueError( - "Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)." - ) - + assert isinstance( + config, DatabricksImplConfig + ), f"Unexpected config type: {type(config)}" + impl = DatabricksInferenceAdapter(config) await impl.initialize() - return impl + return impl \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/databricks/config.py b/llama_stack/providers/adapters/inference/databricks/config.py index 652c4ff17..927bb474c 100644 --- a/llama_stack/providers/adapters/inference/databricks/config.py +++ b/llama_stack/providers/adapters/inference/databricks/config.py @@ -12,11 +12,11 @@ from pydantic import BaseModel, Field @json_schema_type class DatabricksImplConfig(BaseModel): - url: Optional[str] = Field( + url: str = Field( default=None, description="The URL for the Databricks model serving endpoint", ) - api_token: Optional[str] = Field( + api_token: str = Field( default=None, description="The Databricks API token", - ) + ) \ No newline at end of file diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 213330902..eeffb938d 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -6,23 +6,24 @@ from typing import AsyncGenerator +from openai import OpenAI + from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model -from openai import OpenAI - from llama_stack.apis.inference import * # noqa: F403 -from llama_stack.providers.utils.inference.prepare_messages import prepare_messages +from llama_stack.providers.utils.inference.augment_messages import ( + augment_messages_for_tools, +) from .config import DatabricksImplConfig DATABRICKS_SUPPORTED_MODELS = { - "Meta-Llama3.1-8B-Instruct": "databricks-meta-llama-3-1-8b-instruct", - "Meta-Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", - "Meta-Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", + "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", + "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct", } @@ -35,8 +36,8 @@ class DatabricksInferenceAdapter(Inference): @property def client(self) -> OpenAI: return OpenAI( - api_key=self.config.api_token, - base_url=self.config.url + base_url=self.config.url, + api_key=self.config.api_token ) async def initialize(self) -> None: @@ -45,6 +46,11 @@ class DatabricksInferenceAdapter(Inference): async def shutdown(self) -> None: pass + async def validate_routing_keys(self, routing_keys: list[str]) -> None: + # these are the model names the Llama Stack will use to route requests to this provider + # perform validation here if necessary + pass + async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() @@ -91,7 +97,6 @@ class DatabricksInferenceAdapter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, messages=messages, @@ -103,40 +108,22 @@ class DatabricksInferenceAdapter(Inference): logprobs=logprobs, ) - # accumulate sampling params and other options to pass to databricks + messages = augment_messages_for_tools(request) options = self.get_databricks_chat_options(request) databricks_model = self.resolve_databricks_model(request.model) - messages = prepare_messages(request) - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) - - input_tokens = len(model_input.tokens) - max_new_tokens = min( - request.sampling_params.max_tokens or (self.max_tokens - input_tokens), - self.max_tokens - input_tokens - 1, - ) - - print(f"Calculated max_new_tokens: {max_new_tokens}") - - assert ( - request.model == self.model_name - ), f"Model mismatch, expected {self.model_name}, got {request.model}" if not request.stream: - # TODO: might need to add back an async here + r = self.client.chat.completions.create( model=databricks_model, messages=self._messages_to_databricks_messages(messages), - max_tokens=max_new_tokens, stream=False, **options, ) + stop_reason = None if r.choices[0].finish_reason: - if ( - r.choices[0].finish_reason == "stop" - or r.choices[0].finish_reason == "eos" - ): + if r.choices[0].finish_reason == "stop": stop_reason = StopReason.end_of_turn elif r.choices[0].finish_reason == "length": stop_reason = StopReason.out_of_tokens @@ -163,15 +150,13 @@ class DatabricksInferenceAdapter(Inference): for chunk in self.client.chat.completions.create( model=databricks_model, messages=self._messages_to_databricks_messages(messages), - max_tokens=max_new_tokens, stream=True, **options, ): if chunk.choices[0].finish_reason: if ( - stop_reason is None and chunk.choices[0].finish_reason == "stop" - ) or ( - stop_reason is None and chunk.choices[0].finish_reason == "eos" + stop_reason is None + and chunk.choices[0].finish_reason == "stop" ): stop_reason = StopReason.end_of_turn elif ( @@ -181,7 +166,8 @@ class DatabricksInferenceAdapter(Inference): stop_reason = StopReason.out_of_tokens break - text = chunk.choices[0].message.content + text = chunk.choices[0].delta.content + if text is None: continue @@ -268,4 +254,4 @@ class DatabricksInferenceAdapter(Inference): delta="", stop_reason=stop_reason, ) - ) + ) \ No newline at end of file diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 9002a23af..2b39c0a53 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -68,7 +68,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="databricks", + adapter_type="databricks", pip_packages=[ "openai", ],