Merge remote-tracking branch 'origin/main' into resp_branching

This commit is contained in:
Ashwin Bharambe 2025-10-01 21:13:12 -07:00
commit 1536ae0333
144 changed files with 62682 additions and 51560 deletions

View file

@ -15,7 +15,6 @@ from openai.types.chat.chat_completion_chunk import (
)
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
@ -27,9 +26,6 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
@ -64,14 +60,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
convert_tool_call,
get_sampling_options,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
request_has_media,
)
from .config import VLLMInferenceAdapterConfig
@ -349,33 +339,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
self,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
async def chat_completion(
self,
model_id: str,
@ -460,24 +423,6 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
async for chunk in res:
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(
self, request: CompletionRequest
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
if self.client is None:
raise RuntimeError("Client is not initialized")
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
@ -497,7 +442,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
)
return model
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
async def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
@ -507,11 +452,7 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
if isinstance(request, ChatCompletionRequest) and request.tools:
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
else:
assert not request_has_media(request), "vLLM does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):