fixes and linting

This commit is contained in:
Hardik Shah 2025-03-28 18:33:36 -07:00
parent 021dd0d35d
commit 5251d2422d
8 changed files with 149 additions and 345 deletions

View file

@ -56,9 +56,7 @@ from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter( class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config self.config = config
@ -70,9 +68,7 @@ class FireworksInferenceAdapter(
pass pass
def _get_api_key(self) -> str: def _get_api_key(self) -> str:
config_api_key = ( config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
self.config.api_key.get_secret_value() if self.config.api_key else None
)
if config_api_key: if config_api_key:
return config_api_key return config_api_key
else: else:
@ -112,9 +108,7 @@ class FireworksInferenceAdapter(
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
async def _nonstream_completion( async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self._get_client().completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_completion_response(r) return process_completion_response(r)
@ -194,9 +188,7 @@ class FireworksInferenceAdapter(
else: else:
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await self._get_client().chat.completions.acreate(**params) r = await self._get_client().chat.completions.acreate(**params)
@ -204,9 +196,7 @@ class FireworksInferenceAdapter(
r = await self._get_client().completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion( async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
async def _to_async_generator(): async def _to_async_generator():
@ -221,9 +211,7 @@ class FireworksInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_params( async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
input_dict = {} input_dict = {}
media_present = request_has_media(request) media_present = request_has_media(request)
@ -231,17 +219,12 @@ class FireworksInferenceAdapter(
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
if media_present or not llama_model: if media_present or not llama_model:
input_dict["messages"] = [ input_dict["messages"] = [
await convert_message_to_openai_dict(m, download=True) await convert_message_to_openai_dict(m, download=True) for m in request.messages
for m in request.messages
] ]
else: else:
input_dict["prompt"] = await chat_completion_request_to_prompt( input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
request, llama_model
)
else: else:
assert ( assert not media_present, "Fireworks does not support media for Completion requests"
not media_present
), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
# Fireworks always prepends with BOS # Fireworks always prepends with BOS
@ -253,9 +236,7 @@ class FireworksInferenceAdapter(
"model": request.model, "model": request.model,
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**self._build_options( **self._build_options(request.sampling_params, request.response_format, request.logprobs),
request.sampling_params, request.response_format, request.logprobs
),
} }
logger.debug(f"params to fireworks: {params}") logger.debug(f"params to fireworks: {params}")
@ -274,9 +255,9 @@ class FireworksInferenceAdapter(
kwargs = {} kwargs = {}
if model.metadata.get("embedding_dimension"): if model.metadata.get("embedding_dimension"):
kwargs["dimensions"] = model.metadata.get("embedding_dimension") kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all( assert all(not content_has_media(content) for content in contents), (
not content_has_media(content) for content in contents "Fireworks does not support media for embeddings"
), "Fireworks does not support media for embeddings" )
response = self._get_client().embeddings.create( response = self._get_client().embeddings.create(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],

View file

@ -5,10 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi, InferenceClient from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -16,12 +15,10 @@ from llama_stack.apis.common.content_types import (
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -38,26 +35,20 @@ from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_chat_completion_request_to_openai_params,
convert_completion_request_to_openai_params,
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
process_chat_completion_response, convert_chat_completion_request_to_openai_params,
process_chat_completion_stream_response, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
get_sampling_options,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info, completion_request_to_prompt_model_input_info,
) )
@ -85,9 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def __init__(self) -> None: def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = { self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
for model in all_registered_models()
if model.huggingface_repo
} }
async def shutdown(self) -> None: async def shutdown(self) -> None:
@ -114,7 +103,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
if response_format: if response_format:
raise ValueError(f"TGI does not support Response Format for completions.") raise ValueError("TGI does not support Response Format for completions.")
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
@ -166,17 +155,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
return options return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict: async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info( prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
request
)
return dict( return dict(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
details=True, details=True,
max_new_tokens=self._get_max_new_tokens( max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"], stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format), **self._build_options(request.sampling_params, request.response_format),
) )
@ -185,16 +170,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
params = await self._get_params_for_completion(request) params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = self.client.text_generation(**params) s = await self.client.text_generation(**params)
for chunk in s: async for chunk in s:
token_result = chunk.token token_result = chunk.token
finish_reason = None finish_reason = None
if chunk.details: if chunk.details:
finish_reason = chunk.details.finish_reason finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(text=token_result.text, finish_reason=finish_reason)
text=token_result.text, finish_reason=finish_reason
)
yield OpenAICompatCompletionResponse( yield OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
@ -205,7 +188,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request) params = await self._get_params_for_completion(request)
r = self.client.text_generation(**params) r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason, finish_reason=r.details.finish_reason,
@ -234,9 +217,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
from rich.pretty import pprint
pprint(messages)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,
@ -250,18 +230,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
params = await convert_chat_completion_request_to_openai_params(request) params = await convert_chat_completion_request_to_openai_params(request)
import json response = await self.client.chat.completions.create(**params)
# print(json.dumps(params, indent=2))
pprint(params)
response = self.client.chat.completions.create(**params)
if stream: if stream:
return convert_openai_chat_completion_stream( return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=True)
response, enable_incremental_tool_calls=True
)
else: else:
return convert_openai_chat_completion_choice(response.choices[0]) return convert_openai_chat_completion_choice(response.choices[0])
@ -281,18 +252,16 @@ class TGIAdapter(_HfAdapter):
logger.info(f"Initializing TGI client with url={config.url}") logger.info(f"Initializing TGI client with url={config.url}")
# unfortunately, the TGI async client does not work well with proxies # unfortunately, the TGI async client does not work well with proxies
# so using sync client for now instead # so using sync client for now instead
self.client = InferenceClient(model=f"{config.url}") self.client = AsyncInferenceClient(model=f"{config.url}")
endpoint_info = self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None: async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
model=config.huggingface_repo, token=config.api_token.get_secret_value()
)
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
@ -310,6 +279,4 @@ class InferenceEndpointAdapter(_HfAdapter):
# Initialize the adapter # Initialize the adapter
self.client = endpoint.async_client self.client = endpoint.async_client
self.model_id = endpoint.repository self.model_id = endpoint.repository
self.max_tokens = int( self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)

View file

@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -36,11 +35,8 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_chat_completion_request_to_openai_params, convert_chat_completion_request_to_openai_params,
convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream, convert_openai_chat_completion_stream,
convert_tooldef_to_openai_tool,
get_sampling_options,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
@ -54,9 +50,7 @@ class LiteLLMOpenAIMixin(
Inference, Inference,
NeedsRequestProviderData, NeedsRequestProviderData,
): ):
def __init__( def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
self, model_entries, api_key_from_config: str, provider_data_api_key_field: str
):
ModelRegistryHelper.__init__(self, model_entries) ModelRegistryHelper.__init__(self, model_entries)
self.api_key_from_config = api_key_from_config self.api_key_from_config = api_key_from_config
self.provider_data_api_key_field = provider_data_api_key_field self.provider_data_api_key_field = provider_data_api_key_field
@ -96,9 +90,7 @@ class LiteLLMOpenAIMixin(
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -6,7 +6,49 @@
import json import json
import logging import logging
import warnings import warnings
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
from openai.types.chat import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.chat import (
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
)
from openai.types.chat import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
)
from openai.types.chat import (
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
)
from openai.types.chat import (
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
)
from openai.types.chat.chat_completion import (
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from pydantic import BaseModel
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
@ -23,7 +65,6 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage, CompletionMessage,
CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
@ -49,32 +90,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message, decode_assistant_message,
) )
from openai import AsyncStream, Stream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam,
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
)
from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_content_part_image_param import (
ImageURL as OpenAIImageURL,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -169,16 +184,12 @@ def convert_openai_completion_logprobs(
if logprobs.tokens and logprobs.token_logprobs: if logprobs.tokens and logprobs.token_logprobs:
return [ return [
TokenLogProbs(logprobs_by_token={token: token_lp}) TokenLogProbs(logprobs_by_token={token: token_lp})
for token, token_lp in zip( for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
logprobs.tokens, logprobs.token_logprobs, strict=False
)
] ]
return None return None
def convert_openai_completion_logprobs_stream( def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
):
if logprobs is None: if logprobs is None:
return None return None
if isinstance(logprobs, float): if isinstance(logprobs, float):
@ -223,9 +234,7 @@ def process_chat_completion_response(
if not choice.message or not choice.message.tool_calls: if not choice.message or not choice.message.tool_calls:
raise ValueError("Tool calls are not present in the response") raise ValueError("Tool calls are not present in the response")
tool_calls = [ tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
]
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them # If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse( return ChatCompletionResponse(
@ -249,9 +258,7 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider # TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058 # Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message( raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
# NOTE: If we do not set tools in chat-completion request, we should not # NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw # expect the ToolCall in the response. Instead, we should return the raw
@ -452,17 +459,13 @@ async def process_chat_completion_stream_response(
) )
async def convert_message_to_openai_dict( async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
message: Message, download: bool = False
) -> dict:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:
if isinstance(content, ImageContentItem): if isinstance(content, ImageContentItem):
return { return {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": await convert_image_content_to_url( "url": await convert_image_content_to_url(content, download=download),
content, download=download
),
}, },
} }
else: else:
@ -541,9 +544,7 @@ async def convert_message_to_openai_dict_new(
elif isinstance(content_, ImageContentItem): elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam( return OpenAIChatCompletionContentPartImageParam(
type="image_url", type="image_url",
image_url=OpenAIImageURL( image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
url=await convert_image_content_to_url(content_)
),
) )
elif isinstance(content_, list): elif isinstance(content_, list):
return [await impl(item) for item in content_] return [await impl(item) for item in content_]
@ -587,9 +588,7 @@ async def convert_message_to_openai_dict_new(
"id": tool.call_id, "id": tool.call_id,
"function": { "function": {
"name": ( "name": (
tool.tool_name tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value
if not isinstance(tool.tool_name, BuiltinTool)
else tool.tool_name.value
), ),
"arguments": json.dumps(tool.arguments), "arguments": json.dumps(tool.arguments),
}, },
@ -709,11 +708,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
properties = parameters["properties"] properties = parameters["properties"]
required = [] required = []
for param_name, param in tool.parameters.items(): for param_name, param in tool.parameters.items():
properties[param_name] = { properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(
param.param_type, param.param_type
)
}
if param.description: if param.description:
properties[param_name].update(description=param.description) properties[param_name].update(description=param.description)
if param.default: if param.default:
@ -834,11 +829,7 @@ def _convert_openai_logprobs(
return None return None
return [ return [
TokenLogProbs( TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
logprobs_by_token={
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
}
)
for content in logprobs.content for content in logprobs.content
] ]
@ -876,17 +867,14 @@ def convert_openai_chat_completion_choice(
end_of_message = "end_of_message" end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens" out_of_tokens = "out_of_tokens"
""" """
assert ( assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
hasattr(choice, "message") and choice.message assert hasattr(choice, "finish_reason") and choice.finish_reason, (
), "error in server response: message not found" "error in server response: finish_reason not found"
assert ( )
hasattr(choice, "finish_reason") and choice.finish_reason
), "error in server response: finish_reason not found"
return ChatCompletionResponse( return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=choice.message.content content=choice.message.content or "", # CompletionMessage content is not optional
or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason), stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
), ),
@ -895,9 +883,7 @@ def convert_openai_chat_completion_choice(
async def convert_openai_chat_completion_stream( async def convert_openai_chat_completion_stream(
stream: Union[ stream: AsyncStream[OpenAIChatCompletionChunk],
AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk]
],
enable_incremental_tool_calls: bool, enable_incremental_tool_calls: bool,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
""" """
@ -905,14 +891,6 @@ async def convert_openai_chat_completion_stream(
of ChatCompletionResponseStreamChunk. of ChatCompletionResponseStreamChunk.
""" """
async def yield_from_stream(stream):
if isinstance(stream, AsyncGenerator):
async for chunk in stream:
yield chunk
elif isinstance(stream, Generator):
for chunk in stream:
yield chunk
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
@ -924,7 +902,7 @@ async def convert_openai_chat_completion_stream(
stop_reason = None stop_reason = None
tool_call_idx_to_buffer = {} tool_call_idx_to_buffer = {}
async for chunk in yield_from_stream(stream): async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream # we assume there's only one finish_reason in the stream
@ -1092,7 +1070,7 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
except json.JSONDecodeError as e: except json.JSONDecodeError:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -1137,7 +1115,7 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason, stop_reason=stop_reason,
) )
) )
except (KeyError, json.JSONDecodeError) as e: except (KeyError, json.JSONDecodeError):
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -1158,26 +1136,6 @@ async def convert_openai_chat_completion_stream(
) )
async def convert_completion_request_to_openai_params(
request: CompletionRequest,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
input_dict = {}
if request.logprobs:
input_dict["logprobs"] = request.logprobs.top_k
return {
"model": request.model,
"prompt": request.content,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
"n": 1,
}
async def convert_chat_completion_request_to_openai_params( async def convert_chat_completion_request_to_openai_params(
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> dict: ) -> dict:
@ -1186,14 +1144,10 @@ async def convert_chat_completion_request_to_openai_params(
""" """
input_dict = {} input_dict = {}
input_dict["messages"] = [ input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
await convert_message_to_openai_dict_new(m) for m in request.messages
]
if fmt := request.response_format: if fmt := request.response_format:
if not isinstance(fmt, JsonSchemaResponseFormat): if not isinstance(fmt, JsonSchemaResponseFormat):
raise ValueError( raise ValueError(f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported.")
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema fmt = fmt.json_schema
name = fmt["title"] name = fmt["title"]
@ -1217,9 +1171,7 @@ async def convert_chat_completion_request_to_openai_params(
} }
if request.tools: if request.tools:
input_dict["tools"] = [ input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
convert_tooldef_to_openai_tool(tool) for tool in request.tools
]
if request.tool_config.tool_choice: if request.tool_config.tool_choice:
input_dict["tool_choice"] = ( input_dict["tool_choice"] = (
request.tool_config.tool_choice.value request.tool_config.tool_choice.value

View file

@ -12,6 +12,7 @@ import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import httpx import httpx
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
@ -33,7 +34,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
is_multimodal,
ModelFamily, ModelFamily,
RawContent, RawContent,
RawContentItem, RawContentItem,
@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import (
Role, Role,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
is_multimodal,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.prompt_templates import (
@ -55,7 +56,6 @@ from llama_stack.models.llama.llama3.prompt_templates import (
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
from PIL import Image as PIL_Image
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")
@ -129,9 +129,7 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"): if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri) match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match: if not match:
raise ValueError( raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups() _, image_data = match.groups()
data = base64.b64decode(image_data) data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"): elif image.url.uri.startswith("file://"):
@ -211,17 +209,13 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media) content, format = await localize_image_content(media)
if include_format: if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode( return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
"utf-8"
)
else: else:
return base64.b64encode(content).decode("utf-8") return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest) -> str: async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt( content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.response_format, request.content
)
request.content = content request.content = content
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
@ -233,9 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
async def completion_request_to_prompt_model_input_info( async def completion_request_to_prompt_model_input_info(
request: CompletionRequest, request: CompletionRequest,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt( content = augment_content_with_response_format_prompt(request.response_format, request.content)
request.response_format, request.content
)
request.content = content request.content = content
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
@ -256,9 +248,7 @@ def augment_content_with_response_format_prompt(response_format, content):
return content return content
async def chat_completion_request_to_prompt( async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
request: ChatCompletionRequest, llama_model: str
) -> str:
messages = chat_completion_request_to_messages(request, llama_model) messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages request.messages = messages
request = await convert_request_to_raw(request) request = await convert_request_to_raw(request)
@ -266,8 +256,7 @@ async def chat_completion_request_to_prompt(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt( model_input = formatter.encode_dialog_prompt(
request.messages, request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
or get_default_tool_prompt_format(llama_model),
) )
return formatter.tokenizer.decode(model_input.tokens) return formatter.tokenizer.decode(model_input.tokens)
@ -282,8 +271,7 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt( model_input = formatter.encode_dialog_prompt(
request.messages, request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
or get_default_tool_prompt_format(llama_model),
) )
tokens = [] tokens = []
for t in model_input.tokens: for t in model_input.tokens:
@ -318,8 +306,7 @@ def chat_completion_request_to_messages(
return request.messages return request.messages
if model.model_family == ModelFamily.llama3_1 or ( if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
and is_multimodal(model.core_model_id)
): ):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format # llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request) messages = augment_messages_for_tools_llama_3_1(request)
@ -355,9 +342,7 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value: if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0) existing_system_message = existing_messages.pop(0)
assert ( assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = [] messages = []
@ -389,13 +374,9 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str): if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content) sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list): elif isinstance(existing_system_message.content, list):
sys_content += "\n".join( sys_content += "\n".join([_process(c) for c in existing_system_message.content])
[_process(c) for c in existing_system_message.content]
)
tool_choice_prompt = _get_tool_choice_prompt( tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt: if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt sys_content += "\n" + tool_choice_prompt
@ -429,9 +410,7 @@ def augment_messages_for_tools_llama_3_2(
if existing_messages[0].role == Role.system.value: if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0) existing_system_message = existing_messages.pop(0)
assert ( assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
sys_content = "" sys_content = ""
custom_tools, builtin_tools = [], [] custom_tools, builtin_tools = [], []
@ -452,16 +431,10 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools: if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list: if fmt != ToolPromptFormat.python_list:
raise ValueError( raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
)
system_prompt = None system_prompt = None
if ( if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
existing_system_message
and request.tool_config.system_message_behavior
== SystemMessageBehavior.replace
):
system_prompt = existing_system_message.content system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
@ -470,16 +443,11 @@ def augment_messages_for_tools_llama_3_2(
sys_content += "\n" sys_content += "\n"
if existing_system_message and ( if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
or not custom_tools
): ):
sys_content += interleaved_content_as_str( sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
existing_system_message.content, sep="\n"
)
tool_choice_prompt = _get_tool_choice_prompt( tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt: if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt sys_content += "\n" + tool_choice_prompt
@ -487,15 +455,11 @@ def augment_messages_for_tools_llama_3_2(
return messages return messages
def _get_tool_choice_prompt( def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
) -> str:
if tool_choice == ToolChoice.auto: if tool_choice == ToolChoice.auto:
return "" return ""
elif tool_choice == ToolChoice.required: elif tool_choice == ToolChoice.required:
return ( return "You MUST use one of the provided functions/tools to answer the user query."
"You MUST use one of the provided functions/tools to answer the user query."
)
elif tool_choice == ToolChoice.none: elif tool_choice == ToolChoice.none:
# tools are already not passed in # tools are already not passed in
return "" return ""
@ -507,14 +471,11 @@ def _get_tool_choice_prompt(
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat: def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model) llama_model = resolve_model(model)
if llama_model is None: if llama_model is None:
log.warning( log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
f"Could not resolve model {model}, defaulting to json tool prompt format"
)
return ToolPromptFormat.json return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or ( if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
and is_multimodal(llama_model.core_model_id)
): ):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format # llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json return ToolPromptFormat.json

View file

@ -120,16 +120,13 @@ def client_with_models(
judge_model_id, judge_model_id,
): ):
client = llama_stack_client client = llama_stack_client
from rich.pretty import pprint
providers = [p for p in client.providers.list() if p.api == "inference"] providers = [p for p in client.providers.list() if p.api == "inference"]
pprint(providers)
assert len(providers) > 0, "No inference providers found" assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = {m.identifier for m in client.models.list()} model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list()) model_ids.update(m.provider_resource_id for m in client.models.list())
pprint(model_ids)
if text_model_id and text_model_id not in model_ids: if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) client.models.register(model_id=text_model_id, provider_id=inference_providers[0])

View file

@ -8,9 +8,9 @@
import os import os
import pytest import pytest
from pydantic import BaseModel
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from pydantic import BaseModel
from ..test_cases.test_case import TestCase from ..test_cases.test_case import TestCase
@ -29,9 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
"remote::gemini", "remote::gemini",
"remote::groq", "remote::groq",
): ):
pytest.skip( pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
f"Model {model_id} hosted by {provider.provider_type} doesn't support completion"
)
def get_llama_model(client_with_models, model_id): def get_llama_model(client_with_models, model_id):
@ -112,9 +110,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
"inference:completion:stop_sequence", "inference:completion:stop_sequence",
], ],
) )
def test_text_completion_stop_sequence( def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
client_with_models, text_model_id, inference_provider_type, test_case
):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771 # This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm": if inference_provider_type != "remote::vllm":
@ -141,9 +137,7 @@ def test_text_completion_stop_sequence(
"inference:completion:log_probs", "inference:completion:log_probs",
], ],
) )
def test_text_completion_log_probs_non_streaming( def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
client_with_models, text_model_id, inference_provider_type, test_case
):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -162,9 +156,7 @@ def test_text_completion_log_probs_non_streaming(
}, },
) )
assert response.logprobs, "Logprobs should not be empty" assert response.logprobs, "Logprobs should not be empty"
assert ( assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
1 <= len(response.logprobs) <= 5
) # each token has 1 logprob and here max_tokens=5
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
@ -174,9 +166,7 @@ def test_text_completion_log_probs_non_streaming(
"inference:completion:log_probs", "inference:completion:log_probs",
], ],
) )
def test_text_completion_log_probs_streaming( def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
client_with_models, text_model_id, inference_provider_type, test_case
):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -198,9 +188,7 @@ def test_text_completion_log_probs_streaming(
for chunk in streamed_content: for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty" assert chunk.logprobs, "Logprobs should not be empty"
assert all( assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
)
else: # no token, no logprobs else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty" assert not chunk.logprobs, "Logprobs should be empty"
@ -211,13 +199,9 @@ def test_text_completion_log_probs_streaming(
"inference:completion:structured_output", "inference:completion:structured_output",
], ],
) )
def test_text_completion_structured_output( def test_text_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
client_with_models, text_model_id, test_case, inference_provider_type
):
if inference_provider_type == "remote::tgi": if inference_provider_type == "remote::tgi":
pytest.xfail( pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
f"{inference_provider_type} doesn't support structured outputs yet"
)
skip_if_model_doesnt_support_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
class AnswerFormat(BaseModel): class AnswerFormat(BaseModel):
@ -254,9 +238,7 @@ def test_text_completion_structured_output(
"inference:chat_completion:non_streaming_02", "inference:chat_completion:non_streaming_02",
], ],
) )
def test_text_chat_completion_non_streaming( def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case) tc = TestCase(test_case)
question = tc["question"] question = tc["question"]
expected = tc["expected"] expected = tc["expected"]
@ -282,18 +264,15 @@ def test_text_chat_completion_non_streaming(
"inference:chat_completion:ttft", "inference:chat_completion:ttft",
], ],
) )
def test_text_chat_completion_first_token_profiling( def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case) tc = TestCase(test_case)
messages = tc["messages"] messages = tc["messages"]
if os.environ.get( if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
"DEBUG_TTFT"
): # debugging print number of tokens in input, ideally around 800
from llama_stack.apis.inference import Message
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.inference import Message
tokenizer, formatter = get_llama_tokenizer() tokenizer, formatter = get_llama_tokenizer()
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages] typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
encoded = formatter.encode_dialog_prompt(typed_messages, None) encoded = formatter.encode_dialog_prompt(typed_messages, None)
@ -307,9 +286,7 @@ def test_text_chat_completion_first_token_profiling(
message_content = response.completion_message.content.lower().strip() message_content = response.completion_message.content.lower().strip()
assert len(message_content) > 0 assert len(message_content) > 0
if os.environ.get( if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
"DEBUG_TTFT"
): # debugging print number of tokens in response, ideally around 150
tokenizer, formatter = get_llama_tokenizer() tokenizer, formatter = get_llama_tokenizer()
encoded = formatter.encode_content(message_content) encoded = formatter.encode_content(message_content)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0) raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
@ -332,9 +309,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
messages=[{"role": "user", "content": question}], messages=[{"role": "user", "content": question}],
stream=True, stream=True,
) )
streamed_content = [ streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
str(chunk.event.delta.text.lower().strip()) for chunk in response
]
assert len(streamed_content) > 0 assert len(streamed_content) > 0
assert expected.lower() in "".join(streamed_content) assert expected.lower() in "".join(streamed_content)
@ -345,9 +320,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_calling_and_non_streaming( def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case) tc = TestCase(test_case)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
@ -361,10 +334,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
assert response.completion_message.role == "assistant" assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1 assert len(response.completion_message.tool_calls) == 1
assert ( assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
response.completion_message.tool_calls[0].tool_name
== tc["tools"][0]["tool_name"]
)
assert response.completion_message.tool_calls[0].arguments == tc["expected"] assert response.completion_message.tool_calls[0].arguments == tc["expected"]
@ -387,9 +357,7 @@ def extract_tool_invocation_content(response):
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_calling_and_streaming( def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case) tc = TestCase(test_case)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
@ -415,9 +383,7 @@ def test_text_chat_completion_with_tool_choice_required(
client_with_models, text_model_id, test_case, inference_provider_type client_with_models, text_model_id, test_case, inference_provider_type
): ):
if inference_provider_type == "remote::tgi": if inference_provider_type == "remote::tgi":
pytest.xfail( pytest.xfail(f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet")
f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet"
)
tc = TestCase(test_case) tc = TestCase(test_case)
@ -442,9 +408,7 @@ def test_text_chat_completion_with_tool_choice_required(
"inference:chat_completion:tool_calling", "inference:chat_completion:tool_calling",
], ],
) )
def test_text_chat_completion_with_tool_choice_none( def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case) tc = TestCase(test_case)
response = client_with_models.inference.chat_completion( response = client_with_models.inference.chat_completion(
@ -464,13 +428,9 @@ def test_text_chat_completion_with_tool_choice_none(
"inference:chat_completion:structured_output", "inference:chat_completion:structured_output",
], ],
) )
def test_text_chat_completion_structured_output( def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
client_with_models, text_model_id, test_case, inference_provider_type
):
if inference_provider_type == "remote::tgi": if inference_provider_type == "remote::tgi":
pytest.xfail( pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
f"{inference_provider_type} doesn't support structured outputs yet"
)
class NBAStats(BaseModel): class NBAStats(BaseModel):
year_for_draft: int year_for_draft: int

View file

@ -27,9 +27,7 @@ def base64_image_url(base64_image_data, image_path):
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
# @pytest.mark.xfail( @pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
# reason="This test is failing because the image is not being downloaded correctly."
# )
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id): def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
message = { message = {
"role": "user", "role": "user",
@ -58,9 +56,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
# @pytest.mark.xfail( @pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
# reason="This test is failing because the image is not being downloaded correctly."
# )
def test_image_chat_completion_streaming(client_with_models, vision_model_id): def test_image_chat_completion_streaming(client_with_models, vision_model_id):
message = { message = {
"role": "user", "role": "user",
@ -92,9 +88,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
@pytest.mark.parametrize("type_", ["url"]) @pytest.mark.parametrize("type_", ["url"])
def test_image_chat_completion_base64( def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
client_with_models, vision_model_id, base64_image_data, base64_image_url, type_
):
image_spec = { image_spec = {
"url": { "url": {
"type": "image", "type": "image",