diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index c9d0c0da9..0a9e4f5a7 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -56,9 +56,7 @@ from .models import MODEL_ENTRIES logger = get_logger(name=__name__, category="inference") -class FireworksInferenceAdapter( - ModelRegistryHelper, Inference, NeedsRequestProviderData -): +class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config @@ -70,9 +68,7 @@ class FireworksInferenceAdapter( pass def _get_api_key(self) -> str: - config_api_key = ( - self.config.api_key.get_secret_value() if self.config.api_key else None - ) + config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None if config_api_key: return config_api_key else: @@ -112,9 +108,7 @@ class FireworksInferenceAdapter( else: return await self._nonstream_completion(request) - async def _nonstream_completion( - self, request: CompletionRequest - ) -> CompletionResponse: + async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: params = await self._get_params(request) r = await self._get_client().completion.acreate(**params) return process_completion_response(r) @@ -194,9 +188,7 @@ class FireworksInferenceAdapter( else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion( - self, request: ChatCompletionRequest - ) -> ChatCompletionResponse: + async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: r = await self._get_client().chat.completions.acreate(**params) @@ -204,9 +196,7 @@ class FireworksInferenceAdapter( r = await self._get_client().completion.acreate(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion( - self, request: ChatCompletionRequest - ) -> AsyncGenerator: + async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -221,9 +211,7 @@ class FireworksInferenceAdapter( async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params( - self, request: Union[ChatCompletionRequest, CompletionRequest] - ) -> dict: + async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) @@ -231,17 +219,12 @@ class FireworksInferenceAdapter( if isinstance(request, ChatCompletionRequest): if media_present or not llama_model: input_dict["messages"] = [ - await convert_message_to_openai_dict(m, download=True) - for m in request.messages + await convert_message_to_openai_dict(m, download=True) for m in request.messages ] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, llama_model - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: - assert ( - not media_present - ), "Fireworks does not support media for Completion requests" + assert not media_present, "Fireworks does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) # Fireworks always prepends with BOS @@ -253,9 +236,7 @@ class FireworksInferenceAdapter( "model": request.model, **input_dict, "stream": request.stream, - **self._build_options( - request.sampling_params, request.response_format, request.logprobs - ), + **self._build_options(request.sampling_params, request.response_format, request.logprobs), } logger.debug(f"params to fireworks: {params}") @@ -274,9 +255,9 @@ class FireworksInferenceAdapter( kwargs = {} if model.metadata.get("embedding_dimension"): kwargs["dimensions"] = model.metadata.get("embedding_dimension") - assert all( - not content_has_media(content) for content in contents - ), "Fireworks does not support media for embeddings" + assert all(not content_has_media(content) for content in contents), ( + "Fireworks does not support media for embeddings" + ) response = self._get_client().embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 5e0ec88e8..0ef936925 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -5,10 +5,9 @@ # the root directory of this source tree. -import logging from typing import AsyncGenerator, List, Optional -from huggingface_hub import AsyncInferenceClient, HfApi, InferenceClient +from huggingface_hub import AsyncInferenceClient, HfApi from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -16,12 +15,10 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import ( ChatCompletionRequest, - ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, EmbeddingTaskType, Inference, - JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, @@ -38,26 +35,20 @@ from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, ModelRegistryHelper, + build_hf_repo_model_entry, ) from llama_stack.providers.utils.inference.openai_compat import ( - convert_chat_completion_request_to_openai_params, - convert_completion_request_to_openai_params, - convert_message_to_openai_dict_new, - convert_openai_chat_completion_choice, - convert_openai_chat_completion_stream, - convert_tooldef_to_openai_tool, - get_sampling_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, - process_chat_completion_response, - process_chat_completion_stream_response, + convert_chat_completion_request_to_openai_params, + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, + get_sampling_options, process_completion_response, process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_model_input_info, completion_request_to_prompt_model_input_info, ) @@ -85,9 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): def __init__(self) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.huggingface_repo_to_llama_model_id = { - model.huggingface_repo: model.descriptor() - for model in all_registered_models() - if model.huggingface_repo + model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo } async def shutdown(self) -> None: @@ -114,7 +103,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: if response_format: - raise ValueError(f"TGI does not support Response Format for completions.") + raise ValueError("TGI does not support Response Format for completions.") if sampling_params is None: sampling_params = SamplingParams() @@ -166,17 +155,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): return options async def _get_params_for_completion(self, request: CompletionRequest) -> dict: - prompt, input_tokens = await completion_request_to_prompt_model_input_info( - request - ) + prompt, input_tokens = await completion_request_to_prompt_model_input_info(request) return dict( prompt=prompt, stream=request.stream, details=True, - max_new_tokens=self._get_max_new_tokens( - request.sampling_params, input_tokens - ), + max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens), stop_sequences=["<|eom_id|>", "<|eot_id|>"], **self._build_options(request.sampling_params, request.response_format), ) @@ -185,16 +170,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): params = await self._get_params_for_completion(request) async def _generate_and_convert_to_openai_compat(): - s = self.client.text_generation(**params) - for chunk in s: + s = await self.client.text_generation(**params) + async for chunk in s: token_result = chunk.token finish_reason = None if chunk.details: finish_reason = chunk.details.finish_reason - choice = OpenAICompatCompletionChoice( - text=token_result.text, finish_reason=finish_reason - ) + choice = OpenAICompatCompletionChoice(text=token_result.text, finish_reason=finish_reason) yield OpenAICompatCompletionResponse( choices=[choice], ) @@ -205,7 +188,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params_for_completion(request) - r = self.client.text_generation(**params) + r = await self.client.text_generation(**params) choice = OpenAICompatCompletionChoice( finish_reason=r.details.finish_reason, @@ -234,9 +217,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): if sampling_params is None: sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) - from rich.pretty import pprint - - pprint(messages) request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -250,18 +230,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): params = await convert_chat_completion_request_to_openai_params(request) - import json - - # print(json.dumps(params, indent=2)) - - pprint(params) - - response = self.client.chat.completions.create(**params) - + response = await self.client.chat.completions.create(**params) if stream: - return convert_openai_chat_completion_stream( - response, enable_incremental_tool_calls=True - ) + return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=True) else: return convert_openai_chat_completion_choice(response.choices[0]) @@ -281,18 +252,16 @@ class TGIAdapter(_HfAdapter): logger.info(f"Initializing TGI client with url={config.url}") # unfortunately, the TGI async client does not work well with proxies # so using sync client for now instead - self.client = InferenceClient(model=f"{config.url}") + self.client = AsyncInferenceClient(model=f"{config.url}") - endpoint_info = self.client.get_endpoint_info() + endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: - self.client = AsyncInferenceClient( - model=config.huggingface_repo, token=config.api_token.get_secret_value() - ) + self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value()) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] @@ -310,6 +279,4 @@ class InferenceEndpointAdapter(_HfAdapter): # Initialize the adapter self.client = endpoint.async_client self.model_id = endpoint.repository - self.max_tokens = int( - endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] - ) + self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 8e9825107..a6d2ade4c 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -19,7 +19,6 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, EmbeddingTaskType, Inference, - JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, @@ -36,11 +35,8 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( convert_chat_completion_request_to_openai_params, - convert_message_to_openai_dict_new, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, - convert_tooldef_to_openai_tool, - get_sampling_options, ) from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -54,9 +50,7 @@ class LiteLLMOpenAIMixin( Inference, NeedsRequestProviderData, ): - def __init__( - self, model_entries, api_key_from_config: str, provider_data_api_key_field: str - ): + def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str): ModelRegistryHelper.__init__(self, model_entries) self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field @@ -96,9 +90,7 @@ class LiteLLMOpenAIMixin( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ - ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] - ]: + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: if sampling_params is None: sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 2dce64675..66de06c6f 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -6,7 +6,49 @@ import json import logging import warnings -from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union +from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union + +from openai import AsyncStream +from openai.types.chat import ( + ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( + ChatCompletionChunk as OpenAIChatCompletionChunk, +) +from openai.types.chat import ( + ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( + ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, +) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) +from openai.types.chat import ( + ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, +) +from openai.types.chat import ( + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, +) +from openai.types.chat.chat_completion import ( + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from pydantic import BaseModel from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -23,7 +65,6 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, - CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, JsonSchemaResponseFormat, @@ -49,32 +90,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -from openai import AsyncStream, Stream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, - ChatCompletionChunk as OpenAIChatCompletionChunk, - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, - ChatCompletionMessageParam as OpenAIChatCompletionMessage, - ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam, - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call_param import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel - logger = logging.getLogger(__name__) @@ -169,16 +184,12 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip( - logprobs.tokens, logprobs.token_logprobs, strict=False - ) + for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) ] return None -def convert_openai_completion_logprobs_stream( - text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]] -): +def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]): if logprobs is None: return None if isinstance(logprobs, float): @@ -223,9 +234,7 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [ - convert_tool_call(tool_call) for tool_call in choice.message.tool_calls - ] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -249,9 +258,7 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message( - text_from_choice(choice), get_stop_reason(choice.finish_reason) - ) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -452,17 +459,13 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict( - message: Message, download: bool = False -) -> dict: +async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url( - content, download=download - ), + "url": await convert_image_content_to_url(content, download=download), }, } else: @@ -541,9 +544,7 @@ async def convert_message_to_openai_dict_new( elif isinstance(content_, ImageContentItem): return OpenAIChatCompletionContentPartImageParam( type="image_url", - image_url=OpenAIImageURL( - url=await convert_image_content_to_url(content_) - ), + image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)), ) elif isinstance(content_, list): return [await impl(item) for item in content_] @@ -587,9 +588,7 @@ async def convert_message_to_openai_dict_new( "id": tool.call_id, "function": { "name": ( - tool.tool_name - if not isinstance(tool.tool_name, BuiltinTool) - else tool.tool_name.value + tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value ), "arguments": json.dumps(tool.arguments), }, @@ -709,11 +708,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: properties = parameters["properties"] required = [] for param_name, param in tool.parameters.items(): - properties[param_name] = { - "type": PYTHON_TYPE_TO_LITELLM_TYPE.get( - param.param_type, param.param_type - ) - } + properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)} if param.description: properties[param_name].update(description=param.description) if param.default: @@ -834,11 +829,7 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs( - logprobs_by_token={ - logprobs.token: logprobs.logprob for logprobs in content.top_logprobs - } - ) + TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) for content in logprobs.content ] @@ -876,17 +867,14 @@ def convert_openai_chat_completion_choice( end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert ( - hasattr(choice, "message") and choice.message - ), "error in server response: message not found" - assert ( - hasattr(choice, "finish_reason") and choice.finish_reason - ), "error in server response: finish_reason not found" + assert hasattr(choice, "message") and choice.message, "error in server response: message not found" + assert hasattr(choice, "finish_reason") and choice.finish_reason, ( + "error in server response: finish_reason not found" + ) return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content - or "", # CompletionMessage content is not optional + content=choice.message.content or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -895,9 +883,7 @@ def convert_openai_chat_completion_choice( async def convert_openai_chat_completion_stream( - stream: Union[ - AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk] - ], + stream: AsyncStream[OpenAIChatCompletionChunk], enable_incremental_tool_calls: bool, ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: """ @@ -905,14 +891,6 @@ async def convert_openai_chat_completion_stream( of ChatCompletionResponseStreamChunk. """ - async def yield_from_stream(stream): - if isinstance(stream, AsyncGenerator): - async for chunk in stream: - yield chunk - elif isinstance(stream, Generator): - for chunk in stream: - yield chunk - yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, @@ -924,7 +902,7 @@ async def convert_openai_chat_completion_stream( stop_reason = None tool_call_idx_to_buffer = {} - async for chunk in yield_from_stream(stream): + async for chunk in stream: choice = chunk.choices[0] # assuming only one choice per chunk # we assume there's only one finish_reason in the stream @@ -1092,7 +1070,7 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) - except json.JSONDecodeError as e: + except json.JSONDecodeError: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -1137,7 +1115,7 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) - except (KeyError, json.JSONDecodeError) as e: + except (KeyError, json.JSONDecodeError): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -1158,26 +1136,6 @@ async def convert_openai_chat_completion_stream( ) -async def convert_completion_request_to_openai_params( - request: CompletionRequest, -) -> dict: - """ - Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. - """ - input_dict = {} - if request.logprobs: - input_dict["logprobs"] = request.logprobs.top_k - - return { - "model": request.model, - "prompt": request.content, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - "n": 1, - } - - async def convert_chat_completion_request_to_openai_params( request: ChatCompletionRequest, ) -> dict: @@ -1186,14 +1144,10 @@ async def convert_chat_completion_request_to_openai_params( """ input_dict = {} - input_dict["messages"] = [ - await convert_message_to_openai_dict_new(m) for m in request.messages - ] + input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] if fmt := request.response_format: if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) + raise ValueError(f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported.") fmt = fmt.json_schema name = fmt["title"] @@ -1217,9 +1171,7 @@ async def convert_chat_completion_request_to_openai_params( } if request.tools: - input_dict["tools"] = [ - convert_tooldef_to_openai_tool(tool) for tool in request.tools - ] + input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] if request.tool_config.tool_choice: input_dict["tool_choice"] = ( request.tool_config.tool_choice.value diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index f0d8f2b72..1ba658ba7 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -12,6 +12,7 @@ import re from typing import List, Optional, Tuple, Union import httpx +from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -33,7 +34,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( - is_multimodal, ModelFamily, RawContent, RawContentItem, @@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import ( Role, StopReason, ToolPromptFormat, + is_multimodal, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.prompt_templates import ( @@ -55,7 +56,6 @@ from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models -from PIL import Image as PIL_Image log = get_logger(name=__name__, category="inference") @@ -129,9 +129,7 @@ async def interleaved_content_convert_to_raw( if image.url.uri.startswith("data"): match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri) if not match: - raise ValueError( - f"Invalid data URL format, {image.url.uri[:40]}..." - ) + raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...") _, image_data = match.groups() data = base64.b64decode(image_data) elif image.url.uri.startswith("file://"): @@ -211,17 +209,13 @@ async def convert_image_content_to_url( content, format = await localize_image_content(media) if include_format: - return f"data:image/{format};base64," + base64.b64encode(content).decode( - "utf-8" - ) + return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8") else: return base64.b64encode(content).decode("utf-8") async def completion_request_to_prompt(request: CompletionRequest) -> str: - content = augment_content_with_response_format_prompt( - request.response_format, request.content - ) + content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) @@ -233,9 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str: async def completion_request_to_prompt_model_input_info( request: CompletionRequest, ) -> Tuple[str, int]: - content = augment_content_with_response_format_prompt( - request.response_format, request.content - ) + content = augment_content_with_response_format_prompt(request.response_format, request.content) request.content = content request = await convert_request_to_raw(request) @@ -256,9 +248,7 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt( - request: ChatCompletionRequest, llama_model: str -) -> str: +async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) @@ -266,8 +256,7 @@ async def chat_completion_request_to_prompt( formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format - or get_default_tool_prompt_format(llama_model), + tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), ) return formatter.tokenizer.decode(model_input.tokens) @@ -282,8 +271,7 @@ async def chat_completion_request_to_model_input_info( formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format - or get_default_tool_prompt_format(llama_model), + tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), ) tokens = [] for t in model_input.tokens: @@ -318,8 +306,7 @@ def chat_completion_request_to_messages( return request.messages if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 - and is_multimodal(model.core_model_id) + model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format messages = augment_messages_for_tools_llama_3_1(request) @@ -355,9 +342,7 @@ def augment_messages_for_tools_llama_3_1( if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - assert ( - existing_messages[0].role != Role.system.value - ), "Should only have 1 system message" + assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" messages = [] @@ -389,13 +374,9 @@ def augment_messages_for_tools_llama_3_1( if isinstance(existing_system_message.content, str): sys_content += _process(existing_system_message.content) elif isinstance(existing_system_message.content, list): - sys_content += "\n".join( - [_process(c) for c in existing_system_message.content] - ) + sys_content += "\n".join([_process(c) for c in existing_system_message.content]) - tool_choice_prompt = _get_tool_choice_prompt( - request.tool_config.tool_choice, request.tools - ) + tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) if tool_choice_prompt: sys_content += "\n" + tool_choice_prompt @@ -429,9 +410,7 @@ def augment_messages_for_tools_llama_3_2( if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - assert ( - existing_messages[0].role != Role.system.value - ), "Should only have 1 system message" + assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" sys_content = "" custom_tools, builtin_tools = [], [] @@ -452,16 +431,10 @@ def augment_messages_for_tools_llama_3_2( if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: - raise ValueError( - f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}" - ) + raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") system_prompt = None - if ( - existing_system_message - and request.tool_config.system_message_behavior - == SystemMessageBehavior.replace - ): + if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: system_prompt = existing_system_message.content tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) @@ -470,16 +443,11 @@ def augment_messages_for_tools_llama_3_2( sys_content += "\n" if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append - or not custom_tools + request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools ): - sys_content += interleaved_content_as_str( - existing_system_message.content, sep="\n" - ) + sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") - tool_choice_prompt = _get_tool_choice_prompt( - request.tool_config.tool_choice, request.tools - ) + tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) if tool_choice_prompt: sys_content += "\n" + tool_choice_prompt @@ -487,15 +455,11 @@ def augment_messages_for_tools_llama_3_2( return messages -def _get_tool_choice_prompt( - tool_choice: ToolChoice | str, tools: List[ToolDefinition] -) -> str: +def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str: if tool_choice == ToolChoice.auto: return "" elif tool_choice == ToolChoice.required: - return ( - "You MUST use one of the provided functions/tools to answer the user query." - ) + return "You MUST use one of the provided functions/tools to answer the user query." elif tool_choice == ToolChoice.none: # tools are already not passed in return "" @@ -507,14 +471,11 @@ def _get_tool_choice_prompt( def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: - log.warning( - f"Could not resolve model {model}, defaulting to json tool prompt format" - ) + log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format") return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or ( - llama_model.model_family == ModelFamily.llama3_2 - and is_multimodal(llama_model.core_model_id) + llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format return ToolPromptFormat.json diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 477e24ae5..1878c9e88 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -120,16 +120,13 @@ def client_with_models( judge_model_id, ): client = llama_stack_client - from rich.pretty import pprint providers = [p for p in client.providers.list() if p.api == "inference"] - pprint(providers) assert len(providers) > 0, "No inference providers found" inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] model_ids = {m.identifier for m in client.models.list()} model_ids.update(m.provider_resource_id for m in client.models.list()) - pprint(model_ids) if text_model_id and text_model_id not in model_ids: client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 83c86aaee..70fddb2b6 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -8,9 +8,9 @@ import os import pytest +from pydantic import BaseModel from llama_stack.models.llama.sku_list import resolve_model -from pydantic import BaseModel from ..test_cases.test_case import TestCase @@ -29,9 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): "remote::gemini", "remote::groq", ): - pytest.skip( - f"Model {model_id} hosted by {provider.provider_type} doesn't support completion" - ) + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") def get_llama_model(client_with_models, model_id): @@ -112,9 +110,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case) "inference:completion:stop_sequence", ], ) -def test_text_completion_stop_sequence( - client_with_models, text_model_id, inference_provider_type, test_case -): +def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) # This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771 if inference_provider_type != "remote::vllm": @@ -141,9 +137,7 @@ def test_text_completion_stop_sequence( "inference:completion:log_probs", ], ) -def test_text_completion_log_probs_non_streaming( - client_with_models, text_model_id, inference_provider_type, test_case -): +def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") @@ -162,9 +156,7 @@ def test_text_completion_log_probs_non_streaming( }, ) assert response.logprobs, "Logprobs should not be empty" - assert ( - 1 <= len(response.logprobs) <= 5 - ) # each token has 1 logprob and here max_tokens=5 + assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5 assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) @@ -174,9 +166,7 @@ def test_text_completion_log_probs_non_streaming( "inference:completion:log_probs", ], ) -def test_text_completion_log_probs_streaming( - client_with_models, text_model_id, inference_provider_type, test_case -): +def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") @@ -198,9 +188,7 @@ def test_text_completion_log_probs_streaming( for chunk in streamed_content: if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" - assert all( - len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs - ) + assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" @@ -211,13 +199,9 @@ def test_text_completion_log_probs_streaming( "inference:completion:structured_output", ], ) -def test_text_completion_structured_output( - client_with_models, text_model_id, test_case, inference_provider_type -): +def test_text_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type): if inference_provider_type == "remote::tgi": - pytest.xfail( - f"{inference_provider_type} doesn't support structured outputs yet" - ) + pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet") skip_if_model_doesnt_support_completion(client_with_models, text_model_id) class AnswerFormat(BaseModel): @@ -254,9 +238,7 @@ def test_text_completion_structured_output( "inference:chat_completion:non_streaming_02", ], ) -def test_text_chat_completion_non_streaming( - client_with_models, text_model_id, test_case -): +def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case): tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] @@ -282,18 +264,15 @@ def test_text_chat_completion_non_streaming( "inference:chat_completion:ttft", ], ) -def test_text_chat_completion_first_token_profiling( - client_with_models, text_model_id, test_case -): +def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case): tc = TestCase(test_case) messages = tc["messages"] - if os.environ.get( - "DEBUG_TTFT" - ): # debugging print number of tokens in input, ideally around 800 - from llama_stack.apis.inference import Message + if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800 from pydantic import TypeAdapter + from llama_stack.apis.inference import Message + tokenizer, formatter = get_llama_tokenizer() typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages] encoded = formatter.encode_dialog_prompt(typed_messages, None) @@ -307,9 +286,7 @@ def test_text_chat_completion_first_token_profiling( message_content = response.completion_message.content.lower().strip() assert len(message_content) > 0 - if os.environ.get( - "DEBUG_TTFT" - ): # debugging print number of tokens in response, ideally around 150 + if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150 tokenizer, formatter = get_llama_tokenizer() encoded = formatter.encode_content(message_content) raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0) @@ -332,9 +309,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ messages=[{"role": "user", "content": question}], stream=True, ) - streamed_content = [ - str(chunk.event.delta.text.lower().strip()) for chunk in response - ] + streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response] assert len(streamed_content) > 0 assert expected.lower() in "".join(streamed_content) @@ -345,9 +320,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_non_streaming( - client_with_models, text_model_id, test_case -): +def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -361,10 +334,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( assert response.completion_message.role == "assistant" assert len(response.completion_message.tool_calls) == 1 - assert ( - response.completion_message.tool_calls[0].tool_name - == tc["tools"][0]["tool_name"] - ) + assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"] assert response.completion_message.tool_calls[0].arguments == tc["expected"] @@ -387,9 +357,7 @@ def extract_tool_invocation_content(response): "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_streaming( - client_with_models, text_model_id, test_case -): +def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case): tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -415,9 +383,7 @@ def test_text_chat_completion_with_tool_choice_required( client_with_models, text_model_id, test_case, inference_provider_type ): if inference_provider_type == "remote::tgi": - pytest.xfail( - f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet" - ) + pytest.xfail(f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet") tc = TestCase(test_case) @@ -442,9 +408,7 @@ def test_text_chat_completion_with_tool_choice_required( "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_none( - client_with_models, text_model_id, test_case -): +def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case): tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -464,13 +428,9 @@ def test_text_chat_completion_with_tool_choice_none( "inference:chat_completion:structured_output", ], ) -def test_text_chat_completion_structured_output( - client_with_models, text_model_id, test_case, inference_provider_type -): +def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type): if inference_provider_type == "remote::tgi": - pytest.xfail( - f"{inference_provider_type} doesn't support structured outputs yet" - ) + pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet") class NBAStats(BaseModel): year_for_draft: int diff --git a/tests/integration/inference/test_vision_inference.py b/tests/integration/inference/test_vision_inference.py index 7f0f935a7..70068aa29 100644 --- a/tests/integration/inference/test_vision_inference.py +++ b/tests/integration/inference/test_vision_inference.py @@ -27,9 +27,7 @@ def base64_image_url(base64_image_data, image_path): return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" -# @pytest.mark.xfail( -# reason="This test is failing because the image is not being downloaded correctly." -# ) +@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.") def test_image_chat_completion_non_streaming(client_with_models, vision_model_id): message = { "role": "user", @@ -58,9 +56,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) -# @pytest.mark.xfail( -# reason="This test is failing because the image is not being downloaded correctly." -# ) +@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.") def test_image_chat_completion_streaming(client_with_models, vision_model_id): message = { "role": "user", @@ -92,9 +88,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id): @pytest.mark.parametrize("type_", ["url"]) -def test_image_chat_completion_base64( - client_with_models, vision_model_id, base64_image_data, base64_image_url, type_ -): +def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): image_spec = { "url": { "type": "image",