diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index da0d0fe4e..d7930550d 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -25,6 +25,8 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem + from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -778,7 +780,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa else: raise ValueError(f"Unsupported URL {url}") - content.append(f'# There is a file accessible to you at "{filepath}"\n') + content.append( + TextContentItem( + text=f'# There is a file accessible to you at "{filepath}"\n' + ) + ) return ToolResponseMessage( call_id="", diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 1daae2307..5ea7e1ad5 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -25,7 +25,6 @@ from fairscale.nn.model_parallel.initialize import ( ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, LLMInput -from llama_models.llama3.api.datatypes import RawContent, RawMessage from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -39,6 +38,10 @@ from llama_stack.apis.inference import * # noqa: F403 from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.utils.inference.prompt_adapter import ( + ChatCompletionRequestWithRawContent, + CompletionRequestWithRawContent, +) from .config import ( Fp8QuantizationConfig, @@ -50,14 +53,6 @@ from .config import ( log = logging.getLogger(__name__) -class ChatCompletionRequestWithRawContent(ChatCompletionRequest): - messages: List[RawMessage] - - -class CompletionRequestWithRawContent(CompletionRequest): - content: RawContent - - def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 4c4e7cb82..92d96ab65 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -12,7 +12,6 @@ from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import Model from llama_models.llama3.api.datatypes import ( - RawMessage, SamplingParams, StopReason, ToolDefinition, @@ -53,14 +52,10 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.prompt_adapter import ( augment_content_with_response_format_prompt, chat_completion_request_to_messages, - interleaved_content_convert_to_raw, + convert_request_to_raw, ) from .config import MetaReferenceInferenceConfig -from .generation import ( - ChatCompletionRequestWithRawContent, - CompletionRequestWithRawContent, - Llama, -) +from .generation import Llama from .model_parallel import LlamaModelParallelGenerator log = logging.getLogger(__name__) @@ -450,20 +445,3 @@ class MetaReferenceInferenceImpl( else: for x in impl(): yield x - - -async def convert_request_to_raw( - request: Union[ChatCompletionRequest, CompletionRequest], -) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]: - if isinstance(request, ChatCompletionRequest): - messages = [] - for m in request.messages: - content = await interleaved_content_convert_to_raw(m.content) - d = m.model_dump() - d["content"] = content - messages.append(RawMessage(**d)) - request.messages = messages - else: - request.content = await interleaved_content_convert_to_raw(request.content) - - return request diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index e4165ff98..c5925774b 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -120,15 +120,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | CompletionResponseStreamChunk: - log.info("vLLM completion") - messages = [UserMessage(content=content)] - return self.chat_completion( - model=model_id, - messages=messages, - sampling_params=sampling_params, - stream=stream, - logprobs=logprobs, - ) + raise NotImplementedError("Completion not implemented for vLLM") async def chat_completion( self, @@ -142,8 +134,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: - log.info("vLLM chat completion") - assert self.engine is not None request = ChatCompletionRequest( @@ -160,7 +150,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = chat_completion_request_to_prompt(request, self.formatter) + prompt = await chat_completion_request_to_prompt(request, self.formatter) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 65733dfcd..5a9fef22a 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -94,14 +94,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_completion( self, request: CompletionRequest ) -> CompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) r = await self.client.completions.create(**params) return process_completion_response(r, self.formatter) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) stream = await self.client.completions.create(**params) @@ -141,7 +141,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_chat_completion( self, request: CompletionRequest ) -> CompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) r = await self.client.completions.create(**params) @@ -150,7 +150,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): async def _stream_chat_completion( self, request: CompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) stream = await self.client.completions.create(**params) @@ -159,7 +159,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): ): yield chunk - def _get_params( + async def _get_params( self, request: Union[ChatCompletionRequest, CompletionRequest] ) -> dict: if request.sampling_params and request.sampling_params.top_k: @@ -167,11 +167,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): prompt = "" if isinstance(request, ChatCompletionRequest): - prompt = chat_completion_request_to_prompt( + prompt = await chat_completion_request_to_prompt( request, self.get_llama_model(request.model), self.formatter ) elif isinstance(request, CompletionRequest): - prompt = completion_request_to_prompt(request, self.formatter) + prompt = await completion_request_to_prompt(request, self.formatter) else: raise ValueError(f"Unknown request type {type(request)}") diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index bb3ee67ec..d9ef57b15 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -241,14 +241,16 @@ class FireworksInferenceAdapter( await convert_message_to_openai_dict(m) for m in request.messages ] else: - input_dict["prompt"] = chat_completion_request_to_prompt( + input_dict["prompt"] = await chat_completion_request_to_prompt( request, self.get_llama_model(request.model), self.formatter ) else: assert ( not media_present ), "Fireworks does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt( + request, self.formatter + ) # Fireworks always prepends with BOS if "prompt" in input_dict: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 2f51f1299..bf55c5ad2 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -243,7 +243,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ] else: input_dict["raw"] = True - input_dict["prompt"] = chat_completion_request_to_prompt( + input_dict["prompt"] = await chat_completion_request_to_prompt( request, self.register_helper.get_llama_model(request.model), self.formatter, @@ -252,7 +252,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): assert ( not media_present ), "Ollama does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt( + request, self.formatter + ) input_dict["raw"] = True return { diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index f82bb2c77..5cc476fd7 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -130,8 +130,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): return options - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - prompt, input_tokens = completion_request_to_prompt_model_input_info( + async def _get_params_for_completion(self, request: CompletionRequest) -> dict: + prompt, input_tokens = await completion_request_to_prompt_model_input_info( request, self.formatter ) @@ -147,7 +147,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params_for_completion(request) async def _generate_and_convert_to_openai_compat(): s = await self.client.text_generation(**params) @@ -169,7 +169,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params_for_completion(request) r = await self.client.text_generation(**params) choice = OpenAICompatCompletionChoice( @@ -216,7 +216,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) + params = await self._get_params(request) r = await self.client.text_generation(**params) choice = OpenAICompatCompletionChoice( @@ -231,7 +231,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): s = await self.client.text_generation(**params) @@ -249,8 +249,8 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ): yield chunk - def _get_params(self, request: ChatCompletionRequest) -> dict: - prompt, input_tokens = chat_completion_request_to_model_input_info( + async def _get_params(self, request: ChatCompletionRequest) -> dict: + prompt, input_tokens = await chat_completion_request_to_model_input_info( request, self.register_helper.get_llama_model(request.model), self.formatter ) return dict( diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index b2e6e06ba..e12a2cc0a 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -233,14 +233,16 @@ class TogetherInferenceAdapter( await convert_message_to_openai_dict(m) for m in request.messages ] else: - input_dict["prompt"] = chat_completion_request_to_prompt( + input_dict["prompt"] = await chat_completion_request_to_prompt( request, self.get_llama_model(request.model), self.formatter ) else: assert ( not media_present ), "Together does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["prompt"] = await completion_request_to_prompt( + request, self.formatter + ) return { "model": request.model, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 12392ea50..7250d901f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -77,7 +77,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError() + raise NotImplementedError("Completion not implemented for vLLM") async def chat_completion( self, @@ -167,7 +167,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): for m in request.messages ] else: - input_dict["prompt"] = chat_completion_request_to_prompt( + input_dict["prompt"] = await chat_completion_request_to_prompt( request, self.register_helper.get_llama_model(request.model), self.formatter, @@ -176,7 +176,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): assert ( not media_present ), "Together does not support media for Completion requests" - input_dict["prompt"] = completion_request_to_prompt( + input_dict["prompt"] = await completion_request_to_prompt( request, self.register_helper.get_llama_model(request.model), self.formatter, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 42aa987c3..9f034e801 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -20,6 +20,7 @@ from llama_models.llama3.api.datatypes import ( RawContent, RawContentItem, RawMediaItem, + RawMessage, RawTextItem, Role, ToolPromptFormat, @@ -58,6 +59,14 @@ from llama_stack.providers.utils.inference import supported_inference_models log = logging.getLogger(__name__) +class ChatCompletionRequestWithRawContent(ChatCompletionRequest): + messages: List[RawMessage] + + +class CompletionRequestWithRawContent(CompletionRequest): + content: RawContent + + def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: def _process(c) -> str: if isinstance(c, str): @@ -75,6 +84,23 @@ def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> s return _process(content) +async def convert_request_to_raw( + request: Union[ChatCompletionRequest, CompletionRequest], +) -> Union[ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent]: + if isinstance(request, ChatCompletionRequest): + messages = [] + for m in request.messages: + content = await interleaved_content_convert_to_raw(m.content) + d = m.model_dump() + d["content"] = content + messages.append(RawMessage(**d)) + request.messages = messages + else: + request.content = await interleaved_content_convert_to_raw(request.content) + + return request + + async def interleaved_content_convert_to_raw( content: InterleavedContent, ) -> RawContent: @@ -169,23 +195,27 @@ async def convert_image_content_to_url( return base64.b64encode(content).decode("utf-8") -def completion_request_to_prompt( +async def completion_request_to_prompt( request: CompletionRequest, formatter: ChatFormat ) -> str: content = augment_content_with_response_format_prompt( request.response_format, request.content ) - model_input = formatter.encode_content(content) + request.content = content + request = await convert_request_to_raw(request) + model_input = formatter.encode_content(request.content) return formatter.tokenizer.decode(model_input.tokens) -def completion_request_to_prompt_model_input_info( +async def completion_request_to_prompt_model_input_info( request: CompletionRequest, formatter: ChatFormat ) -> Tuple[str, int]: content = augment_content_with_response_format_prompt( request.response_format, request.content ) - model_input = formatter.encode_content(content) + request.content = content + request = await convert_request_to_raw(request) + model_input = formatter.encode_content(request.content) return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) @@ -199,19 +229,23 @@ def augment_content_with_response_format_prompt(response_format, content): return content -def chat_completion_request_to_prompt( +async def chat_completion_request_to_prompt( request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> str: messages = chat_completion_request_to_messages(request, llama_model) - model_input = formatter.encode_dialog_prompt(messages) + request.messages = messages + request = await convert_request_to_raw(request) + model_input = formatter.encode_dialog_prompt(request.messages) return formatter.tokenizer.decode(model_input.tokens) -def chat_completion_request_to_model_input_info( +async def chat_completion_request_to_model_input_info( request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat ) -> Tuple[str, int]: messages = chat_completion_request_to_messages(request, llama_model) - model_input = formatter.encode_dialog_prompt(messages) + request.messages = messages + request = await convert_request_to_raw(request) + model_input = formatter.encode_dialog_prompt(request.messages) return ( formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens),