diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 6d106ccf1..f318e6180 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -8,16 +8,22 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 + from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, + chat_completion_request_to_prompt, ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) from .config import DatabricksImplConfig @@ -34,12 +40,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): self, stack_to_provider_models_map=DATABRICKS_SUPPORTED_MODELS ) self.config = config - tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(tokenizer) - - @property - def client(self) -> OpenAI: - return OpenAI(base_url=self.config.url, api_key=self.config.api_token) + self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -47,35 +48,10 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - async def validate_routing_keys(self, routing_keys: list[str]) -> None: - # these are the model names the Llama Stack will use to route requests to this provider - # perform validation here if necessary - pass - - async def completion(self, request: CompletionRequest) -> AsyncGenerator: + def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def _messages_to_databricks_messages(self, messages: list[Message]) -> list: - databricks_messages = [] - for message in messages: - if message.role == "ipython": - role = "tool" - else: - role = message.role - databricks_messages.append({"role": role, "content": message.content}) - - return databricks_messages - - def get_databricks_chat_options(self, request: ChatCompletionRequest) -> dict: - options = {} - if request.sampling_params is not None: - for attr in {"temperature", "top_p", "top_k", "max_tokens"}: - if getattr(request.sampling_params, attr): - options[attr] = getattr(request.sampling_params, attr) - - return options - - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -97,146 +73,39 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): logprobs=logprobs, ) - messages = augment_messages_for_tools(request) - options = self.get_databricks_chat_options(request) - databricks_model = self.map_to_provider_model(request.model) - - if not request.stream: - r = self.client.chat.completions.create( - model=databricks_model, - messages=self._messages_to_databricks_messages(messages), - stream=False, - **options, - ) - - stop_reason = None - if r.choices[0].finish_reason: - if r.choices[0].finish_reason == "stop": - stop_reason = StopReason.end_of_turn - elif r.choices[0].finish_reason == "length": - stop_reason = StopReason.out_of_tokens - - completion_message = self.formatter.decode_assistant_message_from_content( - r.choices[0].message.content, stop_reason - ) - yield ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, - ) + client = OpenAI(base_url=self.config.url, api_key=self.config.api_token) + if stream: + return self._stream_chat_completion(request, client) else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) + return self._nonstream_chat_completion(request, client) - buffer = "" - ipython = False - stop_reason = None + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = client.completions.create(**params) + return process_chat_completion_response(request, r, self.formatter) - for chunk in self.client.chat.completions.create( - model=databricks_model, - messages=self._messages_to_databricks_messages(messages), - stream=True, - **options, - ): - if chunk.choices[0].finish_reason: - if stop_reason is None and chunk.choices[0].finish_reason == "stop": - stop_reason = StopReason.end_of_turn - elif ( - stop_reason is None - and chunk.choices[0].finish_reason == "length" - ): - stop_reason = StopReason.out_of_tokens - break + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: + params = self._get_params(request) - text = chunk.choices[0].delta.content + async def _to_async_generator(): + s = client.completions.create(**params) + for chunk in s: + yield chunk - if text is None: - continue + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + def _get_params(self, request: ChatCompletionRequest) -> dict: + return { + "model": self.map_to_provider_model(request.model), + "prompt": chat_completion_request_to_prompt(request, self.formatter), + "stream": request.stream, + **get_sampling_options(request), + } diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index 0ad20edd6..bd05f98bb 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -10,13 +10,19 @@ from typing import AsyncGenerator from huggingface_hub import AsyncInferenceClient, HfApi from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, + chat_completion_request_to_model_input_info, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + process_chat_completion_response, + process_chat_completion_stream_response, ) from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig @@ -30,8 +36,7 @@ class _HfAdapter(Inference): model_id: str def __init__(self) -> None: - self.tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(self.tokenizer) + self.formatter = ChatFormat(Tokenizer.get_instance()) async def register_model(self, model: ModelDef) -> None: resolved_model = resolve_model(model.identifier) @@ -49,7 +54,7 @@ class _HfAdapter(Inference): async def shutdown(self) -> None: pass - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, @@ -59,16 +64,7 @@ class _HfAdapter(Inference): ) -> AsyncGenerator: raise NotImplementedError() - def get_chat_options(self, request: ChatCompletionRequest) -> dict: - options = {} - if request.sampling_params is not None: - for attr in {"temperature", "top_p", "top_k", "max_tokens"}: - if getattr(request.sampling_params, attr): - options[attr] = getattr(request.sampling_params, attr) - - return options - - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -90,145 +86,64 @@ class _HfAdapter(Inference): logprobs=logprobs, ) - messages = augment_messages_for_tools(request) - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) + if stream: + return self._stream_chat_completion(request) + else: + return self._nonstream_chat_completion(request) - input_tokens = len(model_input.tokens) + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = await self.client.text_generation(**params) + + choice = OpenAICompatCompletionChoice( + finish_reason=r.details.finish_reason, + text="".join(t.text for t in r.details.tokens), + ) + response = OpenAICompatCompletionResponse( + choices=[choice], + ) + return process_chat_completion_response(request, response, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + params = self._get_params(request) + + async def _generate_and_convert_to_openai_compat(): + s = await self.client.text_generation(**params) + async for chunk in s: + token_result = chunk.token + + choice = OpenAICompatCompletionChoice(text=token_result.text) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + prompt, input_tokens = chat_completion_request_to_model_input_info( + request, self.formatter + ) max_new_tokens = min( request.sampling_params.max_tokens or (self.max_tokens - input_tokens), self.max_tokens - input_tokens - 1, ) - - options = self.get_chat_options(request) - if not request.stream: - response = await self.client.text_generation( - prompt=prompt, - stream=False, - details=True, - max_new_tokens=max_new_tokens, - stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, - ) - stop_reason = None - if response.details.finish_reason: - if response.details.finish_reason in ["stop", "eos_token"]: - stop_reason = StopReason.end_of_turn - elif response.details.finish_reason == "length": - stop_reason = StopReason.out_of_tokens - - generated_text = "".join(t.text for t in response.details.tokens) - completion_message = self.formatter.decode_assistant_message_from_content( - generated_text, - stop_reason, - ) - yield ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, - ) - - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - buffer = "" - ipython = False - stop_reason = None - tokens = [] - - async for response in await self.client.text_generation( - prompt=prompt, - stream=True, - details=True, - max_new_tokens=max_new_tokens, - stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **options, - ): - token_result = response.token - - buffer += token_result.text - tokens.append(token_result.id) - - if not ipython and buffer.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer = buffer[len("<|python_tag|>") :] - continue - - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - else: - text = token_result.text - - if ipython: - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - else: - delta = text - - if stop_reason is None: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message(tokens, stop_reason) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + options = get_sampling_options(request) + return dict( + prompt=prompt, + stream=request.stream, + details=True, + max_new_tokens=max_new_tokens, + stop_sequences=["<|eom_id|>", "<|eot_id|>"], + **options, + ) class TGIAdapter(_HfAdapter): diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index d9a9ae491..adea696fb 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -48,10 +48,6 @@ class TogetherInferenceAdapter( self.config = config self.formatter = ChatFormat(Tokenizer.get_instance()) - @property - def client(self) -> Together: - return Together(api_key=self.config.api_key) - async def initialize(self) -> None: return @@ -91,7 +87,6 @@ class TogetherInferenceAdapter( together_api_key = provider_data.together_api_key client = Together(api_key=together_api_key) - # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, messages=messages, diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 6b12a54e6..107a534d5 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -55,8 +55,8 @@ def get_expected_stop_reason(model: str): @pytest_asyncio.fixture( scope="session", params=[ - # {"model": Llama_8B}, - {"model": Llama_3B}, + {"model": Llama_8B}, + # {"model": Llama_3B}, ], ids=lambda d: d["model"], ) diff --git a/llama_stack/providers/utils/inference/augment_messages.py b/llama_stack/providers/utils/inference/augment_messages.py index a69b80d7b..8f59b5295 100644 --- a/llama_stack/providers/utils/inference/augment_messages.py +++ b/llama_stack/providers/utils/inference/augment_messages.py @@ -3,8 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Tuple + from llama_models.llama3.api.chat_format import ChatFormat from termcolor import cprint + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_models.datatypes import ModelFamily @@ -28,6 +31,17 @@ def chat_completion_request_to_prompt( return formatter.tokenizer.decode(model_input.tokens) +def chat_completion_request_to_model_input_info( + request: ChatCompletionRequest, formatter: ChatFormat +) -> Tuple[str, int]: + messages = augment_messages_for_tools(request) + model_input = formatter.encode_dialog_prompt(messages) + return ( + formatter.tokenizer.decode(model_input.tokens), + len(model_input.tokens), + ) + + def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index a39002976..118880b29 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -60,6 +60,8 @@ def process_chat_completion_response( if reason := choice.finish_reason: if reason in ["stop", "eos"]: stop_reason = StopReason.end_of_turn + elif reason == "eom": + stop_reason = StopReason.end_of_message elif reason == "length": stop_reason = StopReason.out_of_tokens @@ -96,7 +98,7 @@ async def process_chat_completion_stream_response( finish_reason = choice.finish_reason if finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos"]: + if stop_reason is None and finish_reason in ["stop", "eos", "eos_token"]: stop_reason = StopReason.end_of_turn elif stop_reason is None and finish_reason == "length": stop_reason = StopReason.out_of_tokens @@ -118,16 +120,16 @@ async def process_chat_completion_stream_response( buffer += text continue - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + if ipython: buffer += text delta = ToolCallDelta( content=text,