update provider and test

This commit is contained in:
prithu-dasgupta 2024-10-04 15:53:18 -07:00
parent 7c30808167
commit 608e827d36
5 changed files with 34 additions and 56 deletions

View file

@ -7,4 +7,4 @@ distribution_spec:
safety: meta-reference safety: meta-reference
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -5,20 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
from .databricks import InferenceEndpointAdapter, DatabricksAdapter from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps): async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}" assert isinstance(
config, DatabricksImplConfig
if config.url is not None: ), f"Unexpected config type: {type(config)}"
impl = DatabricksAdapter(config) impl = DatabricksInferenceAdapter(config)
elif config.is_inference_endpoint():
impl = InferenceEndpointAdapter(config)
else:
raise ValueError(
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,11 +12,11 @@ from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
class DatabricksImplConfig(BaseModel): class DatabricksImplConfig(BaseModel):
url: Optional[str] = Field( url: str = Field(
default=None, default=None,
description="The URL for the Databricks model serving endpoint", description="The URL for the Databricks model serving endpoint",
) )
api_token: Optional[str] = Field( api_token: str = Field(
default=None, default=None,
description="The Databricks API token", description="The Databricks API token",
) )

View file

@ -6,23 +6,24 @@
from typing import AsyncGenerator from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = { DATABRICKS_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "databricks-meta-llama-3-1-8b-instruct", "Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Meta-Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct", "Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
"Meta-Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
} }
@ -35,8 +36,8 @@ class DatabricksInferenceAdapter(Inference):
@property @property
def client(self) -> OpenAI: def client(self) -> OpenAI:
return OpenAI( return OpenAI(
api_key=self.config.api_token, base_url=self.config.url,
base_url=self.config.url api_key=self.config.api_token
) )
async def initialize(self) -> None: async def initialize(self) -> None:
@ -45,6 +46,11 @@ class DatabricksInferenceAdapter(Inference):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator: async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
@ -91,7 +97,6 @@ class DatabricksInferenceAdapter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,
@ -103,40 +108,22 @@ class DatabricksInferenceAdapter(Inference):
logprobs=logprobs, logprobs=logprobs,
) )
# accumulate sampling params and other options to pass to databricks messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request) options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model) databricks_model = self.resolve_databricks_model(request.model)
messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
input_tokens = len(model_input.tokens)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
print(f"Calculated max_new_tokens: {max_new_tokens}")
assert (
request.model == self.model_name
), f"Model mismatch, expected {self.model_name}, got {request.model}"
if not request.stream: if not request.stream:
# TODO: might need to add back an async here
r = self.client.chat.completions.create( r = self.client.chat.completions.create(
model=databricks_model, model=databricks_model,
messages=self._messages_to_databricks_messages(messages), messages=self._messages_to_databricks_messages(messages),
max_tokens=max_new_tokens,
stream=False, stream=False,
**options, **options,
) )
stop_reason = None stop_reason = None
if r.choices[0].finish_reason: if r.choices[0].finish_reason:
if ( if r.choices[0].finish_reason == "stop":
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length": elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
@ -163,15 +150,13 @@ class DatabricksInferenceAdapter(Inference):
for chunk in self.client.chat.completions.create( for chunk in self.client.chat.completions.create(
model=databricks_model, model=databricks_model,
messages=self._messages_to_databricks_messages(messages), messages=self._messages_to_databricks_messages(messages),
max_tokens=max_new_tokens,
stream=True, stream=True,
**options, **options,
): ):
if chunk.choices[0].finish_reason: if chunk.choices[0].finish_reason:
if ( if (
stop_reason is None and chunk.choices[0].finish_reason == "stop" stop_reason is None
) or ( and chunk.choices[0].finish_reason == "stop"
stop_reason is None and chunk.choices[0].finish_reason == "eos"
): ):
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif ( elif (
@ -181,7 +166,8 @@ class DatabricksInferenceAdapter(Inference):
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
break break
text = chunk.choices[0].message.content text = chunk.choices[0].delta.content
if text is None: if text is None:
continue continue
@ -268,4 +254,4 @@ class DatabricksInferenceAdapter(Inference):
delta="", delta="",
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )

View file

@ -68,7 +68,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="databricks", adapter_type="databricks",
pip_packages=[ pip_packages=[
"openai", "openai",
], ],