Code refactoring and removing dead code

This commit is contained in:
Omar Abdelwahab 2025-10-02 18:38:30 -07:00
parent ef0736527d
commit f6080040da
6 changed files with 302 additions and 137 deletions

View file

@ -6,16 +6,7 @@
from collections.abc import AsyncIterator
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order
@ -32,6 +23,9 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
register_schema(ToolCall)
register_schema(ToolDefinition)
@ -357,32 +351,32 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
# @json_schema_type
# class CompletionResponse(MetricResponseMixin):
# """Response from a completion request.
:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""
# :param content: The generated completion text
# :param stop_reason: Reason why generation stopped
# :param logprobs: Optional log probabilities for generated tokens
# """
content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None
# content: str
# stop_reason: StopReason
# logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
# @json_schema_type
# class CompletionResponseStreamChunk(MetricResponseMixin):
# """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""
# :param delta: New content generated since last chunk. This can be one or more tokens.
# :param stop_reason: Optional reason why generation stopped, if complete
# :param logprobs: Optional log probabilities for generated tokens
# """
delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
# delta: str
# stop_reason: StopReason | None = None
# logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum):
@ -415,7 +409,9 @@ class ToolConfig(BaseModel):
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
system_message_behavior: SystemMessageBehavior | None = Field(
default=SystemMessageBehavior.append
)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
@ -544,15 +540,21 @@ class OpenAIFile(BaseModel):
OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
| OpenAIFile,
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
register_schema(
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
)
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
OpenAIChatCompletionTextOnlyMessageContent = (
str | list[OpenAIChatCompletionContentPartTextParam]
)
@json_schema_type
@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel):
OpenAIResponseFormatParam = Annotated[
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
OpenAIResponseFormatText
| OpenAIResponseFormatJSONSchema
| OpenAIResponseFormatJSONObject,
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@ -1049,8 +1053,16 @@ class InferenceProvider(Protocol):
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
query: (
str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
),
items: list[
str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
@ -1064,7 +1076,12 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
@ -1116,7 +1133,12 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/chat/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
@ -1173,7 +1195,12 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/embeddings",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
@ -1203,7 +1230,12 @@ class Inference(InferenceProvider):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/chat/completions",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
@ -1223,10 +1255,19 @@ class Inference(InferenceProvider):
raise NotImplementedError("List chat completions is not implemented")
@webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
route="/openai/v1/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
@webmethod(
route="/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.

View file

@ -7,24 +7,16 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from datetime import datetime, UTC
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.providers.utils.telemetry.tracing import (
enqueue_event,
get_current_span,
)
from openai.types.chat import (
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
)
from pydantic import Field, TypeAdapter
logger = get_logger(name=__name__, category="core::routers")
@ -101,7 +102,9 @@ class InferenceRouter(Inference):
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics(
self,
@ -156,11 +159,16 @@ class InferenceRouter(Inference):
total_tokens: int,
model: Model,
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry:
for metric in metrics:
enqueue_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens(
self,
@ -207,8 +215,13 @@ class InferenceRouter(Inference):
if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
if (
tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else:
params = {}
if tool_choice:
@ -226,9 +239,14 @@ class InferenceRouter(Inference):
pass
else:
# verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict(
model_id=model_id,
@ -243,7 +261,9 @@ class InferenceRouter(Inference):
tool_config=tool_config,
)
provider = await self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream:
response_stream = await provider.chat_completion(**params)
@ -263,7 +283,9 @@ class InferenceRouter(Inference):
)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
@ -336,7 +358,9 @@ class InferenceRouter(Inference):
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
@ -374,9 +398,13 @@ class InferenceRouter(Inference):
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
tool_choice
)
if tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
raise ValueError(
"'tool_choice' is only allowed when 'tools' is also provided"
)
if tools:
for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
@ -441,7 +469,9 @@ class InferenceRouter(Inference):
enqueue_event(metric)
# these metrics will show up in the client response.
response.metrics = (
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
metrics
if not hasattr(response, "metrics") or response.metrics is None
else response.metrics + metrics
)
return response
@ -477,19 +507,31 @@ class InferenceRouter(Inference):
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
raise NotImplementedError(
"List chat completions is not supported: inference store is not configured."
)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
raise NotImplementedError(
"Get chat completion is not supported: inference store is not configured."
)
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: dict
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
if (
choice.message
and choice.message.tool_calls is not None
and len(choice.message.tool_calls) == 0
):
choice.message.tool_calls = None
return response
@ -509,7 +551,9 @@ class InferenceRouter(Inference):
message=f"Health check timed out after {timeout} seconds",
)
except NotImplementedError:
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.NOT_IMPLEMENTED
)
except Exception as e:
health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
@ -522,7 +566,7 @@ class InferenceRouter(Inference):
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
completion_text = ""
async for chunk in response:
complete = False
@ -544,7 +588,11 @@ class InferenceRouter(Inference):
else:
if hasattr(chunk, "delta"):
completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
if (
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
complete = True
completion_tokens = await self._count_tokens(completion_text)
# if we are done receiving tokens
@ -569,9 +617,14 @@ class InferenceRouter(Inference):
# Return metrics in response
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
else:
# Fallback if no telemetry
completion_metrics = self._construct_metrics(
@ -581,14 +634,19 @@ class InferenceRouter(Inference):
model,
)
async_metrics = [
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
chunk.metrics = (
async_metrics
if chunk.metrics is None
else chunk.metrics + async_metrics
)
yield chunk
async def count_tokens_and_compute_metrics(
self,
response: ChatCompletionResponse | CompletionResponse,
response: ChatCompletionResponse,
prompt_tokens,
model,
tool_prompt_format: ToolPromptFormat | None = None,
@ -597,7 +655,9 @@ class InferenceRouter(Inference):
content = [response.completion_message]
else:
content = response.content
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
completion_tokens = await self._count_tokens(
messages=content, tool_prompt_format=tool_prompt_format
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
# Create a separate span for completion metrics
@ -610,11 +670,17 @@ class InferenceRouter(Inference):
model=model,
)
for metric in completion_metrics:
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
if metric.metric in [
"completion_tokens",
"total_tokens",
]: # Only log completion and total tokens
enqueue_event(metric)
# Return metrics in response
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in completion_metrics
]
# Fallback if no telemetry
metrics = self._construct_metrics(
@ -623,7 +689,10 @@ class InferenceRouter(Inference):
total_tokens,
model,
)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def stream_tokens_and_compute_metrics_openai_chat(
self,
@ -664,33 +733,48 @@ class InferenceRouter(Inference):
if choice_delta.delta:
delta = choice_delta.delta
if delta.content:
current_choice_data["content_parts"].append(delta.content)
current_choice_data["content_parts"].append(
delta.content
)
if delta.tool_calls:
for tool_call_delta in delta.tool_calls:
tc_idx = tool_call_delta.index
if tc_idx not in current_choice_data["tool_calls_builder"]:
current_choice_data["tool_calls_builder"][tc_idx] = {
if (
tc_idx
not in current_choice_data["tool_calls_builder"]
):
current_choice_data["tool_calls_builder"][
tc_idx
] = {
"id": None,
"type": "function",
"function_name_parts": [],
"function_arguments_parts": [],
}
builder = current_choice_data["tool_calls_builder"][tc_idx]
builder = current_choice_data["tool_calls_builder"][
tc_idx
]
if tool_call_delta.id:
builder["id"] = tool_call_delta.id
if tool_call_delta.type:
builder["type"] = tool_call_delta.type
if tool_call_delta.function:
if tool_call_delta.function.name:
builder["function_name_parts"].append(tool_call_delta.function.name)
builder["function_name_parts"].append(
tool_call_delta.function.name
)
if tool_call_delta.function.arguments:
builder["function_arguments_parts"].append(
tool_call_delta.function.arguments
)
if choice_delta.finish_reason:
current_choice_data["finish_reason"] = choice_delta.finish_reason
current_choice_data["finish_reason"] = (
choice_delta.finish_reason
)
if choice_delta.logprobs and choice_delta.logprobs.content:
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
current_choice_data["logprobs_content_parts"].extend(
choice_delta.logprobs.content
)
# Compute metrics on final chunk
if chunk.choices and chunk.choices[0].finish_reason:
@ -720,8 +804,12 @@ class InferenceRouter(Inference):
if choice_data["tool_calls_builder"]:
for tc_build_data in choice_data["tool_calls_builder"].values():
if tc_build_data["id"]:
func_name = "".join(tc_build_data["function_name_parts"])
func_args = "".join(tc_build_data["function_arguments_parts"])
func_name = "".join(
tc_build_data["function_name_parts"]
)
func_args = "".join(
tc_build_data["function_arguments_parts"]
)
assembled_tool_calls.append(
OpenAIChatCompletionToolCall(
id=tc_build_data["id"],
@ -734,10 +822,16 @@ class InferenceRouter(Inference):
message = OpenAIAssistantMessageParam(
role="assistant",
content=content_str if content_str else None,
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
tool_calls=(
assembled_tool_calls if assembled_tool_calls else None
),
)
logprobs_content = choice_data["logprobs_content_parts"]
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
final_logprobs = (
OpenAIChoiceLogprobs(content=logprobs_content)
if logprobs_content
else None
)
assembled_choices.append(
OpenAIChoice(
@ -756,4 +850,6 @@ class InferenceRouter(Inference):
object="chat.completion",
)
logger.debug(f"InferenceRouter.completion_response: {final_response}")
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
asyncio.create_task(
self.store.store_chat_completion(final_response, messages)
)

View file

@ -25,9 +25,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -35,7 +32,6 @@ log = get_logger(name=__name__, category="inference")
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin,
InferenceProvider,
ModelsProtocolPrivate,
@ -114,4 +110,6 @@ class SentenceTransformersInferenceImpl(
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
raise NotImplementedError(
"OpenAI completion not supported by sentence transformers provider"
)

View file

@ -11,8 +11,7 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
ChatCompletionResponse,
Inference,
LogProbConfig,
Message,
@ -25,9 +24,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
@ -102,14 +98,18 @@ class CerebrasInferenceAdapter(
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self._cerebras_client.completions.create(**params)
@ -117,15 +117,17 @@ class CerebrasInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
async def _get_params(self, request: ChatCompletionRequest) -> dict:
if request.sampling_params and isinstance(
request.sampling_params.strategy, TopKSamplingStrategy
):
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
elif isinstance(request, CompletionRequest):
prompt = await completion_request_to_prompt(request)
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model)
)
else:
raise ValueError(f"Unknown request type {type(request)}")

View file

@ -10,11 +10,13 @@ from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
get_sampling_options,
OpenAIChatCompletionToLlamaStackMixin,
process_chat_completion_response,
process_chat_completion_stream_response,
)
@ -41,13 +43,12 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
}
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
]
class RunpodInferenceAdapter(
@ -56,7 +57,9 @@ class RunpodInferenceAdapter(
OpenAIChatCompletionToLlamaStackMixin,
):
def __init__(self, config: RunpodImplConfig) -> None:
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
)
self.config = config
async def initialize(self) -> None:
@ -103,7 +106,9 @@ class RunpodInferenceAdapter(
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():

View file

@ -9,12 +9,10 @@ from typing import Any
from ibm_watsonx_ai.foundation_models import Model
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
GreedySamplingStrategy,
Inference,
LogProbConfig,
@ -48,6 +46,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
request_has_media,
)
from openai import AsyncOpenAI
from . import WatsonXConfig
from .models import MODEL_ENTRIES
@ -85,7 +84,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
pass
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_api_key = (
self._config.api_key.get_secret_value() if self._config.api_key else None
)
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
@ -132,14 +133,18 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
finish_reason=(
result["stop_reason"] if result["stop_reason"] else None
),
text=result["generated_text"],
)
choices.append(choice)
@ -148,7 +153,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
@ -168,28 +175,44 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
input_dict["prompt"] = await chat_completion_request_to_prompt(
request, llama_model
)
else:
assert not media_present, "Together does not support media for Completion requests"
assert (
not media_present
), "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
input_dict["params"][
GenParams.DECODING_METHOD
] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
input_dict["params"][
GenParams.MAX_NEW_TOKENS
] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
input_dict["params"][
GenParams.REPETITION_PENALTY
] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
input_dict["params"][
GenParams.TOP_P
] = request.sampling_params.strategy.top_p
input_dict["params"][
GenParams.TEMPERATURE
] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
input_dict["params"][
GenParams.TOP_K
] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0