mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
chore: mypy for remote::vllm
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
d5b52923c1
commit
3e6c47ce10
5 changed files with 37 additions and 28 deletions
|
@ -431,7 +431,7 @@ class Inference(Protocol):
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_store: ModelStore
|
model_store: ModelStore | None = None
|
||||||
|
|
||||||
@webmethod(route="/inference/completion", method="POST")
|
@webmethod(route="/inference/completion", method="POST")
|
||||||
async def completion(
|
async def completion(
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import Any, AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
@ -32,11 +32,12 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingTaskType,
|
EmbeddingTaskType,
|
||||||
|
GrammarResponseFormat,
|
||||||
Inference,
|
Inference,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
@ -102,9 +103,6 @@ def _convert_to_vllm_tool_calls_in_response(
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
||||||
if tools is None:
|
|
||||||
return tools
|
|
||||||
|
|
||||||
compat_tools = []
|
compat_tools = []
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
|
@ -141,9 +139,7 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]
|
||||||
|
|
||||||
compat_tools.append(compat_tool)
|
compat_tools.append(compat_tool)
|
||||||
|
|
||||||
if len(compat_tools) > 0:
|
return compat_tools
|
||||||
return compat_tools
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||||
|
@ -206,9 +202,10 @@ async def _process_vllm_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
elif choice.delta.tool_calls:
|
elif choice.delta.tool_calls:
|
||||||
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
tool_call = convert_tool_call(choice.delta.tool_calls[0])
|
||||||
tool_call_buf.tool_name += tool_call.tool_name
|
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||||
tool_call_buf.call_id += tool_call.call_id
|
tool_call_buf.call_id += tool_call.call_id
|
||||||
tool_call_buf.arguments += tool_call.arguments
|
# TODO: remove str() when dict type for 'arguments' is no longer allowed
|
||||||
|
tool_call_buf.arguments += str(tool_call.arguments)
|
||||||
else:
|
else:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
@ -248,10 +245,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk]:
|
||||||
|
assert self.model_store is not None
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = self.model_store.get_model(model_id)
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -270,17 +268,18 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk]:
|
||||||
|
assert self.model_store is not None
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = self.model_store.get_model(model_id)
|
||||||
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
|
||||||
# References:
|
# References:
|
||||||
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||||
|
@ -318,11 +317,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: AsyncOpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, client: AsyncOpenAI
|
||||||
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk]:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await client.chat.completions.create(**params)
|
stream = await client.chat.completions.create(**params)
|
||||||
if len(request.tools) > 0:
|
if request.tools:
|
||||||
res = _process_vllm_chat_completion_stream_response(stream)
|
res = _process_vllm_chat_completion_stream_response(stream)
|
||||||
else:
|
else:
|
||||||
res = process_chat_completion_stream_response(stream, request)
|
res = process_chat_completion_stream_response(stream, request)
|
||||||
|
@ -330,18 +331,21 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
|
assert self.client is not None
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator[CompletionResponseStreamChunk]:
|
||||||
|
assert self.client is not None
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> None:
|
||||||
|
assert self.client is not None
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
res = await self.client.models.list()
|
res = await self.client.models.list()
|
||||||
available_models = [m.id async for m in res]
|
available_models = [m.id async for m in res]
|
||||||
|
@ -350,14 +354,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||||
f"Available models: {', '.join(available_models)}"
|
f"Available models: {', '.join(available_models)}"
|
||||||
)
|
)
|
||||||
return model
|
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
options = get_sampling_options(request.sampling_params)
|
options = get_sampling_options(request.sampling_params)
|
||||||
if "max_tokens" not in options:
|
if "max_tokens" not in options:
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
input_dict = {}
|
input_dict: dict[str, Any] = {}
|
||||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
||||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||||
|
|
||||||
|
@ -368,9 +371,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||||
input_dict["extra_body"] = {"guided_json": request.response_format.json_schema}
|
input_dict["extra_body"] = {"guided_json": fmt.json_schema}
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
elif isinstance(fmt, GrammarResponseFormat):
|
||||||
raise NotImplementedError("Grammar response format not supported yet")
|
raise NotImplementedError("Grammar response format not supported yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
@ -393,7 +396,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
assert self.client is not None
|
||||||
|
model = self.model_store.get_model(model_id)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
assert model.model_type == ModelType.embedding
|
assert model.model_type == ModelType.embedding
|
||||||
|
|
|
@ -104,3 +104,6 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
|
@ -137,7 +137,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_options(params: SamplingParams) -> dict:
|
def get_sampling_options(params: SamplingParams | None) -> dict:
|
||||||
|
if not params:
|
||||||
|
return {}
|
||||||
|
|
||||||
options = {}
|
options = {}
|
||||||
if params:
|
if params:
|
||||||
options.update(get_sampling_strategy_options(params))
|
options.update(get_sampling_strategy_options(params))
|
||||||
|
|
|
@ -253,7 +253,6 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/sample/",
|
"^llama_stack/providers/remote/inference/sample/",
|
||||||
"^llama_stack/providers/remote/inference/tgi/",
|
"^llama_stack/providers/remote/inference/tgi/",
|
||||||
"^llama_stack/providers/remote/inference/together/",
|
"^llama_stack/providers/remote/inference/together/",
|
||||||
"^llama_stack/providers/remote/inference/vllm/",
|
|
||||||
"^llama_stack/providers/remote/safety/bedrock/",
|
"^llama_stack/providers/remote/safety/bedrock/",
|
||||||
"^llama_stack/providers/remote/safety/nvidia/",
|
"^llama_stack/providers/remote/safety/nvidia/",
|
||||||
"^llama_stack/providers/remote/safety/sample/",
|
"^llama_stack/providers/remote/safety/sample/",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue