mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Added support for structured output in the API and added a reference implementation for meta-reference. A few notes: * Two formats are specified in the API: Json schema and EBNF based grammar * Implementation only supports Json for now We use lm-format-enhancer to provide the implementation right now but may change this especially because BNF grammars aren't supported by that library. Fireworks has support for structured output and Together has limited supported for it too. Subsequent PRs will add these changes. We would like all our inference providers to provide structured output for llama models since it is an extremely important and highly sought-after need by the developers.
127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
from llama_models.llama3.api.chat_format import ChatFormat
|
|
|
|
from llama_models.llama3.api.datatypes import Message
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
|
|
from openai import OpenAI
|
|
|
|
from llama_stack.apis.inference import * # noqa: F403
|
|
|
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
get_sampling_options,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
)
|
|
|
|
from .config import DatabricksImplConfig
|
|
|
|
|
|
DATABRICKS_SUPPORTED_MODELS = {
|
|
"Llama3.1-70B-Instruct": "databricks-meta-llama-3-1-70b-instruct",
|
|
"Llama3.1-405B-Instruct": "databricks-meta-llama-3-1-405b-instruct",
|
|
}
|
|
|
|
|
|
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|
def __init__(self, config: DatabricksImplConfig) -> None:
|
|
ModelRegistryHelper.__init__(
|
|
self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS
|
|
)
|
|
self.config = config
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
async def initialize(self) -> None:
|
|
return
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def completion(
|
|
self,
|
|
model: str,
|
|
content: InterleavedTextMedia,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
request = ChatCompletionRequest(
|
|
model=model,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
tool_choice=tool_choice,
|
|
tool_prompt_format=tool_prompt_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
)
|
|
|
|
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
|
if stream:
|
|
return self._stream_chat_completion(request, client)
|
|
else:
|
|
return await self._nonstream_chat_completion(request, client)
|
|
|
|
async def _nonstream_chat_completion(
|
|
self, request: ChatCompletionRequest, client: OpenAI
|
|
) -> ChatCompletionResponse:
|
|
params = self._get_params(request)
|
|
r = client.completions.create(**params)
|
|
return process_chat_completion_response(r, self.formatter)
|
|
|
|
async def _stream_chat_completion(
|
|
self, request: ChatCompletionRequest, client: OpenAI
|
|
) -> AsyncGenerator:
|
|
params = self._get_params(request)
|
|
|
|
async def _to_async_generator():
|
|
s = client.completions.create(**params)
|
|
for chunk in s:
|
|
yield chunk
|
|
|
|
stream = _to_async_generator()
|
|
async for chunk in process_chat_completion_stream_response(
|
|
stream, self.formatter
|
|
):
|
|
yield chunk
|
|
|
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
|
return {
|
|
"model": self.map_to_provider_model(request.model),
|
|
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
|
"stream": request.stream,
|
|
**get_sampling_options(request),
|
|
}
|
|
|
|
async def embeddings(
|
|
self,
|
|
model: str,
|
|
contents: List[InterleavedTextMedia],
|
|
) -> EmbeddingsResponse:
|
|
raise NotImplementedError()
|