diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index 654cd345c..ce57480a0 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -10,14 +10,19 @@ from fireworks.client import Fireworks from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper - from llama_stack.apis.inference import * # noqa: F403 + from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, + chat_completion_request_to_prompt, +) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, ) from .config import FireworksImplConfig @@ -38,12 +43,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS ) self.config = config - self.tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(self.tokenizer) - - @property - def client(self) -> Fireworks: - return Fireworks(api_key=self.config.api_key) + self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: return @@ -51,7 +51,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass - async def completion( + def completion( self, model: str, content: InterleavedTextMedia, @@ -61,16 +61,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def get_fireworks_chat_options(self, request: ChatCompletionRequest) -> dict: - options = {} - if request.sampling_params is not None: - for attr in {"temperature", "top_p", "top_k", "max_tokens"}: - if getattr(request.sampling_params, attr): - options[attr] = getattr(request.sampling_params, attr) - - return options - - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -92,154 +83,41 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): logprobs=logprobs, ) - messages = augment_messages_for_tools(request) - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) + client = Fireworks(api_key=self.config.api_key) + if stream: + return self._stream_chat_completion(request, client) + else: + return self._nonstream_chat_completion(request, client) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: Fireworks + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = await client.completion.acreate(**params) + return process_chat_completion_response(request, r, self.formatter) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: Fireworks + ) -> AsyncGenerator: + params = self._get_params(request) + + stream = client.completion.acreate(**params) + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk + + def _get_params(self, request: ChatCompletionRequest) -> dict: + prompt = chat_completion_request_to_prompt(request, self.formatter) # Fireworks always prepends with BOS if prompt.startswith("<|begin_of_text|>"): prompt = prompt[len("<|begin_of_text|>") :] - # accumulate sampling params and other options to pass to fireworks - options = self.get_fireworks_chat_options(request) + options = get_sampling_options(request) options.setdefault("max_tokens", 512) - - fireworks_model = self.map_to_provider_model(request.model) - - if not request.stream: - r = await self.client.completion.acreate( - model=fireworks_model, - prompt=prompt, - stream=False, - **options, - ) - stop_reason = None - if r.choices[0].finish_reason: - if r.choices[0].finish_reason == "stop": - stop_reason = StopReason.end_of_turn - elif r.choices[0].finish_reason == "length": - stop_reason = StopReason.out_of_tokens - - completion_message = self.formatter.decode_assistant_message_from_content( - r.choices[0].text, stop_reason - ) - - yield ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, - ) - else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in self.client.completion.acreate( - model=fireworks_model, - prompt=prompt, - stream=True, - **options, - ): - if chunk.choices[0].finish_reason: - if stop_reason is None and chunk.choices[0].finish_reason == "stop": - stop_reason = StopReason.end_of_turn - elif ( - stop_reason is None - and chunk.choices[0].finish_reason == "length" - ): - stop_reason = StopReason.out_of_tokens - break - - text = chunk.choices[0].text - if text is None: - continue - - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue - - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + return { + "model": self.map_to_provider_model(request.model), + "prompt": prompt, + "stream": request.stream, + **options, + } diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 80d2ad4c8..86d72ca7f 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -9,17 +9,22 @@ from typing import AsyncGenerator import httpx from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, + chat_completion_request_to_prompt, +) +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + process_chat_completion_response, + process_chat_completion_stream_response, ) -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper - OLLAMA_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", @@ -30,14 +35,10 @@ OLLAMA_SUPPORTED_MODELS = { } -class OllamaInferenceAdapter(ModelRegistryHelper, Inference): +class OllamaInferenceAdapter(Inference): def __init__(self, url: str) -> None: - ModelRegistryHelper.__init__( - self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS - ) self.url = url - self.tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(self.tokenizer) + self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> AsyncClient: @@ -55,6 +56,28 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): async def shutdown(self) -> None: pass + async def register_model(self, model: ModelDef) -> None: + if model.identifier not in OLLAMA_SUPPORTED_MODELS: + raise ValueError( + f"Unsupported model {model.identifier}. Supported models: {OLLAMA_SUPPORTED_MODELS.keys()}" + ) + + ollama_model = OLLAMA_SUPPORTED_MODELS[model.identifier] + res = await self.client.ps() + need_model_pull = True + for r in res["models"]: + if ollama_model == r["model"]: + need_model_pull = False + break + + print(f"Ollama model `{ollama_model}` needs pull -> {need_model_pull}") + if need_model_pull: + print(f"Pulling model: {ollama_model}") + status = await self.client.pull(ollama_model) + assert ( + status["status"] == "success" + ), f"Failed to pull model {self.model} in ollama" + def completion( self, model: str, @@ -65,20 +88,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict: - options = {} - if request.sampling_params is not None: - for attr in {"temperature", "top_p", "top_k", "max_tokens"}: - if getattr(request.sampling_params, attr): - options[attr] = getattr(request.sampling_params, attr) - if ( - request.sampling_params.repetition_penalty is not None - and request.sampling_params.repetition_penalty != 1.0 - ): - options["repeat_penalty"] = request.sampling_params.repetition_penalty - - return options - def chat_completion( self, model: str, @@ -90,22 +99,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - ollama_model = self.map_to_provider_model(model) - - res = await self.client.ps() - need_model_pull = True - for r in res["models"]: - if ollama_model == r["model"]: - need_model_pull = False - break - - if need_model_pull: - print(f"Pulling model: {ollama_model}") - status = await self.client.pull(ollama_model) - assert ( - status["status"] == "success" - ), f"Failed to pull model {self.model} in ollama" - request = ChatCompletionRequest( model=model, messages=messages, @@ -116,24 +109,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): stream=stream, logprobs=logprobs, ) - if stream: return self._stream_chat_completion(request) else: return self._nonstream_chat_completion(request) def _get_params(self, request: ChatCompletionRequest) -> dict: - messages = augment_messages_for_tools(request) - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) - - # accumulate sampling params and other options to pass to ollama - options = self.get_ollama_chat_options(request) - return { - "model": self.map_to_provider_model(request.model), - "prompt": prompt, - "options": options, + "model": OLLAMA_SUPPORTED_MODELS[request.model], + "prompt": chat_completion_request_to_prompt(request, self.formatter), + "options": get_sampling_options(request), "raw": True, "stream": request.stream, } @@ -143,129 +128,35 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = await self.client.generate(**params) - stop_reason = None - if r["done"]: - if r["done_reason"] == "stop": - stop_reason = StopReason.end_of_turn - elif r["done_reason"] == "length": - stop_reason = StopReason.out_of_tokens + assert isinstance(r, dict) - completion_message = self.formatter.decode_assistant_message_from_content( - r["response"], stop_reason + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["response"], ) - return ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, + response = OpenAICompatCompletionResponse( + choices=[choice], ) + return process_chat_completion_response(request, response, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: params = self._get_params(request) - stream = await self.client.generate(**params) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - - buffer = "" - ipython = False - stop_reason = None - - async for chunk in stream: - if chunk["done"]: - if stop_reason is None and chunk["done_reason"] == "stop": - stop_reason = StopReason.end_of_turn - elif stop_reason is None and chunk["done_reason"] == "length": - stop_reason = StopReason.out_of_tokens - break - - text = chunk["response"] - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) + async def _generate_and_convert_to_openai_compat(): + s = await self.client.generate(**params) + async for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["response"], ) - buffer += text - continue - - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, + yield OpenAICompatCompletionResponse( + choices=[choice], ) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 5326d83d4..d9a9ae491 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, StopReason +from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from together import Together @@ -16,9 +16,14 @@ from together import Together from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.augment_messages import ( - augment_messages_for_tools, + chat_completion_request_to_prompt, ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, +) from .config import TogetherImplConfig @@ -41,8 +46,7 @@ class TogetherInferenceAdapter( self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS ) self.config = config - self.tokenizer = Tokenizer.get_instance() - self.formatter = ChatFormat(self.tokenizer) + self.formatter = ChatFormat(Tokenizer.get_instance()) @property def client(self) -> Together: @@ -64,16 +68,7 @@ class TogetherInferenceAdapter( ) -> AsyncGenerator: raise NotImplementedError() - def get_together_chat_options(self, request: ChatCompletionRequest) -> dict: - options = {} - if request.sampling_params is not None: - for attr in {"temperature", "top_p", "top_k", "max_tokens"}: - if getattr(request.sampling_params, attr): - options[attr] = getattr(request.sampling_params, attr) - - return options - - async def chat_completion( + def chat_completion( self, model: str, messages: List[Message], @@ -84,7 +79,6 @@ class TogetherInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - together_api_key = None if self.config.api_key is not None: together_api_key = self.config.api_key @@ -109,148 +103,39 @@ class TogetherInferenceAdapter( logprobs=logprobs, ) - # accumulate sampling params and other options to pass to together - options = self.get_together_chat_options(request) - together_model = self.map_to_provider_model(request.model) - messages = augment_messages_for_tools(request) - model_input = self.formatter.encode_dialog_prompt(messages) - prompt = self.tokenizer.decode(model_input.tokens) - - if not request.stream: - # TODO: might need to add back an async here - r = client.completions.create( - model=together_model, - prompt=prompt, - stream=False, - **options, - ) - stop_reason = None - choice = r.choices[0] - if choice.finish_reason: - if choice.finish_reason in ["stop", "eos"]: - stop_reason = StopReason.end_of_turn - stop_reason = StopReason.end_of_turn - elif choice.finish_reason == "length": - stop_reason = StopReason.out_of_tokens - - completion_message = self.formatter.decode_assistant_message_from_content( - choice.text, stop_reason - ) - yield ChatCompletionResponse( - completion_message=completion_message, - logprobs=None, - ) + if stream: + return self._stream_chat_completion(request, client) else: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) + return self._nonstream_chat_completion(request, client) - buffer = "" - ipython = False - stop_reason = None + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, client: Together + ) -> ChatCompletionResponse: + params = self._get_params(request) + r = client.completions.create(**params) + return process_chat_completion_response(request, r, self.formatter) - for chunk in client.completions.create( - model=together_model, - prompt=prompt, - stream=True, - **options, - ): - choice = chunk.choices[0] - if finish_reason := choice.finish_reason: - if stop_reason is None and finish_reason in ["stop", "eos"]: - stop_reason = StopReason.end_of_turn - elif stop_reason is None and finish_reason == "length": - stop_reason = StopReason.out_of_tokens - break + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: Together + ) -> AsyncGenerator: + params = self._get_params(request) - text = choice.delta.content - if text is None: - continue + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = client.completions.create(**params) + for chunk in s: + yield chunk - # check if its a tool call ( aka starts with <|python_tag|> ) - if not ipython and text.startswith("<|python_tag|>"): - ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), - ) - ) - buffer += text - continue + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response( + request, stream, self.formatter + ): + yield chunk - if ipython: - if text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - text = "" - continue - elif text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - text = "" - continue - - buffer += text - delta = ToolCallDelta( - content=text, - parse_status=ToolCallParseStatus.in_progress, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, - ) - ) - else: - buffer += text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=text, - stop_reason=stop_reason, - ) - ) - - # parse tool calls and report errors - message = self.formatter.decode_assistant_message_from_content( - buffer, stop_reason - ) - parsed_tool_calls = len(message.tool_calls) > 0 - if ipython and not parsed_tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.failure, - ), - stop_reason=stop_reason, - ) - ) - - for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=tool_call, - parse_status=ToolCallParseStatus.success, - ), - stop_reason=stop_reason, - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) + def _get_params(self, request: ChatCompletionRequest) -> dict: + return { + "model": self.map_to_provider_model(request.model), + "prompt": chat_completion_request_to_prompt(request, self.formatter), + "stream": request.stream, + **get_sampling_options(request), + } diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 38b9ff860..6b12a54e6 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -55,7 +55,7 @@ def get_expected_stop_reason(model: str): @pytest_asyncio.fixture( scope="session", params=[ - {"model": Llama_8B}, + # {"model": Llama_8B}, {"model": Llama_3B}, ], ids=lambda d: d["model"], @@ -112,20 +112,16 @@ def sample_tool_definition(): @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] - response = [ - r - async for r in inference_impl.chat_completion( - messages=sample_messages, - stream=False, - **inference_settings["common_params"], - ) - ] + response = await inference_impl.chat_completion( + messages=sample_messages, + stream=False, + **inference_settings["common_params"], + ) - assert len(response) == 1 - assert isinstance(response[0], ChatCompletionResponse) - assert response[0].completion_message.role == "assistant" - assert isinstance(response[0].completion_message.content, str) - assert len(response[0].completion_message.content) > 0 + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + assert len(response.completion_message.content) > 0 @pytest.mark.asyncio @@ -166,20 +162,16 @@ async def test_chat_completion_with_tool_calling( ) ] - response = [ - r - async for r in inference_impl.chat_completion( - messages=messages, - tools=[sample_tool_definition], - stream=False, - **inference_settings["common_params"], - ) - ] + response = await inference_impl.chat_completion( + messages=messages, + tools=[sample_tool_definition], + stream=False, + **inference_settings["common_params"], + ) - assert len(response) == 1 - assert isinstance(response[0], ChatCompletionResponse) + assert isinstance(response, ChatCompletionResponse) - message = response[0].completion_message + message = response.completion_message # This is not supported in most providers :/ they don't return eom_id / eot_id # stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"]) diff --git a/llama_stack/providers/utils/inference/augment_messages.py b/llama_stack/providers/utils/inference/augment_messages.py index 613a39525..a69b80d7b 100644 --- a/llama_stack/providers/utils/inference/augment_messages.py +++ b/llama_stack/providers/utils/inference/augment_messages.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_models.llama3.api.chat_format import ChatFormat from termcolor import cprint from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 @@ -19,6 +20,14 @@ from llama_models.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models +def chat_completion_request_to_prompt( + request: ChatCompletionRequest, formatter: ChatFormat +) -> str: + messages = augment_messages_for_tools(request) + model_input = formatter.encode_dialog_prompt(messages) + return formatter.tokenizer.decode(model_input.tokens) + + def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]: """Reads chat completion request and augments the messages to handle tools. For eg. for llama_3_1, add system message with the appropriate tools or @@ -48,7 +57,6 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]: def augment_messages_for_tools_llama_3_1( request: ChatCompletionRequest, ) -> List[Message]: - assert request.tool_choice == ToolChoice.auto, "Only `ToolChoice.auto` supported" existing_messages = request.messages diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py new file mode 100644 index 000000000..a39002976 --- /dev/null +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, Optional + +from llama_models.llama3.api.chat_format import ChatFormat + +from llama_models.llama3.api.datatypes import StopReason + +from llama_stack.apis.inference import * # noqa: F403 + +from pydantic import BaseModel + + +class OpenAICompatCompletionChoiceDelta(BaseModel): + content: str + + +class OpenAICompatCompletionChoice(BaseModel): + finish_reason: Optional[str] = None + text: Optional[str] = None + delta: Optional[OpenAICompatCompletionChoiceDelta] = None + + +class OpenAICompatCompletionResponse(BaseModel): + choices: List[OpenAICompatCompletionChoice] + + +def get_sampling_options(request: ChatCompletionRequest) -> dict: + options = {} + if params := request.sampling_params: + for attr in {"temperature", "top_p", "top_k", "max_tokens"}: + if getattr(params, attr): + options[attr] = getattr(params, attr) + + if params.repetition_penalty is not None and params.repetition_penalty != 1.0: + options["repeat_penalty"] = params.repetition_penalty + + return options + + +def text_from_choice(choice) -> str: + if hasattr(choice, "delta") and choice.delta: + return choice.delta.content + + return choice.text + + +def process_chat_completion_response( + request: ChatCompletionRequest, + response: OpenAICompatCompletionResponse, + formatter: ChatFormat, +) -> ChatCompletionResponse: + choice = response.choices[0] + + stop_reason = None + if reason := choice.finish_reason: + if reason in ["stop", "eos"]: + stop_reason = StopReason.end_of_turn + elif reason == "length": + stop_reason = StopReason.out_of_tokens + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + completion_message = formatter.decode_assistant_message_from_content( + text_from_choice(choice), stop_reason + ) + return ChatCompletionResponse( + completion_message=completion_message, + logprobs=None, + ) + + +async def process_chat_completion_stream_response( + request: ChatCompletionRequest, + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], + formatter: ChatFormat, +) -> AsyncGenerator: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + + buffer = "" + ipython = False + stop_reason = None + + async for chunk in stream: + choice = chunk.choices[0] + finish_reason = choice.finish_reason + + if finish_reason: + if stop_reason is None and finish_reason in ["stop", "eos"]: + stop_reason = StopReason.end_of_turn + elif stop_reason is None and finish_reason == "length": + stop_reason = StopReason.out_of_tokens + break + + text = text_from_choice(choice) + # check if its a tool call ( aka starts with <|python_tag|> ) + if not ipython and text.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) + ) + buffer += text + continue + + if ipython: + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + + buffer += text + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) + ) + else: + buffer += text + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=text, + stop_reason=stop_reason, + ) + ) + + # parse tool calls and report errors + message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), + stop_reason=stop_reason, + ) + ) + + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + )