mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update provider and test
This commit is contained in:
parent
7c30808167
commit
608e827d36
5 changed files with 34 additions and 56 deletions
|
@ -7,4 +7,4 @@ distribution_spec:
|
||||||
safety: meta-reference
|
safety: meta-reference
|
||||||
agents: meta-reference
|
agents: meta-reference
|
||||||
telemetry: meta-reference
|
telemetry: meta-reference
|
||||||
image_type: conda
|
image_type: conda
|
|
@ -5,20 +5,12 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
from .databricks import InferenceEndpointAdapter, DatabricksAdapter
|
from .databricks import DatabricksInferenceAdapter
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
assert isinstance(
|
||||||
|
config, DatabricksImplConfig
|
||||||
if config.url is not None:
|
), f"Unexpected config type: {type(config)}"
|
||||||
impl = DatabricksAdapter(config)
|
impl = DatabricksInferenceAdapter(config)
|
||||||
elif config.is_inference_endpoint():
|
|
||||||
impl = InferenceEndpointAdapter(config)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid configuration. Specify either an URL or HF Inference Endpoint details (namespace and endpoint name)."
|
|
||||||
)
|
|
||||||
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
|
@ -12,11 +12,11 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatabricksImplConfig(BaseModel):
|
class DatabricksImplConfig(BaseModel):
|
||||||
url: Optional[str] = Field(
|
url: str = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Databricks model serving endpoint",
|
description="The URL for the Databricks model serving endpoint",
|
||||||
)
|
)
|
||||||
api_token: Optional[str] = Field(
|
api_token: str = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The Databricks API token",
|
description="The Databricks API token",
|
||||||
)
|
)
|
|
@ -6,23 +6,24 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message, StopReason
|
from llama_models.llama3.api.datatypes import Message, StopReason
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.prepare_messages import prepare_messages
|
from llama_stack.providers.utils.inference.augment_messages import (
|
||||||
|
augment_messages_for_tools,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import DatabricksImplConfig
|
from .config import DatabricksImplConfig
|
||||||
|
|
||||||
DATABRICKS_SUPPORTED_MODELS = {
|
DATABRICKS_SUPPORTED_MODELS = {
|
||||||
"Meta-Llama3.1-8B-Instruct": "databricks-meta-llama-3-1-8b-instruct",
|
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
||||||
"Meta-Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
||||||
"Meta-Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,8 +36,8 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
@property
|
@property
|
||||||
def client(self) -> OpenAI:
|
def client(self) -> OpenAI:
|
||||||
return OpenAI(
|
return OpenAI(
|
||||||
api_key=self.config.api_token,
|
base_url=self.config.url,
|
||||||
base_url=self.config.url
|
api_key=self.config.api_token
|
||||||
)
|
)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
@ -45,6 +46,11 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def validate_routing_keys(self, routing_keys: list[str]) -> None:
|
||||||
|
# these are the model names the Llama Stack will use to route requests to this provider
|
||||||
|
# perform validation here if necessary
|
||||||
|
pass
|
||||||
|
|
||||||
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -91,7 +97,6 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -103,40 +108,22 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to databricks
|
messages = augment_messages_for_tools(request)
|
||||||
options = self.get_databricks_chat_options(request)
|
options = self.get_databricks_chat_options(request)
|
||||||
databricks_model = self.resolve_databricks_model(request.model)
|
databricks_model = self.resolve_databricks_model(request.model)
|
||||||
messages = prepare_messages(request)
|
|
||||||
model_input = self.formatter.encode_dialog_prompt(messages)
|
|
||||||
prompt = self.tokenizer.decode(model_input.tokens)
|
|
||||||
|
|
||||||
input_tokens = len(model_input.tokens)
|
|
||||||
max_new_tokens = min(
|
|
||||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
|
||||||
self.max_tokens - input_tokens - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
|
||||||
|
|
||||||
assert (
|
|
||||||
request.model == self.model_name
|
|
||||||
), f"Model mismatch, expected {self.model_name}, got {request.model}"
|
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
# TODO: might need to add back an async here
|
|
||||||
r = self.client.chat.completions.create(
|
r = self.client.chat.completions.create(
|
||||||
model=databricks_model,
|
model=databricks_model,
|
||||||
messages=self._messages_to_databricks_messages(messages),
|
messages=self._messages_to_databricks_messages(messages),
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
stream=False,
|
stream=False,
|
||||||
**options,
|
**options,
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
if r.choices[0].finish_reason:
|
if r.choices[0].finish_reason:
|
||||||
if (
|
if r.choices[0].finish_reason == "stop":
|
||||||
r.choices[0].finish_reason == "stop"
|
|
||||||
or r.choices[0].finish_reason == "eos"
|
|
||||||
):
|
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif r.choices[0].finish_reason == "length":
|
elif r.choices[0].finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
@ -163,15 +150,13 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
for chunk in self.client.chat.completions.create(
|
for chunk in self.client.chat.completions.create(
|
||||||
model=databricks_model,
|
model=databricks_model,
|
||||||
messages=self._messages_to_databricks_messages(messages),
|
messages=self._messages_to_databricks_messages(messages),
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
stream=True,
|
stream=True,
|
||||||
**options,
|
**options,
|
||||||
):
|
):
|
||||||
if chunk.choices[0].finish_reason:
|
if chunk.choices[0].finish_reason:
|
||||||
if (
|
if (
|
||||||
stop_reason is None and chunk.choices[0].finish_reason == "stop"
|
stop_reason is None
|
||||||
) or (
|
and chunk.choices[0].finish_reason == "stop"
|
||||||
stop_reason is None and chunk.choices[0].finish_reason == "eos"
|
|
||||||
):
|
):
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
elif (
|
elif (
|
||||||
|
@ -181,7 +166,8 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
text = chunk.choices[0].message.content
|
text = chunk.choices[0].delta.content
|
||||||
|
|
||||||
if text is None:
|
if text is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -268,4 +254,4 @@ class DatabricksInferenceAdapter(Inference):
|
||||||
delta="",
|
delta="",
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
|
@ -68,7 +68,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="databricks",
|
adapter_type="databricks",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"openai",
|
"openai",
|
||||||
],
|
],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue