diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7d3539dcb..98baf846e 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -431,7 +431,7 @@ class Inference(Protocol): - Embedding models: these models generate embeddings to be used for semantic search. """ - model_store: ModelStore + model_store: ModelStore | None = None @webmethod(route="/inference/completion", method="POST") async def completion( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index eda1a179c..ecf41e50d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import json import logging -from typing import AsyncGenerator, List, Optional, Union +from typing import Any, AsyncGenerator, List, Optional, Union import httpx from openai import AsyncOpenAI @@ -32,11 +32,12 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, + GrammarResponseFormat, Inference, + JsonSchemaResponseFormat, LogProbConfig, Message, ResponseFormat, - ResponseFormatType, SamplingParams, TextTruncation, ToolChoice, @@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response( def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: - if tools is None: - return tools - compat_tools = [] for tool in tools: @@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] compat_tools.append(compat_tool) - if len(compat_tools) > 0: - return compat_tools - return None + return compat_tools def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: @@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response( ) elif choice.delta.tool_calls: tool_call = convert_tool_call(choice.delta.tool_calls[0]) - tool_call_buf.tool_name += tool_call.tool_name + tool_call_buf.tool_name += str(tool_call.tool_name) tool_call_buf.call_id += tool_call.call_id - tool_call_buf.arguments += tool_call.arguments + # TODO: remove str() when dict type for 'arguments' is no longer allowed + tool_call_buf.arguments += str(tool_call.arguments) else: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -248,10 +245,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk]: + assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = self.model_store.get_model(model_id) request = CompletionRequest( model=model.provider_resource_id, content=content, @@ -270,17 +268,18 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> AsyncGenerator: + ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk]: + assert self.model_store is not None if sampling_params is None: sampling_params = SamplingParams() - model = await self.model_store.get_model(model_id) + model = self.model_store.get_model(model_id) # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice @@ -318,11 +317,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) return result - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: AsyncOpenAI + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk]: params = await self._get_params(request) stream = await client.chat.completions.create(**params) - if len(request.tools) > 0: + if request.tools: res = _process_vllm_chat_completion_stream_response(stream) else: res = process_chat_completion_stream_response(stream, request) @@ -330,18 +331,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: + assert self.client is not None params = await self._get_params(request) r = await self.client.completions.create(**params) return process_completion_response(r) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator[CompletionResponseStreamChunk]: + assert self.client is not None params = await self._get_params(request) stream = await self.client.completions.create(**params) async for chunk in process_completion_stream_response(stream): yield chunk - async def register_model(self, model: Model) -> Model: + async def register_model(self, model: Model) -> None: + assert self.client is not None model = await self.register_helper.register_model(model) res = await self.client.models.list() available_models = [m.id async for m in res] @@ -350,14 +354,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): f"Model {model.provider_resource_id} is not being served by vLLM. " f"Available models: {', '.join(available_models)}" ) - return model async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: options = get_sampling_options(request.sampling_params) if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens - input_dict = {} + input_dict: dict[str, Any] = {} if isinstance(request, ChatCompletionRequest) and request.tools is not None: input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)} @@ -368,9 +371,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["prompt"] = await completion_request_to_prompt(request) if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - input_dict["extra_body"] = {"guided_json": request.response_format.json_schema} - elif fmt.type == ResponseFormatType.grammar.value: + if isinstance(fmt, JsonSchemaResponseFormat): + input_dict["extra_body"] = {"guided_json": fmt.json_schema} + elif isinstance(fmt, GrammarResponseFormat): raise NotImplementedError("Grammar response format not supported yet") else: raise ValueError(f"Unknown response format {fmt.type}") @@ -393,7 +396,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: - model = await self.model_store.get_model(model_id) + assert self.client is not None + model = self.model_store.get_model(model_id) kwargs = {} assert model.model_type == ModelType.embedding diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index d9e24662a..a11c734df 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate): ) return model + + async def unregister_model(self, model_id: str) -> None: + pass diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 07976e811..04db45e67 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict: return options -def get_sampling_options(params: SamplingParams) -> dict: +def get_sampling_options(params: SamplingParams | None) -> dict: + if not params: + return {} + options = {} if params: options.update(get_sampling_strategy_options(params)) diff --git a/pyproject.toml b/pyproject.toml index 75dfcbb2f..fee0191e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -253,7 +253,6 @@ exclude = [ "^llama_stack/providers/remote/inference/sample/", "^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/together/", - "^llama_stack/providers/remote/inference/vllm/", "^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/nvidia/", "^llama_stack/providers/remote/safety/sample/",