diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index ca4dc59f7..8023e2836 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -106,24 +106,6 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", ), ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::serverless", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", - ), - ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::endpoint", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", - ), - ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 4acbe43f8..c9d0c0da9 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -33,9 +33,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, @@ -58,7 +56,9 @@ from .models import MODEL_ENTRIES logger = get_logger(name=__name__, category="inference") -class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): +class FireworksInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ENTRIES) self.config = config @@ -70,7 +70,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv pass def _get_api_key(self) -> str: - config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None + config_api_key = ( + self.config.api_key.get_secret_value() if self.config.api_key else None + ) if config_api_key: return config_api_key else: @@ -110,7 +112,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv else: return await self._nonstream_completion(request) - async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: params = await self._get_params(request) r = await self._get_client().completion.acreate(**params) return process_completion_response(r) @@ -190,7 +194,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: r = await self._get_client().chat.completions.acreate(**params) @@ -198,7 +204,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv r = await self._get_client().completion.acreate(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) async def _to_async_generator(): @@ -213,7 +221,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: input_dict = {} media_present = request_has_media(request) @@ -221,12 +231,17 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv if isinstance(request, ChatCompletionRequest): if media_present or not llama_model: input_dict["messages"] = [ - await convert_message_to_openai_dict(m, download=True) for m in request.messages + await convert_message_to_openai_dict(m, download=True) + for m in request.messages ] else: - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + input_dict["prompt"] = await chat_completion_request_to_prompt( + request, llama_model + ) else: - assert not media_present, "Fireworks does not support media for Completion requests" + assert ( + not media_present + ), "Fireworks does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) # Fireworks always prepends with BOS @@ -238,7 +253,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv "model": request.model, **input_dict, "stream": request.stream, - **self._build_options(request.sampling_params, request.response_format, request.logprobs), + **self._build_options( + request.sampling_params, request.response_format, request.logprobs + ), } logger.debug(f"params to fireworks: {params}") @@ -257,9 +274,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv kwargs = {} if model.metadata.get("embedding_dimension"): kwargs["dimensions"] = model.metadata.get("embedding_dimension") - assert all(not content_has_media(content) for content in contents), ( - "Fireworks does not support media for embeddings" - ) + assert all( + not content_has_media(content) for content in contents + ), "Fireworks does not support media for embeddings" response = self._get_client().embeddings.create( model=model.provider_resource_id, input=[interleaved_content_as_str(content) for content in contents], diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 757085fb1..5e0ec88e8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -8,7 +8,7 @@ import logging from typing import AsyncGenerator, List, Optional -from huggingface_hub import AsyncInferenceClient, HfApi +from huggingface_hub import AsyncInferenceClient, HfApi, InferenceClient from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, EmbeddingTaskType, Inference, + JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, @@ -33,16 +34,23 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, build_hf_repo_model_entry, + ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( + convert_chat_completion_request_to_openai_params, + convert_completion_request_to_openai_params, + convert_message_to_openai_dict_new, + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, + convert_tooldef_to_openai_tool, + get_sampling_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, - get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, @@ -55,7 +63,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference") def build_hf_repo_model_entries(): @@ -77,7 +85,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): def __init__(self) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.huggingface_repo_to_llama_model_id = { - model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo + model.huggingface_repo: model.descriptor() + for model in all_registered_models() + if model.huggingface_repo } async def shutdown(self) -> None: @@ -103,6 +113,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + if response_format: + raise ValueError(f"TGI does not support Response Format for completions.") + if sampling_params is None: sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) @@ -153,13 +166,17 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): return options async def _get_params_for_completion(self, request: CompletionRequest) -> dict: - prompt, input_tokens = await completion_request_to_prompt_model_input_info(request) + prompt, input_tokens = await completion_request_to_prompt_model_input_info( + request + ) return dict( prompt=prompt, stream=request.stream, details=True, - max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens), + max_new_tokens=self._get_max_new_tokens( + request.sampling_params, input_tokens + ), stop_sequences=["<|eom_id|>", "<|eot_id|>"], **self._build_options(request.sampling_params, request.response_format), ) @@ -168,14 +185,16 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): params = await self._get_params_for_completion(request) async def _generate_and_convert_to_openai_compat(): - s = await self.client.text_generation(**params) - async for chunk in s: + s = self.client.text_generation(**params) + for chunk in s: token_result = chunk.token finish_reason = None if chunk.details: finish_reason = chunk.details.finish_reason - choice = OpenAICompatCompletionChoice(text=token_result.text, finish_reason=finish_reason) + choice = OpenAICompatCompletionChoice( + text=token_result.text, finish_reason=finish_reason + ) yield OpenAICompatCompletionResponse( choices=[choice], ) @@ -186,7 +205,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params_for_completion(request) - r = await self.client.text_generation(**params) + r = self.client.text_generation(**params) choice = OpenAICompatCompletionChoice( finish_reason=r.details.finish_reason, @@ -215,6 +234,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): if sampling_params is None: sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) + from rich.pretty import pprint + + pprint(messages) request = ChatCompletionRequest( model=model.provider_resource_id, messages=messages, @@ -226,53 +248,22 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): tool_config=tool_config, ) + params = await convert_chat_completion_request_to_openai_params(request) + + import json + + # print(json.dumps(params, indent=2)) + + pprint(params) + + response = self.client.chat.completions.create(**params) + if stream: - return self._stream_chat_completion(request) + return convert_openai_chat_completion_stream( + response, enable_incremental_tool_calls=True + ) else: - return await self._nonstream_chat_completion(request) - - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - params = await self._get_params(request) - r = await self.client.text_generation(**params) - - choice = OpenAICompatCompletionChoice( - finish_reason=r.details.finish_reason, - text="".join(t.text for t in r.details.tokens), - ) - response = OpenAICompatCompletionResponse( - choices=[choice], - ) - return process_chat_completion_response(response, request) - - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - params = await self._get_params(request) - - async def _generate_and_convert_to_openai_compat(): - s = await self.client.text_generation(**params) - async for chunk in s: - token_result = chunk.token - - choice = OpenAICompatCompletionChoice(text=token_result.text) - yield OpenAICompatCompletionResponse( - choices=[choice], - ) - - stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, request): - yield chunk - - async def _get_params(self, request: ChatCompletionRequest) -> dict: - prompt, input_tokens = await chat_completion_request_to_model_input_info( - request, self.register_helper.get_llama_model(request.model) - ) - return dict( - prompt=prompt, - stream=request.stream, - details=True, - max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens), - stop_sequences=["<|eom_id|>", "<|eot_id|>"], - **self._build_options(request.sampling_params, request.response_format), - ) + return convert_openai_chat_completion_choice(response.choices[0]) async def embeddings( self, @@ -287,18 +278,21 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: - log.info(f"Initializing TGI client with url={config.url}") - self.client = AsyncInferenceClient( - model=config.url, - ) - endpoint_info = await self.client.get_endpoint_info() + logger.info(f"Initializing TGI client with url={config.url}") + # unfortunately, the TGI async client does not work well with proxies + # so using sync client for now instead + self.client = InferenceClient(model=f"{config.url}") + + endpoint_info = self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: - self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value()) + self.client = AsyncInferenceClient( + model=config.huggingface_repo, token=config.api_token.get_secret_value() + ) endpoint_info = await self.client.get_endpoint_info() self.max_tokens = endpoint_info["max_total_tokens"] self.model_id = endpoint_info["model_id"] @@ -316,4 +310,6 @@ class InferenceEndpointAdapter(_HfAdapter): # Initialize the adapter self.client = endpoint.async_client self.model_id = endpoint.repository - self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]) + self.max_tokens = int( + endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"] + ) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index f99883990..8e9825107 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -33,10 +33,9 @@ from llama_stack.apis.inference import ( from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( + convert_chat_completion_request_to_openai_params, convert_message_to_openai_dict_new, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, @@ -55,7 +54,9 @@ class LiteLLMOpenAIMixin( Inference, NeedsRequestProviderData, ): - def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str): + def __init__( + self, model_entries, api_key_from_config: str, provider_data_api_key_field: str + ): ModelRegistryHelper.__init__(self, model_entries) self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field @@ -95,7 +96,9 @@ class LiteLLMOpenAIMixin( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: if sampling_params is None: sampling_params = SamplingParams() model = await self.model_store.get_model(model_id) @@ -110,7 +113,17 @@ class LiteLLMOpenAIMixin( tool_config=tool_config, ) - params = await self._get_params(request) + params = await convert_chat_completion_request_to_openai_params(request) + + # add api_key to params if available + provider_data = self.get_request_provider_data() + key_field = self.provider_data_api_key_field + if provider_data and getattr(provider_data, key_field, None): + api_key = getattr(provider_data, key_field) + else: + api_key = self.api_key_from_config + params["api_key"] = api_key + logger.debug(f"params to litellm (openai compat): {params}") # unfortunately, we need to use synchronous litellm.completion here because litellm # caches various httpx.client objects in a non-eventloop aware manner @@ -132,87 +145,6 @@ class LiteLLMOpenAIMixin( ): yield chunk - def _add_additional_properties_recursive(self, schema): - """ - Recursively add additionalProperties: False to all object schemas - """ - if isinstance(schema, dict): - if schema.get("type") == "object": - schema["additionalProperties"] = False - - # Add required field with all property keys if properties exist - if "properties" in schema and schema["properties"]: - schema["required"] = list(schema["properties"].keys()) - - if "properties" in schema: - for prop_schema in schema["properties"].values(): - self._add_additional_properties_recursive(prop_schema) - - for key in ["anyOf", "allOf", "oneOf"]: - if key in schema: - for sub_schema in schema[key]: - self._add_additional_properties_recursive(sub_schema) - - if "not" in schema: - self._add_additional_properties_recursive(schema["not"]) - - # Handle $defs/$ref - if "$defs" in schema: - for def_schema in schema["$defs"].values(): - self._add_additional_properties_recursive(def_schema) - - return schema - - async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} - - input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt, - "strict": True, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) - - provider_data = self.get_request_provider_data() - key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self.api_key_from_config - - return { - "model": request.model, - "api_key": api_key, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - async def embeddings( self, model_id: str, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 07976e811..2dce64675 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -6,55 +6,7 @@ import json import logging import warnings -from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union - -from openai import AsyncStream -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionChunk as OpenAIChatCompletionChunk, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion import ( - Choice as OpenAIChoice, -) -from openai.types.chat.chat_completion import ( - ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call_param import ( - Function as OpenAIFunction, -) -from pydantic import BaseModel +from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -71,11 +23,14 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, + CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, + JsonSchemaResponseFormat, Message, SystemMessage, TokenLogProbs, + ToolChoice, ToolResponseMessage, UserMessage, ) @@ -94,6 +49,32 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) +from openai import AsyncStream, Stream +from openai.types.chat import ( + ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, + ChatCompletionChunk as OpenAIChatCompletionChunk, + ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, + ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, + ChatCompletionMessageParam as OpenAIChatCompletionMessage, + ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall, + ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, + ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, + ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, +) +from openai.types.chat.chat_completion import ( + Choice as OpenAIChoice, + ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ImageURL as OpenAIImageURL, +) +from openai.types.chat.chat_completion_message_tool_call_param import ( + Function as OpenAIFunction, +) +from pydantic import BaseModel + logger = logging.getLogger(__name__) @@ -188,12 +169,16 @@ def convert_openai_completion_logprobs( if logprobs.tokens and logprobs.token_logprobs: return [ TokenLogProbs(logprobs_by_token={token: token_lp}) - for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False) + for token, token_lp in zip( + logprobs.tokens, logprobs.token_logprobs, strict=False + ) ] return None -def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]): +def convert_openai_completion_logprobs_stream( + text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]] +): if logprobs is None: return None if isinstance(logprobs, float): @@ -238,7 +223,9 @@ def process_chat_completion_response( if not choice.message or not choice.message.tool_calls: raise ValueError("Tool calls are not present in the response") - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [ + convert_tool_call(tool_call) for tool_call in choice.message.tool_calls + ] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -262,7 +249,9 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) + raw_message = decode_assistant_message( + text_from_choice(choice), get_stop_reason(choice.finish_reason) + ) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -463,13 +452,17 @@ async def process_chat_completion_stream_response( ) -async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict: +async def convert_message_to_openai_dict( + message: Message, download: bool = False +) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageContentItem): return { "type": "image_url", "image_url": { - "url": await convert_image_content_to_url(content, download=download), + "url": await convert_image_content_to_url( + content, download=download + ), }, } else: @@ -548,7 +541,9 @@ async def convert_message_to_openai_dict_new( elif isinstance(content_, ImageContentItem): return OpenAIChatCompletionContentPartImageParam( type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)), + image_url=OpenAIImageURL( + url=await convert_image_content_to_url(content_) + ), ) elif isinstance(content_, list): return [await impl(item) for item in content_] @@ -574,14 +569,32 @@ async def convert_message_to_openai_dict_new( role="assistant", content=await _convert_message_content(message.content), tool_calls=[ - OpenAIChatCompletionMessageToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), - arguments=json.dumps(tool.arguments), - ), - type="function", - ) + # OpenAIChatCompletionMessageToolCall( + # id=tool.call_id, + # function=OpenAIFunction( + # name=( + # tool.tool_name + # if not isinstance(tool.tool_name, BuiltinTool) + # else tool.tool_name.value + # ), + # arguments=json.dumps(tool.arguments), + # ), + # type="function", + # ) + # using a dict instead of OpenAIChatCompletionMessageToolCall object + # as it fails to get json encoded + { + "id": tool.call_id, + "function": { + "name": ( + tool.tool_name + if not isinstance(tool.tool_name, BuiltinTool) + else tool.tool_name.value + ), + "arguments": json.dumps(tool.arguments), + }, + "type": "function", + } for tool in message.tool_calls ] or None, @@ -604,7 +617,7 @@ async def convert_message_to_openai_dict_new( def convert_tool_call( - tool_call: ChatCompletionMessageToolCall, + tool_call: OpenAIChatCompletionMessageToolCall, ) -> Union[ToolCall, UnparseableToolCall]: """ Convert a ChatCompletionMessageToolCall tool call to either a @@ -696,7 +709,11 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: properties = parameters["properties"] required = [] for param_name, param in tool.parameters.items(): - properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)} + properties[param_name] = { + "type": PYTHON_TYPE_TO_LITELLM_TYPE.get( + param.param_type, param.param_type + ) + } if param.description: properties[param_name].update(description=param.description) if param.default: @@ -762,15 +779,30 @@ def _convert_openai_tool_calls( if not tool_calls: return [] # CompletionMessage tool_calls is not optional - return [ - ToolCall( - call_id=call.id, - tool_name=call.function.name, - arguments=json.loads(call.function.arguments), - arguments_json=call.function.arguments, + ls_tool_calls = [] + for call in tool_calls: + args = call.function.arguments + # TGI is sending a dict instead of a json string + # While OpenAI spec expects a json string + if isinstance(args, str): + arguments = json.loads(args) + arguments_json = args + elif isinstance(args, dict): + arguments = args + arguments_json = json.dumps(args) + else: + raise ValueError(f"Unsupported arguments type: {type(args)}") + + ls_tool_calls.append( + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=arguments, + arguments_json=arguments_json, + ) ) - for call in tool_calls - ] + + return ls_tool_calls def _convert_openai_logprobs( @@ -802,7 +834,11 @@ def _convert_openai_logprobs( return None return [ - TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs}) + TokenLogProbs( + logprobs_by_token={ + logprobs.token: logprobs.logprob for logprobs in content.top_logprobs + } + ) for content in logprobs.content ] @@ -840,14 +876,17 @@ def convert_openai_chat_completion_choice( end_of_message = "end_of_message" out_of_tokens = "out_of_tokens" """ - assert hasattr(choice, "message") and choice.message, "error in server response: message not found" - assert hasattr(choice, "finish_reason") and choice.finish_reason, ( - "error in server response: finish_reason not found" - ) + assert ( + hasattr(choice, "message") and choice.message + ), "error in server response: message not found" + assert ( + hasattr(choice, "finish_reason") and choice.finish_reason + ), "error in server response: finish_reason not found" return ChatCompletionResponse( completion_message=CompletionMessage( - content=choice.message.content or "", # CompletionMessage content is not optional + content=choice.message.content + or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), ), @@ -856,13 +895,24 @@ def convert_openai_chat_completion_choice( async def convert_openai_chat_completion_stream( - stream: AsyncStream[OpenAIChatCompletionChunk], + stream: Union[ + AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk] + ], enable_incremental_tool_calls: bool, ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: """ Convert a stream of OpenAI chat completion chunks into a stream of ChatCompletionResponseStreamChunk. """ + + async def yield_from_stream(stream): + if isinstance(stream, AsyncGenerator): + async for chunk in stream: + yield chunk + elif isinstance(stream, Generator): + for chunk in stream: + yield chunk + yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, @@ -874,7 +924,7 @@ async def convert_openai_chat_completion_stream( stop_reason = None tool_call_idx_to_buffer = {} - async for chunk in stream: + async for chunk in yield_from_stream(stream): choice = chunk.choices[0] # assuming only one choice per chunk # we assume there's only one finish_reason in the stream @@ -916,12 +966,60 @@ async def convert_openai_chat_completion_stream( ) ) else: - for tool_call in choice.delta.tool_calls: - idx = tool_call.index if hasattr(tool_call, "index") else 0 + if isinstance(choice.delta.tool_calls, list): + tool_calls = choice.delta.tool_calls + for tool_call in tool_calls: + idx = tool_call.index if hasattr(tool_call, "index") else 0 + + if idx not in tool_call_idx_to_buffer: + tool_call_idx_to_buffer[idx] = { + "call_id": tool_call.id, + "name": None, + "arguments": "", + "content": "", + } + + buffer = tool_call_idx_to_buffer[idx] + + if tool_call.function: + if tool_call.function.name: + buffer["name"] = tool_call.function.name + delta = f"{buffer['name']}(" + buffer["content"] += delta + + if tool_call.function.arguments: + delta = tool_call.function.arguments + buffer["arguments"] += delta + buffer["content"] += delta + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=delta, + parse_status=ToolCallParseStatus.in_progress, + ), + logprobs=_convert_openai_logprobs(logprobs), + ) + ) + # TGI streams a non-openai compat response + elif isinstance(choice.delta.tool_calls, dict): + # tool_calls is a dict of the format + # { + # 'index': 0, + # 'id': '', + # 'type': 'function', + # 'function': { + # 'name': None, + # 'arguments': '{"' + # } + # } + tool_call = choice.delta.tool_calls + idx = tool_call["index"] if "index" in tool_call else 0 if idx not in tool_call_idx_to_buffer: tool_call_idx_to_buffer[idx] = { - "call_id": tool_call.id, + "call_id": tool_call["id"], "name": None, "arguments": "", "content": "", @@ -929,14 +1027,15 @@ async def convert_openai_chat_completion_stream( buffer = tool_call_idx_to_buffer[idx] - if tool_call.function: - if tool_call.function.name: - buffer["name"] = tool_call.function.name + if "function" in tool_call: + function = tool_call["function"] + if function["name"]: + buffer["name"] = function["name"] delta = f"{buffer['name']}(" buffer["content"] += delta - if tool_call.function.arguments: - delta = tool_call.function.arguments + if function["arguments"]: + delta = function["arguments"] buffer["arguments"] += delta buffer["content"] += delta @@ -994,7 +1093,6 @@ async def convert_openai_chat_completion_stream( ) ) except json.JSONDecodeError as e: - print(f"Failed to parse arguments: {e}") yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -1005,6 +1103,51 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) + else: + # name is None but we have the arguments contain the entire function call + # example response where arguments is --> + # '{"function": {"_name": "get_weather", "location": "San Francisco, CA"}}<|eot_id|>' + # - parse the arguments + # - build try to build ToolCall and return it or return the content as is + + if buffer["arguments"]: + arguments = buffer["arguments"] + # remove the eot_id and eom_id from the arguments + if arguments.endswith("<|eom_id|>"): + arguments = arguments[: -len("<|eom_id|>")] + if arguments.endswith("<|eot_id|>"): + arguments = arguments[: -len("<|eot_id|>")] + + arguments = json.loads(arguments) + try: + tool_name = arguments["function"].pop("_name", None) + parsed_tool_call = ToolCall( + call_id=buffer["call_id"], + tool_name=tool_name, + arguments=arguments["function"], + arguments_json=json.dumps(arguments["function"]), + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=parsed_tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + except (KeyError, json.JSONDecodeError) as e: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=buffer["content"], + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) + ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -1013,3 +1156,113 @@ async def convert_openai_chat_completion_stream( stop_reason=stop_reason, ) ) + + +async def convert_completion_request_to_openai_params( + request: CompletionRequest, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary. + """ + input_dict = {} + if request.logprobs: + input_dict["logprobs"] = request.logprobs.top_k + + return { + "model": request.model, + "prompt": request.content, + **input_dict, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + "n": 1, + } + + +async def convert_chat_completion_request_to_openai_params( + request: ChatCompletionRequest, +) -> dict: + """ + Convert a ChatCompletionRequest to an OpenAI chat completion request. + """ + input_dict = {} + + input_dict["messages"] = [ + await convert_message_to_openai_dict_new(m) for m in request.messages + ] + if fmt := request.response_format: + if not isinstance(fmt, JsonSchemaResponseFormat): + raise ValueError( + f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." + ) + + fmt = fmt.json_schema + name = fmt["title"] + del fmt["title"] + fmt["additionalProperties"] = False + + # Apply additionalProperties: False recursively to all objects + fmt = _add_additional_properties_recursive(fmt) + + from rich.pretty import pprint + + pprint(fmt) + + input_dict["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": name, + "schema": fmt, + "strict": True, + }, + } + + if request.tools: + input_dict["tools"] = [ + convert_tooldef_to_openai_tool(tool) for tool in request.tools + ] + if request.tool_config.tool_choice: + input_dict["tool_choice"] = ( + request.tool_config.tool_choice.value + if isinstance(request.tool_config.tool_choice, ToolChoice) + else request.tool_config.tool_choice + ) + + return { + "model": request.model, + **input_dict, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + "n": 1, + } + + +def _add_additional_properties_recursive(schema): + """ + Recursively add `additionalProperties: False` to all object schemas + """ + if isinstance(schema, dict): + if schema.get("type") == "object": + schema["additionalProperties"] = False + + # Add required field with all property keys if properties exist + if "properties" in schema and schema["properties"]: + schema["required"] = list(schema["properties"].keys()) + + if "properties" in schema: + for prop_schema in schema["properties"].values(): + _add_additional_properties_recursive(prop_schema) + + for key in ["anyOf", "allOf", "oneOf"]: + if key in schema: + for sub_schema in schema[key]: + _add_additional_properties_recursive(sub_schema) + + if "not" in schema: + _add_additional_properties_recursive(schema["not"]) + + # Handle $defs/$ref + if "$defs" in schema: + for def_schema in schema["$defs"].values(): + _add_additional_properties_recursive(def_schema) + + return schema diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 1edf445c0..f0d8f2b72 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -12,7 +12,6 @@ import re from typing import List, Optional, Tuple, Union import httpx -from PIL import Image as PIL_Image from llama_stack.apis.common.content_types import ( ImageContentItem, @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( + is_multimodal, ModelFamily, RawContent, RawContentItem, @@ -43,7 +43,6 @@ from llama_stack.models.llama.datatypes import ( Role, StopReason, ToolPromptFormat, - is_multimodal, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.prompt_templates import ( @@ -56,6 +55,7 @@ from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.utils.inference import supported_inference_models +from PIL import Image as PIL_Image log = get_logger(name=__name__, category="inference") @@ -129,7 +129,9 @@ async def interleaved_content_convert_to_raw( if image.url.uri.startswith("data"): match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri) if not match: - raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...") + raise ValueError( + f"Invalid data URL format, {image.url.uri[:40]}..." + ) _, image_data = match.groups() data = base64.b64decode(image_data) elif image.url.uri.startswith("file://"): @@ -209,13 +211,17 @@ async def convert_image_content_to_url( content, format = await localize_image_content(media) if include_format: - return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8") + return f"data:image/{format};base64," + base64.b64encode(content).decode( + "utf-8" + ) else: return base64.b64encode(content).decode("utf-8") async def completion_request_to_prompt(request: CompletionRequest) -> str: - content = augment_content_with_response_format_prompt(request.response_format, request.content) + content = augment_content_with_response_format_prompt( + request.response_format, request.content + ) request.content = content request = await convert_request_to_raw(request) @@ -224,8 +230,12 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str: return formatter.tokenizer.decode(model_input.tokens) -async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]: - content = augment_content_with_response_format_prompt(request.response_format, request.content) +async def completion_request_to_prompt_model_input_info( + request: CompletionRequest, +) -> Tuple[str, int]: + content = augment_content_with_response_format_prompt( + request.response_format, request.content + ) request.content = content request = await convert_request_to_raw(request) @@ -246,7 +256,9 @@ def augment_content_with_response_format_prompt(response_format, content): return content -async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str: +async def chat_completion_request_to_prompt( + request: ChatCompletionRequest, llama_model: str +) -> str: messages = chat_completion_request_to_messages(request, llama_model) request.messages = messages request = await convert_request_to_raw(request) @@ -254,7 +266,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), + tool_prompt_format=request.tool_config.tool_prompt_format + or get_default_tool_prompt_format(llama_model), ) return formatter.tokenizer.decode(model_input.tokens) @@ -269,10 +282,17 @@ async def chat_completion_request_to_model_input_info( formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) model_input = formatter.encode_dialog_prompt( request.messages, - tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model), + tool_prompt_format=request.tool_config.tool_prompt_format + or get_default_tool_prompt_format(llama_model), ) + tokens = [] + for t in model_input.tokens: + if t == 128256: + tokens.append(formatter.vision_token) + else: + tokens.append(t) return ( - formatter.tokenizer.decode(model_input.tokens), + formatter.tokenizer.decode(tokens), len(model_input.tokens), ) @@ -298,7 +318,8 @@ def chat_completion_request_to_messages( return request.messages if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) + model.model_family == ModelFamily.llama3_2 + and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format messages = augment_messages_for_tools_llama_3_1(request) @@ -334,7 +355,9 @@ def augment_messages_for_tools_llama_3_1( if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" + assert ( + existing_messages[0].role != Role.system.value + ), "Should only have 1 system message" messages = [] @@ -366,9 +389,13 @@ def augment_messages_for_tools_llama_3_1( if isinstance(existing_system_message.content, str): sys_content += _process(existing_system_message.content) elif isinstance(existing_system_message.content, list): - sys_content += "\n".join([_process(c) for c in existing_system_message.content]) + sys_content += "\n".join( + [_process(c) for c in existing_system_message.content] + ) - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + tool_choice_prompt = _get_tool_choice_prompt( + request.tool_config.tool_choice, request.tools + ) if tool_choice_prompt: sys_content += "\n" + tool_choice_prompt @@ -402,7 +429,9 @@ def augment_messages_for_tools_llama_3_2( if existing_messages[0].role == Role.system.value: existing_system_message = existing_messages.pop(0) - assert existing_messages[0].role != Role.system.value, "Should only have 1 system message" + assert ( + existing_messages[0].role != Role.system.value + ), "Should only have 1 system message" sys_content = "" custom_tools, builtin_tools = [], [] @@ -423,10 +452,16 @@ def augment_messages_for_tools_llama_3_2( if custom_tools: fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list if fmt != ToolPromptFormat.python_list: - raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}") + raise ValueError( + f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}" + ) system_prompt = None - if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: + if ( + existing_system_message + and request.tool_config.system_message_behavior + == SystemMessageBehavior.replace + ): system_prompt = existing_system_message.content tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) @@ -435,11 +470,16 @@ def augment_messages_for_tools_llama_3_2( sys_content += "\n" if existing_system_message and ( - request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools + request.tool_config.system_message_behavior == SystemMessageBehavior.append + or not custom_tools ): - sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n") + sys_content += interleaved_content_as_str( + existing_system_message.content, sep="\n" + ) - tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools) + tool_choice_prompt = _get_tool_choice_prompt( + request.tool_config.tool_choice, request.tools + ) if tool_choice_prompt: sys_content += "\n" + tool_choice_prompt @@ -447,11 +487,15 @@ def augment_messages_for_tools_llama_3_2( return messages -def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str: +def _get_tool_choice_prompt( + tool_choice: ToolChoice | str, tools: List[ToolDefinition] +) -> str: if tool_choice == ToolChoice.auto: return "" elif tool_choice == ToolChoice.required: - return "You MUST use one of the provided functions/tools to answer the user query." + return ( + "You MUST use one of the provided functions/tools to answer the user query." + ) elif tool_choice == ToolChoice.none: # tools are already not passed in return "" @@ -463,11 +507,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: llama_model = resolve_model(model) if llama_model is None: - log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format") + log.warning( + f"Could not resolve model {model}, defaulting to json tool prompt format" + ) return ToolPromptFormat.json if llama_model.model_family == ModelFamily.llama3_1 or ( - llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id) + llama_model.model_family == ModelFamily.llama3_2 + and is_multimodal(llama_model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format return ToolPromptFormat.json diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 1878c9e88..477e24ae5 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -120,13 +120,16 @@ def client_with_models( judge_model_id, ): client = llama_stack_client + from rich.pretty import pprint providers = [p for p in client.providers.list() if p.api == "inference"] + pprint(providers) assert len(providers) > 0, "No inference providers found" inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] model_ids = {m.identifier for m in client.models.list()} model_ids.update(m.provider_resource_id for m in client.models.list()) + pprint(model_ids) if text_model_id and text_model_id not in model_ids: client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index f558254e5..83c86aaee 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -8,9 +8,9 @@ import os import pytest -from pydantic import BaseModel from llama_stack.models.llama.sku_list import resolve_model +from pydantic import BaseModel from ..test_cases.test_case import TestCase @@ -23,8 +23,15 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): provider_id = models[model_id].provider_id providers = {p.provider_id: p for p in client_with_models.providers.list()} provider = providers[provider_id] - if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"): - pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") + if provider.provider_type in ( + "remote::openai", + "remote::anthropic", + "remote::gemini", + "remote::groq", + ): + pytest.skip( + f"Model {model_id} hosted by {provider.provider_type} doesn't support completion" + ) def get_llama_model(client_with_models, model_id): @@ -105,7 +112,9 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case) "inference:completion:stop_sequence", ], ) -def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case): +def test_text_completion_stop_sequence( + client_with_models, text_model_id, inference_provider_type, test_case +): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) # This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771 if inference_provider_type != "remote::vllm": @@ -132,7 +141,9 @@ def test_text_completion_stop_sequence(client_with_models, text_model_id, infere "inference:completion:log_probs", ], ) -def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case): +def test_text_completion_log_probs_non_streaming( + client_with_models, text_model_id, inference_provider_type, test_case +): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") @@ -151,7 +162,9 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_ }, ) assert response.logprobs, "Logprobs should not be empty" - assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5 + assert ( + 1 <= len(response.logprobs) <= 5 + ) # each token has 1 logprob and here max_tokens=5 assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) @@ -161,7 +174,9 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_ "inference:completion:log_probs", ], ) -def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case): +def test_text_completion_log_probs_streaming( + client_with_models, text_model_id, inference_provider_type, test_case +): skip_if_model_doesnt_support_completion(client_with_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") @@ -183,7 +198,9 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id, for chunk in streamed_content: if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" - assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs) + assert all( + len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs + ) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" @@ -194,7 +211,13 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id, "inference:completion:structured_output", ], ) -def test_text_completion_structured_output(client_with_models, text_model_id, test_case): +def test_text_completion_structured_output( + client_with_models, text_model_id, test_case, inference_provider_type +): + if inference_provider_type == "remote::tgi": + pytest.xfail( + f"{inference_provider_type} doesn't support structured outputs yet" + ) skip_if_model_doesnt_support_completion(client_with_models, text_model_id) class AnswerFormat(BaseModel): @@ -231,7 +254,9 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te "inference:chat_completion:non_streaming_02", ], ) -def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_non_streaming( + client_with_models, text_model_id, test_case +): tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] @@ -257,14 +282,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t "inference:chat_completion:ttft", ], ) -def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case): +def test_text_chat_completion_first_token_profiling( + client_with_models, text_model_id, test_case +): tc = TestCase(test_case) messages = tc["messages"] - if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800 - from pydantic import TypeAdapter - + if os.environ.get( + "DEBUG_TTFT" + ): # debugging print number of tokens in input, ideally around 800 from llama_stack.apis.inference import Message + from pydantic import TypeAdapter tokenizer, formatter = get_llama_tokenizer() typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages] @@ -279,7 +307,9 @@ def test_text_chat_completion_first_token_profiling(client_with_models, text_mod message_content = response.completion_message.content.lower().strip() assert len(message_content) > 0 - if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150 + if os.environ.get( + "DEBUG_TTFT" + ): # debugging print number of tokens in response, ideally around 150 tokenizer, formatter = get_llama_tokenizer() encoded = formatter.encode_content(message_content) raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0) @@ -302,7 +332,9 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ messages=[{"role": "user", "content": question}], stream=True, ) - streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response] + streamed_content = [ + str(chunk.event.delta.text.lower().strip()) for chunk in response + ] assert len(streamed_content) > 0 assert expected.lower() in "".join(streamed_content) @@ -313,7 +345,9 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_ "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_calling_and_non_streaming( + client_with_models, text_model_id, test_case +): tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -327,7 +361,10 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo assert response.completion_message.role == "assistant" assert len(response.completion_message.tool_calls) == 1 - assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"] + assert ( + response.completion_message.tool_calls[0].tool_name + == tc["tools"][0]["tool_name"] + ) assert response.completion_message.tool_calls[0].arguments == tc["expected"] @@ -350,7 +387,9 @@ def extract_tool_invocation_content(response): "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_calling_and_streaming( + client_with_models, text_model_id, test_case +): tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -372,7 +411,14 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_choice_required( + client_with_models, text_model_id, test_case, inference_provider_type +): + if inference_provider_type == "remote::tgi": + pytest.xfail( + f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet" + ) + tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -396,7 +442,9 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text "inference:chat_completion:tool_calling", ], ) -def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case): +def test_text_chat_completion_with_tool_choice_none( + client_with_models, text_model_id, test_case +): tc = TestCase(test_case) response = client_with_models.inference.chat_completion( @@ -416,7 +464,14 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod "inference:chat_completion:structured_output", ], ) -def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): +def test_text_chat_completion_structured_output( + client_with_models, text_model_id, test_case, inference_provider_type +): + if inference_provider_type == "remote::tgi": + pytest.xfail( + f"{inference_provider_type} doesn't support structured outputs yet" + ) + class NBAStats(BaseModel): year_for_draft: int num_seasons_in_nba: int diff --git a/tests/integration/inference/test_vision_inference.py b/tests/integration/inference/test_vision_inference.py index 9f6fb0478..7f0f935a7 100644 --- a/tests/integration/inference/test_vision_inference.py +++ b/tests/integration/inference/test_vision_inference.py @@ -27,7 +27,9 @@ def base64_image_url(base64_image_data, image_path): return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" -@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.") +# @pytest.mark.xfail( +# reason="This test is failing because the image is not being downloaded correctly." +# ) def test_image_chat_completion_non_streaming(client_with_models, vision_model_id): message = { "role": "user", @@ -56,7 +58,9 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) -@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.") +# @pytest.mark.xfail( +# reason="This test is failing because the image is not being downloaded correctly." +# ) def test_image_chat_completion_streaming(client_with_models, vision_model_id): message = { "role": "user", @@ -87,8 +91,10 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id): assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) -@pytest.mark.parametrize("type_", ["url", "data"]) -def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): +@pytest.mark.parametrize("type_", ["url"]) +def test_image_chat_completion_base64( + client_with_models, vision_model_id, base64_image_data, base64_image_url, type_ +): image_spec = { "url": { "type": "image",