chore: mypy for remote::vllm

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-03-25 11:11:15 -04:00
parent d5b52923c1
commit 3e6c47ce10
5 changed files with 37 additions and 28 deletions

View file

@ -431,7 +431,7 @@ class Inference(Protocol):
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
""" """
model_store: ModelStore model_store: ModelStore | None = None
@webmethod(route="/inference/completion", method="POST") @webmethod(route="/inference/completion", method="POST")
async def completion( async def completion(

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import logging import logging
from typing import AsyncGenerator, List, Optional, Union from typing import Any, AsyncGenerator, List, Optional, Union
import httpx import httpx
from openai import AsyncOpenAI from openai import AsyncOpenAI
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
GrammarResponseFormat,
Inference, Inference,
JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
ToolChoice, ToolChoice,
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compat_tools = [] compat_tools = []
for tool in tools: for tool in tools:
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
compat_tools.append(compat_tool) compat_tools.append(compat_tool)
if len(compat_tools) > 0: return compat_tools
return compat_tools
return None
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
) )
elif choice.delta.tool_calls: elif choice.delta.tool_calls:
tool_call = convert_tool_call(choice.delta.tool_calls[0]) tool_call = convert_tool_call(choice.delta.tool_calls[0])
tool_call_buf.tool_name += tool_call.tool_name tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += tool_call.arguments # TODO: remove str() when dict type for 'arguments' is no longer allowed
tool_call_buf.arguments += str(tool_call.arguments)
else: else:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
@ -248,10 +245,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = self.model_store.get_model(model_id)
request = CompletionRequest( request = CompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
content=content, content=content,
@ -270,17 +268,18 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk]:
assert self.model_store is not None
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id) model = self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References: # References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
@ -318,11 +317,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) )
return result return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator: async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]:
params = await self._get_params(request) params = await self._get_params(request)
stream = await client.chat.completions.create(**params) stream = await client.chat.completions.create(**params)
if len(request.tools) > 0: if request.tools:
res = _process_vllm_chat_completion_stream_response(stream) res = _process_vllm_chat_completion_stream_response(stream)
else: else:
res = process_chat_completion_stream_response(stream, request) res = process_chat_completion_stream_response(stream, request)
@ -330,18 +331,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
assert self.client is not None
params = await self._get_params(request) params = await self._get_params(request)
r = await self.client.completions.create(**params) r = await self.client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator[CompletionResponseStreamChunk]:
assert self.client is not None
params = await self._get_params(request) params = await self._get_params(request)
stream = await self.client.completions.create(**params) stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> None:
assert self.client is not None
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
res = await self.client.models.list() res = await self.client.models.list()
available_models = [m.id async for m in res] available_models = [m.id async for m in res]
@ -350,14 +354,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
f"Model {model.provider_resource_id} is not being served by vLLM. " f"Model {model.provider_resource_id} is not being served by vLLM. "
f"Available models: {', '.join(available_models)}" f"Available models: {', '.join(available_models)}"
) )
return model
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
options = get_sampling_options(request.sampling_params) options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options: if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens options["max_tokens"] = self.config.max_tokens
input_dict = {} input_dict: dict[str, Any] = {}
if isinstance(request, ChatCompletionRequest) and request.tools is not None: if isinstance(request, ChatCompletionRequest) and request.tools is not None:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)} input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
@ -368,9 +371,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
input_dict["prompt"] = await completion_request_to_prompt(request) input_dict["prompt"] = await completion_request_to_prompt(request)
if fmt := request.response_format: if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value: if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema} input_dict["extra_body"] = {"guided_json": fmt.json_schema}
elif fmt.type == ResponseFormatType.grammar.value: elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format not supported yet") raise NotImplementedError("Grammar response format not supported yet")
else: else:
raise ValueError(f"Unknown response format {fmt.type}") raise ValueError(f"Unknown response format {fmt.type}")
@ -393,7 +396,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None, output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None, task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) assert self.client is not None
model = self.model_store.get_model(model_id)
kwargs = {} kwargs = {}
assert model.model_type == ModelType.embedding assert model.model_type == ModelType.embedding

View file

@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
) )
return model return model
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
return options return options
def get_sampling_options(params: SamplingParams) -> dict: def get_sampling_options(params: SamplingParams | None) -> dict:
if not params:
return {}
options = {} options = {}
if params: if params:
options.update(get_sampling_strategy_options(params)) options.update(get_sampling_strategy_options(params))

View file

@ -253,7 +253,6 @@ exclude = [
"^llama_stack/providers/remote/inference/sample/", "^llama_stack/providers/remote/inference/sample/",
"^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/tgi/",
"^llama_stack/providers/remote/inference/together/", "^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/vllm/",
"^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/nvidia/", "^llama_stack/providers/remote/safety/nvidia/",
"^llama_stack/providers/remote/safety/sample/", "^llama_stack/providers/remote/safety/sample/",