Stable update

This commit is contained in:
Justin 2025-10-01 14:33:15 -07:00
parent 064602bc97
commit 3236f82223

View file

@ -5,14 +5,21 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from openai import OpenAI import asyncio
from typing import Any
from openai import AsyncOpenAI
from llama_stack.apis.inference import * from llama_stack.apis.inference import *
from llama_stack.apis.inference import OpenAIEmbeddingsResponse from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.common.content_types import InterleavedContentItem
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin, convert_message_to_openai_dict,
OpenAICompletionToLlamaStackMixin,
get_sampling_options, get_sampling_options,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
@ -23,16 +30,16 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt, completion_request_to_prompt,
interleaved_content_as_str, interleaved_content_as_str,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import RunpodImplConfig from .config import RunpodImplConfig
MODEL_ENTRIES = [] MODEL_ENTRIES = []
class RunpodInferenceAdapter( class RunpodInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper, ModelRegistryHelper,
Inference, Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
): ):
""" """
Adapter for RunPod's OpenAI-compatible API endpoints. Adapter for RunPod's OpenAI-compatible API endpoints.
@ -41,44 +48,96 @@ class RunpodInferenceAdapter(
""" """
def __init__(self, config: RunpodImplConfig) -> None: def __init__(self, config: RunpodImplConfig) -> None:
OpenAIMixin.__init__(self)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config self.config = config
def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
return self.config.api_token
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return self.config.url
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def get_extra_client_params(self) -> dict[str, Any]:
"""Override to add RunPod-specific client parameters if needed."""
return {}
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
):
"""Override to add RunPod-specific stream_options requirement."""
if stream and not stream_options:
stream_options = {"include_usage": True}
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
""" """
Register any model with the runpod provider_id. Pass-through registration - accepts any model that the RunPod endpoint serves.
In the .yaml file the model: can be defined as example
Pass-through registration - accepts any model string that the RunPod endpoint serves. models:
No static model validation since RunPod endpoints can serve arbitrary vLLM models.
YAML Configuration Example:
models:
- metadata: {} - metadata: {}
model_id: runpod/qwen/qwen3-8b model_id: qwen3-32b-awq
model_type: llm model_type: llm
provider_id: runpod provider_id: runpod
provider_model_id: qwen/qwen3-8b provider_model_id: Qwen/Qwen3-32B-AWQ
- metadata: {}
model_id: runpod/deepcogito/cogito-v2-preview-llama-70B
model_type: llm
provider_id: runpod
provider_model_id: deepcogito/cogito-v2-preview-llama-70B
The provider strips 'runpod/' prefix before API calls:
"runpod/qwen/qwen3-8b" -> "qwen/qwen3-8b"
""" """
if model.provider_id == "runpod": return model
logger.info(
f"Registering model: {model.identifier} -> {model.provider_resource_id}"
)
return model
return await super().register_model(model)
async def completion( async def completion(
self, self,
@ -88,12 +147,16 @@ class RunpodInferenceAdapter(
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> AsyncGenerator: ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
# Resolve model_id to provider_resource_id
model = await self.model_store.get_model(model_id)
provider_model_id = model.provider_resource_id or model_id
request = CompletionRequest( request = CompletionRequest(
model=model_id, model=provider_model_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
@ -101,12 +164,10 @@ class RunpodInferenceAdapter(
logprobs=logprobs, logprobs=logprobs,
) )
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream: if stream:
return self._stream_completion(request, client) return self._stream_completion(request, self.client)
else: else:
return await self._nonstream_completion(request, client) return await self._nonstream_completion(request, self.client)
async def chat_completion( async def chat_completion(
self, self,
@ -120,13 +181,17 @@ class RunpodInferenceAdapter(
stream: bool | None = False, stream: bool | None = False,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None, tool_config: ToolConfig | None = None,
) -> AsyncGenerator: ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
"""Process chat completion requests using RunPod's OpenAI-compatible API.""" """Process chat completion requests using RunPod's OpenAI-compatible API."""
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
# Resolve model_id to provider_resource_id
model = await self.model_store.get_model(model_id)
provider_model_id = model.provider_resource_id or model_id
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model_id, model=provider_model_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
@ -135,49 +200,34 @@ class RunpodInferenceAdapter(
tool_config=tool_config, tool_config=tool_config,
) )
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request, self.client)
else: else:
return await self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = await self._get_chat_params(request) params = await self._get_chat_params(request)
r = client.chat.completions.create(**params) # Make actual RunPod API call
r = await client.chat.completions.create(**params)
return process_chat_completion_response(r, request) return process_chat_completion_response(r, request)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI self, request: ChatCompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator: ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
params = await self._get_chat_params(request) params = await self._get_chat_params(request)
# Make actual RunPod API call for streaming
async def _to_async_generator(): stream = await client.chat.completions.create(**params)
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def _get_chat_params(self, request: ChatCompletionRequest) -> dict: async def _get_chat_params(self, request: ChatCompletionRequest) -> dict:
"""Convert Llama Stack request to RunPod API parameters.""" """Convert Llama Stack request to RunPod API parameters."""
messages = [ messages = [await convert_message_to_openai_dict(m, download=False) for m in request.messages]
{"role": msg.role, "content": msg.content} for msg in request.messages
]
# Resolve model_id to provider_resource_id
model_obj = await self.model_store.get_model(request.model)
model = model_obj.provider_resource_id or request.model
if model.startswith("runpod/"):
model = model.replace("runpod/", "", 1)
params = { params = {
"model": model, "model": request.model,
"messages": messages, "messages": messages,
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
@ -189,37 +239,27 @@ class RunpodInferenceAdapter(
return params return params
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest, client: OpenAI self, request: CompletionRequest, client: AsyncOpenAI
) -> CompletionResponse: ) -> CompletionResponse:
params = await self._get_completion_params(request) params = await self._get_completion_params(request)
r = client.completions.create(**params) # Make actual RunPod API call
r = await client.completions.create(**params)
return process_completion_response(r) return process_completion_response(r)
async def _stream_completion( async def _stream_completion(
self, request: CompletionRequest, client: OpenAI self, request: CompletionRequest, client: AsyncOpenAI
) -> AsyncGenerator: ) -> AsyncGenerator:
params = await self._get_completion_params(request) params = await self._get_completion_params(request)
# Make actual RunPod API call for streaming
async def _to_async_generator(): stream = await client.completions.create(**params)
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream): async for chunk in process_completion_stream_response(stream):
yield chunk yield chunk
async def _get_completion_params(self, request: CompletionRequest) -> dict: async def _get_completion_params(self, request: CompletionRequest) -> dict:
# Resolve model_id to provider_resource_id """Convert Llama Stack request to RunPod API parameters."""
model_obj = await self.model_store.get_model(request.model)
model = model_obj.provider_resource_id or request.model
if model.startswith("runpod/"):
model = model.replace("runpod/", "", 1)
params = { params = {
"model": model, "model": request.model,
"prompt": completion_request_to_prompt(request), "prompt": await completion_request_to_prompt(request),
"stream": request.stream, "stream": request.stream,
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
@ -241,16 +281,11 @@ class RunpodInferenceAdapter(
model_obj = await self.model_store.get_model(model_id) model_obj = await self.model_store.get_model(model_id)
model = model_obj.provider_resource_id or model_id model = model_obj.provider_resource_id or model_id
if model.startswith("runpod/"):
model = model.replace("runpod/", "", 1)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
kwargs = {} kwargs = {}
if output_dimension: if output_dimension:
kwargs["dimensions"] = output_dimension kwargs["dimensions"] = output_dimension
response = client.embeddings.create( response = await self.client.embeddings.create(
model=model, model=model,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
**kwargs, **kwargs,
@ -269,19 +304,14 @@ class RunpodInferenceAdapter(
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
# Resolve model_id to provider_resource_id # Resolve model_id to provider_resource_id
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
model_stripped = model_obj.provider_resource_id or model provider_model_id = model_obj.provider_resource_id or model
if model_stripped.startswith("runpod/"): response = await self.client.embeddings.create(
model_stripped = model_stripped.replace("runpod/", "", 1) model=provider_model_id,
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
response = client.embeddings.create(
model=model_stripped,
input=input, input=input,
encoding_format=encoding_format, encoding_format=encoding_format,
dimensions=dimensions, dimensions=dimensions,
user=user, user=user,
) )
return response return response