update provider and test

This commit is contained in:
prithu-dasgupta 2024-10-04 15:53:18 -07:00
parent 7c30808167
commit 608e827d36
5 changed files with 34 additions and 56 deletions

View file

@ -5,20 +5,12 @@
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import InferenceEndpointAdapter, DatabricksAdapter
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
if config.url is not None:
impl = DatabricksAdapter(config)
elif config.is_inference_endpoint():
impl = InferenceEndpointAdapter(config)
else:
raise ValueError(
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
)
assert isinstance(
config, DatabricksImplConfig
), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -12,11 +12,11 @@ from pydantic import BaseModel, Field
@json_schema_type
class DatabricksImplConfig(BaseModel):
url: Optional[str] = Field(
url: str = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: Optional[str] = Field(
api_token: str = Field(
default=None,
description="The Databricks API token",
)

View file

@ -6,23 +6,24 @@
from typing import AsyncGenerator
from openai import OpenAI
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message, StopReason
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
from llama_stack.providers.utils.inference.augment_messages import (
augment_messages_for_tools,
)
from .config import DatabricksImplConfig
DATABRICKS_SUPPORTED_MODELS = {
"Meta-Llama3.1-8B-Instruct": "databricks-meta-llama-3-1-8b-instruct",
"Meta-Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Meta-Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
}
@ -35,8 +36,8 @@ class DatabricksInferenceAdapter(Inference):
@property
def client(self) -> OpenAI:
return OpenAI(
api_key=self.config.api_token,
base_url=self.config.url
base_url=self.config.url,
api_key=self.config.api_token
)
async def initialize(self) -> None:
@ -45,6 +46,11 @@ class DatabricksInferenceAdapter(Inference):
async def shutdown(self) -> None:
pass
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
@ -91,7 +97,6 @@ class DatabricksInferenceAdapter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
messages=messages,
@ -103,40 +108,22 @@ class DatabricksInferenceAdapter(Inference):
logprobs=logprobs,
)
# accumulate sampling params and other options to pass to databricks
messages = augment_messages_for_tools(request)
options = self.get_databricks_chat_options(request)
databricks_model = self.resolve_databricks_model(request.model)
messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
input_tokens = len(model_input.tokens)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
print(f"Calculated max_new_tokens: {max_new_tokens}")
assert (
request.model == self.model_name
), f"Model mismatch, expected {self.model_name}, got {request.model}"
if not request.stream:
# TODO: might need to add back an async here
r = self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
max_tokens=max_new_tokens,
stream=False,
**options,
)
stop_reason = None
if r.choices[0].finish_reason:
if (
r.choices[0].finish_reason == "stop"
or r.choices[0].finish_reason == "eos"
):
if r.choices[0].finish_reason == "stop":
stop_reason = StopReason.end_of_turn
elif r.choices[0].finish_reason == "length":
stop_reason = StopReason.out_of_tokens
@ -163,15 +150,13 @@ class DatabricksInferenceAdapter(Inference):
for chunk in self.client.chat.completions.create(
model=databricks_model,
messages=self._messages_to_databricks_messages(messages),
max_tokens=max_new_tokens,
stream=True,
**options,
):
if chunk.choices[0].finish_reason:
if (
stop_reason is None and chunk.choices[0].finish_reason == "stop"
) or (
stop_reason is None and chunk.choices[0].finish_reason == "eos"
stop_reason is None
and chunk.choices[0].finish_reason == "stop"
):
stop_reason = StopReason.end_of_turn
elif (
@ -181,7 +166,8 @@ class DatabricksInferenceAdapter(Inference):
stop_reason = StopReason.out_of_tokens
break
text = chunk.choices[0].message.content
text = chunk.choices[0].delta.content
if text is None:
continue

View file

@ -68,7 +68,7 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_id="databricks",
adapter_type="databricks",
pip_packages=[
"openai",
],