mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update provider and test
This commit is contained in:
parent
7c30808167
commit
608e827d36
5 changed files with 34 additions and 56 deletions
|
@ -7,4 +7,4 @@ distribution_spec:
|
|||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
image_type: conda
|
|
@ -5,20 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import InferenceEndpointAdapter, DatabricksAdapter
|
||||
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
if config.url is not None:
|
||||
impl = DatabricksAdapter(config)
|
||||
elif config.is_inference_endpoint():
|
||||
impl = InferenceEndpointAdapter(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
config, DatabricksImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
return impl
|
|
@ -12,11 +12,11 @@ from pydantic import BaseModel, Field
|
|||
|
||||
@json_schema_type
|
||||
class DatabricksImplConfig(BaseModel):
|
||||
url: Optional[str] = Field(
|
||||
url: str = Field(
|
||||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: Optional[str] = Field(
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
description="The Databricks API token",
|
||||
)
|
||||
)
|
|
@ -6,23 +6,24 @@
|
|||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
||||
from llama_stack.providers.utils.inference.augment_messages import (
|
||||
augment_messages_for_tools,
|
||||
)
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
DATABRICKS_SUPPORTED_MODELS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "databricks-meta-llama-3-1-8b-instruct",
|
||||
"Meta-Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Meta-Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||
}
|
||||
|
||||
|
||||
|
@ -35,8 +36,8 @@ class DatabricksInferenceAdapter(Inference):
|
|||
@property
|
||||
def client(self) -> OpenAI:
|
||||
return OpenAI(
|
||||
api_key=self.config.api_token,
|
||||
base_url=self.config.url
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -45,6 +46,11 @@ class DatabricksInferenceAdapter(Inference):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -91,7 +97,6 @@ class DatabricksInferenceAdapter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -103,40 +108,22 @@ class DatabricksInferenceAdapter(Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# accumulate sampling params and other options to pass to databricks
|
||||
messages = augment_messages_for_tools(request)
|
||||
options = self.get_databricks_chat_options(request)
|
||||
databricks_model = self.resolve_databricks_model(request.model)
|
||||
messages = prepare_messages(request)
|
||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||
prompt = self.tokenizer.decode(model_input.tokens)
|
||||
|
||||
input_tokens = len(model_input.tokens)
|
||||
max_new_tokens = min(
|
||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
||||
|
||||
assert (
|
||||
request.model == self.model_name
|
||||
), f"Model mismatch, expected {self.model_name}, got {request.model}"
|
||||
|
||||
if not request.stream:
|
||||
# TODO: might need to add back an async here
|
||||
|
||||
r = self.client.chat.completions.create(
|
||||
model=databricks_model,
|
||||
messages=self._messages_to_databricks_messages(messages),
|
||||
max_tokens=max_new_tokens,
|
||||
stream=False,
|
||||
**options,
|
||||
)
|
||||
|
||||
stop_reason = None
|
||||
if r.choices[0].finish_reason:
|
||||
if (
|
||||
r.choices[0].finish_reason == "stop"
|
||||
or r.choices[0].finish_reason == "eos"
|
||||
):
|
||||
if r.choices[0].finish_reason == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif r.choices[0].finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
@ -163,15 +150,13 @@ class DatabricksInferenceAdapter(Inference):
|
|||
for chunk in self.client.chat.completions.create(
|
||||
model=databricks_model,
|
||||
messages=self._messages_to_databricks_messages(messages),
|
||||
max_tokens=max_new_tokens,
|
||||
stream=True,
|
||||
**options,
|
||||
):
|
||||
if chunk.choices[0].finish_reason:
|
||||
if (
|
||||
stop_reason is None and chunk.choices[0].finish_reason == "stop"
|
||||
) or (
|
||||
stop_reason is None and chunk.choices[0].finish_reason == "eos"
|
||||
stop_reason is None
|
||||
and chunk.choices[0].finish_reason == "stop"
|
||||
):
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif (
|
||||
|
@ -181,7 +166,8 @@ class DatabricksInferenceAdapter(Inference):
|
|||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = chunk.choices[0].message.content
|
||||
text = chunk.choices[0].delta.content
|
||||
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
|
@ -268,4 +254,4 @@ class DatabricksInferenceAdapter(Inference):
|
|||
delta="",
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
)
|
||||
)
|
|
@ -68,7 +68,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="databricks",
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue