Code refactoring and removing dead code

This commit is contained in:
Omar Abdelwahab 2025-10-02 18:38:30 -07:00
parent ef0736527d
commit f6080040da
6 changed files with 302 additions and 137 deletions

View file

@ -10,11 +10,13 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
get_sampling_options,
OpenAIChatCompletionToLlamaStackMixin,
process_chat_completion_response,
process_chat_completion_stream_response,
)
@ -41,13 +43,12 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
]
class RunpodInferenceAdapter(
@ -56,7 +57,9 @@ class RunpodInferenceAdapter(
OpenAIChatCompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
)
self.config = config
async def initialize(self) -> None:
@ -103,7 +106,9 @@ class RunpodInferenceAdapter(
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():