mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
fixes and linting
This commit is contained in:
parent
021dd0d35d
commit
5251d2422d
8 changed files with 149 additions and 345 deletions
|
@ -56,9 +56,7 @@ from .models import MODEL_ENTRIES
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
|
||||||
):
|
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -70,9 +68,7 @@ class FireworksInferenceAdapter(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_api_key(self) -> str:
|
def _get_api_key(self) -> str:
|
||||||
config_api_key = (
|
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||||
self.config.api_key.get_secret_value() if self.config.api_key else None
|
|
||||||
)
|
|
||||||
if config_api_key:
|
if config_api_key:
|
||||||
return config_api_key
|
return config_api_key
|
||||||
else:
|
else:
|
||||||
|
@ -112,9 +108,7 @@ class FireworksInferenceAdapter(
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
self, request: CompletionRequest
|
|
||||||
) -> CompletionResponse:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
@ -194,9 +188,7 @@ class FireworksInferenceAdapter(
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await self._get_client().chat.completions.acreate(**params)
|
r = await self._get_client().chat.completions.acreate(**params)
|
||||||
|
@ -204,9 +196,7 @@ class FireworksInferenceAdapter(
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
self, request: ChatCompletionRequest
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -221,9 +211,7 @@ class FireworksInferenceAdapter(
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
|
||||||
) -> dict:
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
|
|
||||||
|
@ -231,17 +219,12 @@ class FireworksInferenceAdapter(
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present or not llama_model:
|
if media_present or not llama_model:
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict(m, download=True)
|
await convert_message_to_openai_dict(m, download=True) for m in request.messages
|
||||||
for m in request.messages
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
request, llama_model
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert (
|
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||||
not media_present
|
|
||||||
), "Fireworks does not support media for Completion requests"
|
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
|
@ -253,9 +236,7 @@ class FireworksInferenceAdapter(
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(
|
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||||
request.sampling_params, request.response_format, request.logprobs
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
logger.debug(f"params to fireworks: {params}")
|
logger.debug(f"params to fireworks: {params}")
|
||||||
|
|
||||||
|
@ -274,9 +255,9 @@ class FireworksInferenceAdapter(
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.metadata.get("embedding_dimension"):
|
if model.metadata.get("embedding_dimension"):
|
||||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||||
assert all(
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
not content_has_media(content) for content in contents
|
"Fireworks does not support media for embeddings"
|
||||||
), "Fireworks does not support media for embeddings"
|
)
|
||||||
response = self._get_client().embeddings.create(
|
response = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
|
|
|
@ -5,10 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi, InferenceClient
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -16,12 +15,10 @@ from llama_stack.apis.common.content_types import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -38,26 +35,20 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_hf_repo_model_entry,
|
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
|
build_hf_repo_model_entry,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_chat_completion_request_to_openai_params,
|
|
||||||
convert_completion_request_to_openai_params,
|
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
convert_openai_chat_completion_choice,
|
|
||||||
convert_openai_chat_completion_stream,
|
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
get_sampling_options,
|
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
process_chat_completion_response,
|
convert_chat_completion_request_to_openai_params,
|
||||||
process_chat_completion_stream_response,
|
convert_openai_chat_completion_choice,
|
||||||
|
convert_openai_chat_completion_stream,
|
||||||
|
get_sampling_options,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_model_input_info,
|
|
||||||
completion_request_to_prompt_model_input_info,
|
completion_request_to_prompt_model_input_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,9 +76,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||||
self.huggingface_repo_to_llama_model_id = {
|
self.huggingface_repo_to_llama_model_id = {
|
||||||
model.huggingface_repo: model.descriptor()
|
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||||
for model in all_registered_models()
|
|
||||||
if model.huggingface_repo
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
@ -114,7 +103,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if response_format:
|
if response_format:
|
||||||
raise ValueError(f"TGI does not support Response Format for completions.")
|
raise ValueError("TGI does not support Response Format for completions.")
|
||||||
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
@ -166,17 +155,13 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||||
prompt, input_tokens = await completion_request_to_prompt_model_input_info(
|
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
|
||||||
request
|
|
||||||
)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
details=True,
|
details=True,
|
||||||
max_new_tokens=self._get_max_new_tokens(
|
max_new_tokens=self._get_max_new_tokens(request.sampling_params, input_tokens),
|
||||||
request.sampling_params, input_tokens
|
|
||||||
),
|
|
||||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
)
|
)
|
||||||
|
@ -185,16 +170,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
params = await self._get_params_for_completion(request)
|
params = await self._get_params_for_completion(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = self.client.text_generation(**params)
|
s = await self.client.text_generation(**params)
|
||||||
for chunk in s:
|
async for chunk in s:
|
||||||
token_result = chunk.token
|
token_result = chunk.token
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
if chunk.details:
|
if chunk.details:
|
||||||
finish_reason = chunk.details.finish_reason
|
finish_reason = chunk.details.finish_reason
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(text=token_result.text, finish_reason=finish_reason)
|
||||||
text=token_result.text, finish_reason=finish_reason
|
|
||||||
)
|
|
||||||
yield OpenAICompatCompletionResponse(
|
yield OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
@ -205,7 +188,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params_for_completion(request)
|
params = await self._get_params_for_completion(request)
|
||||||
r = self.client.text_generation(**params)
|
r = await self.client.text_generation(**params)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=r.details.finish_reason,
|
finish_reason=r.details.finish_reason,
|
||||||
|
@ -234,9 +217,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
from rich.pretty import pprint
|
|
||||||
|
|
||||||
pprint(messages)
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -250,18 +230,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
params = await convert_chat_completion_request_to_openai_params(request)
|
params = await convert_chat_completion_request_to_openai_params(request)
|
||||||
|
|
||||||
import json
|
response = await self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
# print(json.dumps(params, indent=2))
|
|
||||||
|
|
||||||
pprint(params)
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return convert_openai_chat_completion_stream(
|
return convert_openai_chat_completion_stream(response, enable_incremental_tool_calls=True)
|
||||||
response, enable_incremental_tool_calls=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return convert_openai_chat_completion_choice(response.choices[0])
|
return convert_openai_chat_completion_choice(response.choices[0])
|
||||||
|
|
||||||
|
@ -281,18 +252,16 @@ class TGIAdapter(_HfAdapter):
|
||||||
logger.info(f"Initializing TGI client with url={config.url}")
|
logger.info(f"Initializing TGI client with url={config.url}")
|
||||||
# unfortunately, the TGI async client does not work well with proxies
|
# unfortunately, the TGI async client does not work well with proxies
|
||||||
# so using sync client for now instead
|
# so using sync client for now instead
|
||||||
self.client = InferenceClient(model=f"{config.url}")
|
self.client = AsyncInferenceClient(model=f"{config.url}")
|
||||||
|
|
||||||
endpoint_info = self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
|
||||||
|
|
||||||
class InferenceAPIAdapter(_HfAdapter):
|
class InferenceAPIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
async def initialize(self, config: InferenceAPIImplConfig) -> None:
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
|
||||||
model=config.huggingface_repo, token=config.api_token.get_secret_value()
|
|
||||||
)
|
|
||||||
endpoint_info = await self.client.get_endpoint_info()
|
endpoint_info = await self.client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
|
@ -310,6 +279,4 @@ class InferenceEndpointAdapter(_HfAdapter):
|
||||||
# Initialize the adapter
|
# Initialize the adapter
|
||||||
self.client = endpoint.async_client
|
self.client = endpoint.async_client
|
||||||
self.model_id = endpoint.repository
|
self.model_id = endpoint.repository
|
||||||
self.max_tokens = int(
|
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
|
||||||
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
|
|
||||||
)
|
|
||||||
|
|
|
@ -19,7 +19,6 @@ from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -36,11 +35,8 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_chat_completion_request_to_openai_params,
|
convert_chat_completion_request_to_openai_params,
|
||||||
convert_message_to_openai_dict_new,
|
|
||||||
convert_openai_chat_completion_choice,
|
convert_openai_chat_completion_choice,
|
||||||
convert_openai_chat_completion_stream,
|
convert_openai_chat_completion_stream,
|
||||||
convert_tooldef_to_openai_tool,
|
|
||||||
get_sampling_options,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
@ -54,9 +50,7 @@ class LiteLLMOpenAIMixin(
|
||||||
Inference,
|
Inference,
|
||||||
NeedsRequestProviderData,
|
NeedsRequestProviderData,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
|
||||||
self, model_entries, api_key_from_config: str, provider_data_api_key_field: str
|
|
||||||
):
|
|
||||||
ModelRegistryHelper.__init__(self, model_entries)
|
ModelRegistryHelper.__init__(self, model_entries)
|
||||||
self.api_key_from_config = api_key_from_config
|
self.api_key_from_config = api_key_from_config
|
||||||
self.provider_data_api_key_field = provider_data_api_key_field
|
self.provider_data_api_key_field = provider_data_api_key_field
|
||||||
|
@ -96,9 +90,7 @@ class LiteLLMOpenAIMixin(
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
|
||||||
]:
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
|
@ -6,7 +6,49 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
|
from openai import AsyncStream
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import (
|
||||||
|
Choice as OpenAIChoice,
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion import (
|
||||||
|
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||||
|
)
|
||||||
|
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||||
|
ImageURL as OpenAIImageURL,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
|
@ -23,7 +65,6 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
CompletionRequest,
|
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
|
@ -49,32 +90,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
decode_assistant_message,
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
from openai import AsyncStream, Stream
|
|
||||||
from openai.types.chat import (
|
|
||||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
|
||||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
|
||||||
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
|
||||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
|
||||||
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
|
|
||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
|
||||||
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageToolCall,
|
|
||||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCallParam,
|
|
||||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
|
||||||
ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage,
|
|
||||||
ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion import (
|
|
||||||
Choice as OpenAIChoice,
|
|
||||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_content_part_image_param import (
|
|
||||||
ImageURL as OpenAIImageURL,
|
|
||||||
)
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
|
||||||
Function as OpenAIFunction,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -169,16 +184,12 @@ def convert_openai_completion_logprobs(
|
||||||
if logprobs.tokens and logprobs.token_logprobs:
|
if logprobs.tokens and logprobs.token_logprobs:
|
||||||
return [
|
return [
|
||||||
TokenLogProbs(logprobs_by_token={token: token_lp})
|
TokenLogProbs(logprobs_by_token={token: token_lp})
|
||||||
for token, token_lp in zip(
|
for token, token_lp in zip(logprobs.tokens, logprobs.token_logprobs, strict=False)
|
||||||
logprobs.tokens, logprobs.token_logprobs, strict=False
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_completion_logprobs_stream(
|
def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]):
|
||||||
text: str, logprobs: Optional[Union[float, OpenAICompatLogprobs]]
|
|
||||||
):
|
|
||||||
if logprobs is None:
|
if logprobs is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(logprobs, float):
|
if isinstance(logprobs, float):
|
||||||
|
@ -223,9 +234,7 @@ def process_chat_completion_response(
|
||||||
if not choice.message or not choice.message.tool_calls:
|
if not choice.message or not choice.message.tool_calls:
|
||||||
raise ValueError("Tool calls are not present in the response")
|
raise ValueError("Tool calls are not present in the response")
|
||||||
|
|
||||||
tool_calls = [
|
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||||
convert_tool_call(tool_call) for tool_call in choice.message.tool_calls
|
|
||||||
]
|
|
||||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
|
@ -249,9 +258,7 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||||
raw_message = decode_assistant_message(
|
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||||
# expect the ToolCall in the response. Instead, we should return the raw
|
# expect the ToolCall in the response. Instead, we should return the raw
|
||||||
|
@ -452,17 +459,13 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict(
|
async def convert_message_to_openai_dict(message: Message, download: bool = False) -> dict:
|
||||||
message: Message, download: bool = False
|
|
||||||
) -> dict:
|
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
if isinstance(content, ImageContentItem):
|
if isinstance(content, ImageContentItem):
|
||||||
return {
|
return {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": await convert_image_content_to_url(
|
"url": await convert_image_content_to_url(content, download=download),
|
||||||
content, download=download
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
@ -541,9 +544,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
elif isinstance(content_, ImageContentItem):
|
elif isinstance(content_, ImageContentItem):
|
||||||
return OpenAIChatCompletionContentPartImageParam(
|
return OpenAIChatCompletionContentPartImageParam(
|
||||||
type="image_url",
|
type="image_url",
|
||||||
image_url=OpenAIImageURL(
|
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
|
||||||
url=await convert_image_content_to_url(content_)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif isinstance(content_, list):
|
elif isinstance(content_, list):
|
||||||
return [await impl(item) for item in content_]
|
return [await impl(item) for item in content_]
|
||||||
|
@ -587,9 +588,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
"id": tool.call_id,
|
"id": tool.call_id,
|
||||||
"function": {
|
"function": {
|
||||||
"name": (
|
"name": (
|
||||||
tool.tool_name
|
tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value
|
||||||
if not isinstance(tool.tool_name, BuiltinTool)
|
|
||||||
else tool.tool_name.value
|
|
||||||
),
|
),
|
||||||
"arguments": json.dumps(tool.arguments),
|
"arguments": json.dumps(tool.arguments),
|
||||||
},
|
},
|
||||||
|
@ -709,11 +708,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||||
properties = parameters["properties"]
|
properties = parameters["properties"]
|
||||||
required = []
|
required = []
|
||||||
for param_name, param in tool.parameters.items():
|
for param_name, param in tool.parameters.items():
|
||||||
properties[param_name] = {
|
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
|
||||||
"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(
|
|
||||||
param.param_type, param.param_type
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if param.description:
|
if param.description:
|
||||||
properties[param_name].update(description=param.description)
|
properties[param_name].update(description=param.description)
|
||||||
if param.default:
|
if param.default:
|
||||||
|
@ -834,11 +829,7 @@ def _convert_openai_logprobs(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return [
|
return [
|
||||||
TokenLogProbs(
|
TokenLogProbs(logprobs_by_token={logprobs.token: logprobs.logprob for logprobs in content.top_logprobs})
|
||||||
logprobs_by_token={
|
|
||||||
logprobs.token: logprobs.logprob for logprobs in content.top_logprobs
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for content in logprobs.content
|
for content in logprobs.content
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -876,17 +867,14 @@ def convert_openai_chat_completion_choice(
|
||||||
end_of_message = "end_of_message"
|
end_of_message = "end_of_message"
|
||||||
out_of_tokens = "out_of_tokens"
|
out_of_tokens = "out_of_tokens"
|
||||||
"""
|
"""
|
||||||
assert (
|
assert hasattr(choice, "message") and choice.message, "error in server response: message not found"
|
||||||
hasattr(choice, "message") and choice.message
|
assert hasattr(choice, "finish_reason") and choice.finish_reason, (
|
||||||
), "error in server response: message not found"
|
"error in server response: finish_reason not found"
|
||||||
assert (
|
)
|
||||||
hasattr(choice, "finish_reason") and choice.finish_reason
|
|
||||||
), "error in server response: finish_reason not found"
|
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=choice.message.content
|
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||||
or "", # CompletionMessage content is not optional
|
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||||
),
|
),
|
||||||
|
@ -895,9 +883,7 @@ def convert_openai_chat_completion_choice(
|
||||||
|
|
||||||
|
|
||||||
async def convert_openai_chat_completion_stream(
|
async def convert_openai_chat_completion_stream(
|
||||||
stream: Union[
|
stream: AsyncStream[OpenAIChatCompletionChunk],
|
||||||
AsyncStream[OpenAIChatCompletionChunk], Stream[OpenAIChatCompletionChunk]
|
|
||||||
],
|
|
||||||
enable_incremental_tool_calls: bool,
|
enable_incremental_tool_calls: bool,
|
||||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
"""
|
"""
|
||||||
|
@ -905,14 +891,6 @@ async def convert_openai_chat_completion_stream(
|
||||||
of ChatCompletionResponseStreamChunk.
|
of ChatCompletionResponseStreamChunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def yield_from_stream(stream):
|
|
||||||
if isinstance(stream, AsyncGenerator):
|
|
||||||
async for chunk in stream:
|
|
||||||
yield chunk
|
|
||||||
elif isinstance(stream, Generator):
|
|
||||||
for chunk in stream:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
@ -924,7 +902,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
tool_call_idx_to_buffer = {}
|
tool_call_idx_to_buffer = {}
|
||||||
|
|
||||||
async for chunk in yield_from_stream(stream):
|
async for chunk in stream:
|
||||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||||
|
|
||||||
# we assume there's only one finish_reason in the stream
|
# we assume there's only one finish_reason in the stream
|
||||||
|
@ -1092,7 +1070,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -1137,7 +1115,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (KeyError, json.JSONDecodeError) as e:
|
except (KeyError, json.JSONDecodeError):
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -1158,26 +1136,6 @@ async def convert_openai_chat_completion_stream(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def convert_completion_request_to_openai_params(
|
|
||||||
request: CompletionRequest,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
|
||||||
"""
|
|
||||||
input_dict = {}
|
|
||||||
if request.logprobs:
|
|
||||||
input_dict["logprobs"] = request.logprobs.top_k
|
|
||||||
|
|
||||||
return {
|
|
||||||
"model": request.model,
|
|
||||||
"prompt": request.content,
|
|
||||||
**input_dict,
|
|
||||||
"stream": request.stream,
|
|
||||||
**get_sampling_options(request.sampling_params),
|
|
||||||
"n": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_chat_completion_request_to_openai_params(
|
async def convert_chat_completion_request_to_openai_params(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -1186,14 +1144,10 @@ async def convert_chat_completion_request_to_openai_params(
|
||||||
"""
|
"""
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
|
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
|
||||||
await convert_message_to_openai_dict_new(m) for m in request.messages
|
|
||||||
]
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported.")
|
||||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
fmt = fmt.json_schema
|
fmt = fmt.json_schema
|
||||||
name = fmt["title"]
|
name = fmt["title"]
|
||||||
|
@ -1217,9 +1171,7 @@ async def convert_chat_completion_request_to_openai_params(
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.tools:
|
if request.tools:
|
||||||
input_dict["tools"] = [
|
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
convert_tooldef_to_openai_tool(tool) for tool in request.tools
|
|
||||||
]
|
|
||||||
if request.tool_config.tool_choice:
|
if request.tool_config.tool_choice:
|
||||||
input_dict["tool_choice"] = (
|
input_dict["tool_choice"] = (
|
||||||
request.tool_config.tool_choice.value
|
request.tool_config.tool_choice.value
|
||||||
|
|
|
@ -12,6 +12,7 @@ import re
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
|
@ -33,7 +34,6 @@ from llama_stack.apis.inference import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
is_multimodal,
|
|
||||||
ModelFamily,
|
ModelFamily,
|
||||||
RawContent,
|
RawContent,
|
||||||
RawContentItem,
|
RawContentItem,
|
||||||
|
@ -43,6 +43,7 @@ from llama_stack.models.llama.datatypes import (
|
||||||
Role,
|
Role,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
is_multimodal,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
|
@ -55,7 +56,6 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
from PIL import Image as PIL_Image
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
@ -129,9 +129,7 @@ async def interleaved_content_convert_to_raw(
|
||||||
if image.url.uri.startswith("data"):
|
if image.url.uri.startswith("data"):
|
||||||
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
|
||||||
f"Invalid data URL format, {image.url.uri[:40]}..."
|
|
||||||
)
|
|
||||||
_, image_data = match.groups()
|
_, image_data = match.groups()
|
||||||
data = base64.b64decode(image_data)
|
data = base64.b64decode(image_data)
|
||||||
elif image.url.uri.startswith("file://"):
|
elif image.url.uri.startswith("file://"):
|
||||||
|
@ -211,17 +209,13 @@ async def convert_image_content_to_url(
|
||||||
|
|
||||||
content, format = await localize_image_content(media)
|
content, format = await localize_image_content(media)
|
||||||
if include_format:
|
if include_format:
|
||||||
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return base64.b64encode(content).decode("utf-8")
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||||
content = augment_content_with_response_format_prompt(
|
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||||
request.response_format, request.content
|
|
||||||
)
|
|
||||||
request.content = content
|
request.content = content
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
@ -233,9 +227,7 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||||
async def completion_request_to_prompt_model_input_info(
|
async def completion_request_to_prompt_model_input_info(
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
content = augment_content_with_response_format_prompt(
|
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||||
request.response_format, request.content
|
|
||||||
)
|
|
||||||
request.content = content
|
request.content = content
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
@ -256,9 +248,7 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_prompt(
|
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||||
request: ChatCompletionRequest, llama_model: str
|
|
||||||
) -> str:
|
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
request.messages = messages
|
request.messages = messages
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
@ -266,8 +256,7 @@ async def chat_completion_request_to_prompt(
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(
|
model_input = formatter.encode_dialog_prompt(
|
||||||
request.messages,
|
request.messages,
|
||||||
tool_prompt_format=request.tool_config.tool_prompt_format
|
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||||
or get_default_tool_prompt_format(llama_model),
|
|
||||||
)
|
)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
@ -282,8 +271,7 @@ async def chat_completion_request_to_model_input_info(
|
||||||
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(
|
model_input = formatter.encode_dialog_prompt(
|
||||||
request.messages,
|
request.messages,
|
||||||
tool_prompt_format=request.tool_config.tool_prompt_format
|
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
|
||||||
or get_default_tool_prompt_format(llama_model),
|
|
||||||
)
|
)
|
||||||
tokens = []
|
tokens = []
|
||||||
for t in model_input.tokens:
|
for t in model_input.tokens:
|
||||||
|
@ -318,8 +306,7 @@ def chat_completion_request_to_messages(
|
||||||
return request.messages
|
return request.messages
|
||||||
|
|
||||||
if model.model_family == ModelFamily.llama3_1 or (
|
if model.model_family == ModelFamily.llama3_1 or (
|
||||||
model.model_family == ModelFamily.llama3_2
|
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
|
||||||
and is_multimodal(model.core_model_id)
|
|
||||||
):
|
):
|
||||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||||
messages = augment_messages_for_tools_llama_3_1(request)
|
messages = augment_messages_for_tools_llama_3_1(request)
|
||||||
|
@ -355,9 +342,7 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
if existing_messages[0].role == Role.system.value:
|
if existing_messages[0].role == Role.system.value:
|
||||||
existing_system_message = existing_messages.pop(0)
|
existing_system_message = existing_messages.pop(0)
|
||||||
|
|
||||||
assert (
|
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||||
existing_messages[0].role != Role.system.value
|
|
||||||
), "Should only have 1 system message"
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
|
@ -389,13 +374,9 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
if isinstance(existing_system_message.content, str):
|
if isinstance(existing_system_message.content, str):
|
||||||
sys_content += _process(existing_system_message.content)
|
sys_content += _process(existing_system_message.content)
|
||||||
elif isinstance(existing_system_message.content, list):
|
elif isinstance(existing_system_message.content, list):
|
||||||
sys_content += "\n".join(
|
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
|
||||||
[_process(c) for c in existing_system_message.content]
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_choice_prompt = _get_tool_choice_prompt(
|
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||||
request.tool_config.tool_choice, request.tools
|
|
||||||
)
|
|
||||||
if tool_choice_prompt:
|
if tool_choice_prompt:
|
||||||
sys_content += "\n" + tool_choice_prompt
|
sys_content += "\n" + tool_choice_prompt
|
||||||
|
|
||||||
|
@ -429,9 +410,7 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
if existing_messages[0].role == Role.system.value:
|
if existing_messages[0].role == Role.system.value:
|
||||||
existing_system_message = existing_messages.pop(0)
|
existing_system_message = existing_messages.pop(0)
|
||||||
|
|
||||||
assert (
|
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
|
||||||
existing_messages[0].role != Role.system.value
|
|
||||||
), "Should only have 1 system message"
|
|
||||||
|
|
||||||
sys_content = ""
|
sys_content = ""
|
||||||
custom_tools, builtin_tools = [], []
|
custom_tools, builtin_tools = [], []
|
||||||
|
@ -452,16 +431,10 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
if custom_tools:
|
if custom_tools:
|
||||||
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
|
||||||
if fmt != ToolPromptFormat.python_list:
|
if fmt != ToolPromptFormat.python_list:
|
||||||
raise ValueError(
|
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
|
||||||
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
system_prompt = None
|
system_prompt = None
|
||||||
if (
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
existing_system_message
|
|
||||||
and request.tool_config.system_message_behavior
|
|
||||||
== SystemMessageBehavior.replace
|
|
||||||
):
|
|
||||||
system_prompt = existing_system_message.content
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
||||||
|
@ -470,16 +443,11 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
||||||
if existing_system_message and (
|
if existing_system_message and (
|
||||||
request.tool_config.system_message_behavior == SystemMessageBehavior.append
|
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
|
||||||
or not custom_tools
|
|
||||||
):
|
):
|
||||||
sys_content += interleaved_content_as_str(
|
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
|
||||||
existing_system_message.content, sep="\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_choice_prompt = _get_tool_choice_prompt(
|
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
|
||||||
request.tool_config.tool_choice, request.tools
|
|
||||||
)
|
|
||||||
if tool_choice_prompt:
|
if tool_choice_prompt:
|
||||||
sys_content += "\n" + tool_choice_prompt
|
sys_content += "\n" + tool_choice_prompt
|
||||||
|
|
||||||
|
@ -487,15 +455,11 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_choice_prompt(
|
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
|
||||||
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
|
|
||||||
) -> str:
|
|
||||||
if tool_choice == ToolChoice.auto:
|
if tool_choice == ToolChoice.auto:
|
||||||
return ""
|
return ""
|
||||||
elif tool_choice == ToolChoice.required:
|
elif tool_choice == ToolChoice.required:
|
||||||
return (
|
return "You MUST use one of the provided functions/tools to answer the user query."
|
||||||
"You MUST use one of the provided functions/tools to answer the user query."
|
|
||||||
)
|
|
||||||
elif tool_choice == ToolChoice.none:
|
elif tool_choice == ToolChoice.none:
|
||||||
# tools are already not passed in
|
# tools are already not passed in
|
||||||
return ""
|
return ""
|
||||||
|
@ -507,14 +471,11 @@ def _get_tool_choice_prompt(
|
||||||
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
||||||
llama_model = resolve_model(model)
|
llama_model = resolve_model(model)
|
||||||
if llama_model is None:
|
if llama_model is None:
|
||||||
log.warning(
|
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
|
||||||
f"Could not resolve model {model}, defaulting to json tool prompt format"
|
|
||||||
)
|
|
||||||
return ToolPromptFormat.json
|
return ToolPromptFormat.json
|
||||||
|
|
||||||
if llama_model.model_family == ModelFamily.llama3_1 or (
|
if llama_model.model_family == ModelFamily.llama3_1 or (
|
||||||
llama_model.model_family == ModelFamily.llama3_2
|
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
||||||
and is_multimodal(llama_model.core_model_id)
|
|
||||||
):
|
):
|
||||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||||
return ToolPromptFormat.json
|
return ToolPromptFormat.json
|
||||||
|
|
|
@ -120,16 +120,13 @@ def client_with_models(
|
||||||
judge_model_id,
|
judge_model_id,
|
||||||
):
|
):
|
||||||
client = llama_stack_client
|
client = llama_stack_client
|
||||||
from rich.pretty import pprint
|
|
||||||
|
|
||||||
providers = [p for p in client.providers.list() if p.api == "inference"]
|
providers = [p for p in client.providers.list() if p.api == "inference"]
|
||||||
pprint(providers)
|
|
||||||
assert len(providers) > 0, "No inference providers found"
|
assert len(providers) > 0, "No inference providers found"
|
||||||
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
||||||
|
|
||||||
model_ids = {m.identifier for m in client.models.list()}
|
model_ids = {m.identifier for m in client.models.list()}
|
||||||
model_ids.update(m.provider_resource_id for m in client.models.list())
|
model_ids.update(m.provider_resource_id for m in client.models.list())
|
||||||
pprint(model_ids)
|
|
||||||
|
|
||||||
if text_model_id and text_model_id not in model_ids:
|
if text_model_id and text_model_id not in model_ids:
|
||||||
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
|
||||||
|
|
|
@ -8,9 +8,9 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..test_cases.test_case import TestCase
|
from ..test_cases.test_case import TestCase
|
||||||
|
|
||||||
|
@ -29,9 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
||||||
"remote::gemini",
|
"remote::gemini",
|
||||||
"remote::groq",
|
"remote::groq",
|
||||||
):
|
):
|
||||||
pytest.skip(
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||||
f"Model {model_id} hosted by {provider.provider_type} doesn't support completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_llama_model(client_with_models, model_id):
|
def get_llama_model(client_with_models, model_id):
|
||||||
|
@ -112,9 +110,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
|
||||||
"inference:completion:stop_sequence",
|
"inference:completion:stop_sequence",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_stop_sequence(
|
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||||
client_with_models, text_model_id, inference_provider_type, test_case
|
|
||||||
):
|
|
||||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
|
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
|
||||||
if inference_provider_type != "remote::vllm":
|
if inference_provider_type != "remote::vllm":
|
||||||
|
@ -141,9 +137,7 @@ def test_text_completion_stop_sequence(
|
||||||
"inference:completion:log_probs",
|
"inference:completion:log_probs",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_log_probs_non_streaming(
|
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||||
client_with_models, text_model_id, inference_provider_type, test_case
|
|
||||||
):
|
|
||||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
@ -162,9 +156,7 @@ def test_text_completion_log_probs_non_streaming(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.logprobs, "Logprobs should not be empty"
|
assert response.logprobs, "Logprobs should not be empty"
|
||||||
assert (
|
assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
|
||||||
1 <= len(response.logprobs) <= 5
|
|
||||||
) # each token has 1 logprob and here max_tokens=5
|
|
||||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,9 +166,7 @@ def test_text_completion_log_probs_non_streaming(
|
||||||
"inference:completion:log_probs",
|
"inference:completion:log_probs",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_log_probs_streaming(
|
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||||
client_with_models, text_model_id, inference_provider_type, test_case
|
|
||||||
):
|
|
||||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
@ -198,9 +188,7 @@ def test_text_completion_log_probs_streaming(
|
||||||
for chunk in streamed_content:
|
for chunk in streamed_content:
|
||||||
if chunk.delta: # if there's a token, we expect logprobs
|
if chunk.delta: # if there's a token, we expect logprobs
|
||||||
assert chunk.logprobs, "Logprobs should not be empty"
|
assert chunk.logprobs, "Logprobs should not be empty"
|
||||||
assert all(
|
assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
|
||||||
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
|
|
||||||
)
|
|
||||||
else: # no token, no logprobs
|
else: # no token, no logprobs
|
||||||
assert not chunk.logprobs, "Logprobs should be empty"
|
assert not chunk.logprobs, "Logprobs should be empty"
|
||||||
|
|
||||||
|
@ -211,13 +199,9 @@ def test_text_completion_log_probs_streaming(
|
||||||
"inference:completion:structured_output",
|
"inference:completion:structured_output",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_completion_structured_output(
|
def test_text_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
|
||||||
client_with_models, text_model_id, test_case, inference_provider_type
|
|
||||||
):
|
|
||||||
if inference_provider_type == "remote::tgi":
|
if inference_provider_type == "remote::tgi":
|
||||||
pytest.xfail(
|
pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
|
||||||
f"{inference_provider_type} doesn't support structured outputs yet"
|
|
||||||
)
|
|
||||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
|
|
||||||
class AnswerFormat(BaseModel):
|
class AnswerFormat(BaseModel):
|
||||||
|
@ -254,9 +238,7 @@ def test_text_completion_structured_output(
|
||||||
"inference:chat_completion:non_streaming_02",
|
"inference:chat_completion:non_streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_non_streaming(
|
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, test_case
|
|
||||||
):
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
@ -282,18 +264,15 @@ def test_text_chat_completion_non_streaming(
|
||||||
"inference:chat_completion:ttft",
|
"inference:chat_completion:ttft",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_first_token_profiling(
|
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, test_case
|
|
||||||
):
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
messages = tc["messages"]
|
messages = tc["messages"]
|
||||||
if os.environ.get(
|
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
|
||||||
"DEBUG_TTFT"
|
|
||||||
): # debugging print number of tokens in input, ideally around 800
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
|
||||||
tokenizer, formatter = get_llama_tokenizer()
|
tokenizer, formatter = get_llama_tokenizer()
|
||||||
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
||||||
encoded = formatter.encode_dialog_prompt(typed_messages, None)
|
encoded = formatter.encode_dialog_prompt(typed_messages, None)
|
||||||
|
@ -307,9 +286,7 @@ def test_text_chat_completion_first_token_profiling(
|
||||||
message_content = response.completion_message.content.lower().strip()
|
message_content = response.completion_message.content.lower().strip()
|
||||||
assert len(message_content) > 0
|
assert len(message_content) > 0
|
||||||
|
|
||||||
if os.environ.get(
|
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
|
||||||
"DEBUG_TTFT"
|
|
||||||
): # debugging print number of tokens in response, ideally around 150
|
|
||||||
tokenizer, formatter = get_llama_tokenizer()
|
tokenizer, formatter = get_llama_tokenizer()
|
||||||
encoded = formatter.encode_content(message_content)
|
encoded = formatter.encode_content(message_content)
|
||||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||||
|
@ -332,9 +309,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
||||||
messages=[{"role": "user", "content": question}],
|
messages=[{"role": "user", "content": question}],
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
streamed_content = [
|
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
|
||||||
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
|
||||||
]
|
|
||||||
assert len(streamed_content) > 0
|
assert len(streamed_content) > 0
|
||||||
assert expected.lower() in "".join(streamed_content)
|
assert expected.lower() in "".join(streamed_content)
|
||||||
|
|
||||||
|
@ -345,9 +320,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, test_case
|
|
||||||
):
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
|
@ -361,10 +334,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||||
assert response.completion_message.role == "assistant"
|
assert response.completion_message.role == "assistant"
|
||||||
|
|
||||||
assert len(response.completion_message.tool_calls) == 1
|
assert len(response.completion_message.tool_calls) == 1
|
||||||
assert (
|
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
|
||||||
response.completion_message.tool_calls[0].tool_name
|
|
||||||
== tc["tools"][0]["tool_name"]
|
|
||||||
)
|
|
||||||
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
|
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -387,9 +357,7 @@ def extract_tool_invocation_content(response):
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, test_case
|
|
||||||
):
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
|
@ -415,9 +383,7 @@ def test_text_chat_completion_with_tool_choice_required(
|
||||||
client_with_models, text_model_id, test_case, inference_provider_type
|
client_with_models, text_model_id, test_case, inference_provider_type
|
||||||
):
|
):
|
||||||
if inference_provider_type == "remote::tgi":
|
if inference_provider_type == "remote::tgi":
|
||||||
pytest.xfail(
|
pytest.xfail(f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet")
|
||||||
f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet"
|
|
||||||
)
|
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
@ -442,9 +408,7 @@ def test_text_chat_completion_with_tool_choice_required(
|
||||||
"inference:chat_completion:tool_calling",
|
"inference:chat_completion:tool_calling",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_with_tool_choice_none(
|
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
||||||
client_with_models, text_model_id, test_case
|
|
||||||
):
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
response = client_with_models.inference.chat_completion(
|
response = client_with_models.inference.chat_completion(
|
||||||
|
@ -464,13 +428,9 @@ def test_text_chat_completion_with_tool_choice_none(
|
||||||
"inference:chat_completion:structured_output",
|
"inference:chat_completion:structured_output",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_text_chat_completion_structured_output(
|
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
|
||||||
client_with_models, text_model_id, test_case, inference_provider_type
|
|
||||||
):
|
|
||||||
if inference_provider_type == "remote::tgi":
|
if inference_provider_type == "remote::tgi":
|
||||||
pytest.xfail(
|
pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
|
||||||
f"{inference_provider_type} doesn't support structured outputs yet"
|
|
||||||
)
|
|
||||||
|
|
||||||
class NBAStats(BaseModel):
|
class NBAStats(BaseModel):
|
||||||
year_for_draft: int
|
year_for_draft: int
|
||||||
|
|
|
@ -27,9 +27,7 @@ def base64_image_url(base64_image_data, image_path):
|
||||||
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.xfail(
|
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||||
# reason="This test is failing because the image is not being downloaded correctly."
|
|
||||||
# )
|
|
||||||
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
|
def test_image_chat_completion_non_streaming(client_with_models, vision_model_id):
|
||||||
message = {
|
message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -58,9 +56,7 @@ def test_image_chat_completion_non_streaming(client_with_models, vision_model_id
|
||||||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.xfail(
|
@pytest.mark.xfail(reason="This test is failing because the image is not being downloaded correctly.")
|
||||||
# reason="This test is failing because the image is not being downloaded correctly."
|
|
||||||
# )
|
|
||||||
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||||
message = {
|
message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -92,9 +88,7 @@ def test_image_chat_completion_streaming(client_with_models, vision_model_id):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", ["url"])
|
@pytest.mark.parametrize("type_", ["url"])
|
||||||
def test_image_chat_completion_base64(
|
def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_):
|
||||||
client_with_models, vision_model_id, base64_image_data, base64_image_url, type_
|
|
||||||
):
|
|
||||||
image_spec = {
|
image_spec = {
|
||||||
"url": {
|
"url": {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue