Code refactoring and removing dead code

This commit is contained in:
Omar Abdelwahab 2025-10-02 18:38:30 -07:00
parent ef0736527d
commit f6080040da
6 changed files with 302 additions and 137 deletions

View file

@ -11,8 +11,7 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
@ -25,9 +24,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
@ -102,14 +98,18 @@ class CerebrasInferenceAdapter(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
@ -117,15 +117,17 @@ class CerebrasInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
async def _get_params(self, request: ChatCompletionRequest) -> dict:
if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
else:
raise ValueError(f"Unknown request type {type(request)}")

View file

@ -10,11 +10,13 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
get_sampling_options,
OpenAIChatCompletionToLlamaStackMixin,
process_chat_completion_response,
process_chat_completion_stream_response,
)
@ -41,13 +43,12 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
]
class RunpodInferenceAdapter(
@ -56,7 +57,9 @@ class RunpodInferenceAdapter(
OpenAIChatCompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
)
self.config = config
async def initialize(self) -> None:
@ -103,7 +106,9 @@ class RunpodInferenceAdapter(
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():

View file

@ -9,12 +9,10 @@ from typing import Any
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
GreedySamplingStrategy,
Inference,
LogProbConfig,
@ -48,6 +46,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
request_has_media,
)
from openai import AsyncOpenAI
from . import WatsonXConfig
from .models import MODEL_ENTRIES
@ -85,7 +84,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
pass
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_api_key = (
self._config.api_key.get_secret_value() if self._config.api_key else None
)
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
@ -132,14 +133,18 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
finish_reason=(
result["stop_reason"] if result["stop_reason"] else None
),
text=result["generated_text"],
)
choices.append(choice)
@ -148,7 +153,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
@ -168,28 +175,44 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, llama_model
)
else:
assert not media_present, "Together does not support media for Completion requests"
assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
input_dict["params"][
GenParams.DECODING_METHOD
] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
input_dict["params"][
GenParams.MAX_NEW_TOKENS
] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
input_dict["params"][
GenParams.REPETITION_PENALTY
] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
input_dict["params"][
GenParams.TOP_P
] = request.sampling_params.strategy.top_p
input_dict["params"][
GenParams.TEMPERATURE
] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
input_dict["params"][
GenParams.TOP_K
] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0