diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 829a94a6a..d149a4dc2 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -6,16 +6,7 @@ from collections.abc import AsyncIterator from enum import Enum -from typing import ( - Annotated, - Any, - Literal, - Protocol, - runtime_checkable, -) - -from pydantic import BaseModel, Field, field_validator -from typing_extensions import TypedDict +from typing import Annotated, Any, Literal, Protocol, runtime_checkable from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.responses import Order @@ -32,6 +23,9 @@ from llama_stack.models.llama.datatypes import ( from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +from pydantic import BaseModel, Field, field_validator +from typing_extensions import TypedDict + register_schema(ToolCall) register_schema(ToolDefinition) @@ -357,32 +351,32 @@ class CompletionRequest(BaseModel): logprobs: LogProbConfig | None = None -@json_schema_type -class CompletionResponse(MetricResponseMixin): - """Response from a completion request. +# @json_schema_type +# class CompletionResponse(MetricResponseMixin): +# """Response from a completion request. - :param content: The generated completion text - :param stop_reason: Reason why generation stopped - :param logprobs: Optional log probabilities for generated tokens - """ +# :param content: The generated completion text +# :param stop_reason: Reason why generation stopped +# :param logprobs: Optional log probabilities for generated tokens +# """ - content: str - stop_reason: StopReason - logprobs: list[TokenLogProbs] | None = None +# content: str +# stop_reason: StopReason +# logprobs: list[TokenLogProbs] | None = None -@json_schema_type -class CompletionResponseStreamChunk(MetricResponseMixin): - """A chunk of a streamed completion response. +# @json_schema_type +# class CompletionResponseStreamChunk(MetricResponseMixin): +# """A chunk of a streamed completion response. - :param delta: New content generated since last chunk. This can be one or more tokens. - :param stop_reason: Optional reason why generation stopped, if complete - :param logprobs: Optional log probabilities for generated tokens - """ +# :param delta: New content generated since last chunk. This can be one or more tokens. +# :param stop_reason: Optional reason why generation stopped, if complete +# :param logprobs: Optional log probabilities for generated tokens +# """ - delta: str - stop_reason: StopReason | None = None - logprobs: list[TokenLogProbs] | None = None +# delta: str +# stop_reason: StopReason | None = None +# logprobs: list[TokenLogProbs] | None = None class SystemMessageBehavior(Enum): @@ -415,7 +409,9 @@ class ToolConfig(BaseModel): tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto) tool_prompt_format: ToolPromptFormat | None = Field(default=None) - system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append) + system_message_behavior: SystemMessageBehavior | None = Field( + default=SystemMessageBehavior.append + ) def model_post_init(self, __context: Any) -> None: if isinstance(self.tool_choice, str): @@ -544,15 +540,21 @@ class OpenAIFile(BaseModel): OpenAIChatCompletionContentPartParam = Annotated[ - OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile, + OpenAIChatCompletionContentPartTextParam + | OpenAIChatCompletionContentPartImageParam + | OpenAIFile, Field(discriminator="type"), ] -register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam") +register_schema( + OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam" +) OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam] -OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam] +OpenAIChatCompletionTextOnlyMessageContent = ( + str | list[OpenAIChatCompletionContentPartTextParam] +) @json_schema_type @@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel): OpenAIResponseFormatParam = Annotated[ - OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject, + OpenAIResponseFormatText + | OpenAIResponseFormatJSONSchema + | OpenAIResponseFormatJSONObject, Field(discriminator="type"), ] register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") @@ -1049,8 +1053,16 @@ class InferenceProvider(Protocol): async def rerank( self, model: str, - query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, - items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + query: ( + str + | OpenAIChatCompletionContentPartTextParam + | OpenAIChatCompletionContentPartImageParam + ), + items: list[ + str + | OpenAIChatCompletionContentPartTextParam + | OpenAIChatCompletionContentPartImageParam + ], max_num_results: int | None = None, ) -> RerankResponse: """Rerank a list of documents based on their relevance to a query. @@ -1064,7 +1076,12 @@ class InferenceProvider(Protocol): raise NotImplementedError("Reranking is not implemented") return # this is so mypy's safe-super rule will consider the method concrete - @webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/completions", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_completion( self, @@ -1116,7 +1133,12 @@ class InferenceProvider(Protocol): """ ... - @webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/chat/completions", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_chat_completion( self, @@ -1173,7 +1195,12 @@ class InferenceProvider(Protocol): """ ... - @webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/embeddings", + method="POST", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1) async def openai_embeddings( self, @@ -1203,7 +1230,12 @@ class Inference(InferenceProvider): - Embedding models: these models generate embeddings to be used for semantic search. """ - @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod( + route="/openai/v1/chat/completions", + method="GET", + level=LLAMA_STACK_API_V1, + deprecated=True, + ) @webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1) async def list_chat_completions( self, @@ -1223,10 +1255,19 @@ class Inference(InferenceProvider): raise NotImplementedError("List chat completions is not implemented") @webmethod( - route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True + route="/openai/v1/chat/completions/{completion_id}", + method="GET", + level=LLAMA_STACK_API_V1, + deprecated=True, ) - @webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1) - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + @webmethod( + route="/chat/completions/{completion_id}", + method="GET", + level=LLAMA_STACK_API_V1, + ) + async def get_chat_completion( + self, completion_id: str + ) -> OpenAICompletionWithInputMessages: """Describe a chat completion by its ID. :param completion_id: ID of the chat completion. diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 4b004a82c..7f0e4a352 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -7,24 +7,16 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator -from datetime import UTC, datetime +from datetime import datetime, UTC from typing import Annotated, Any -from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam -from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam -from pydantic import Field, TypeAdapter - -from llama_stack.apis.common.content_types import ( - InterleavedContent, -) +from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, Inference, ListOpenAIChatCompletionResponse, LogProbConfig, @@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span +from llama_stack.providers.utils.telemetry.tracing import ( + enqueue_event, + get_current_span, +) + +from openai.types.chat import ( + ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam as OpenAIChatCompletionToolParam, +) +from pydantic import Field, TypeAdapter logger = get_logger(name=__name__, category="core::routers") @@ -101,7 +102,9 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( self, @@ -156,11 +159,16 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> list[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: enqueue_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -207,8 +215,13 @@ class InferenceRouter(Inference): if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -226,9 +239,14 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -243,7 +261,9 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = await self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: response_stream = await provider.chat_completion(**params) @@ -263,7 +283,9 @@ class InferenceRouter(Inference): ) # these metrics will show up in the client response. response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + metrics + if not hasattr(response, "metrics") or response.metrics is None + else response.metrics + metrics ) return response @@ -336,7 +358,9 @@ class InferenceRouter(Inference): # these metrics will show up in the client response. response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + metrics + if not hasattr(response, "metrics") or response.metrics is None + else response.metrics + metrics ) return response @@ -374,9 +398,13 @@ class InferenceRouter(Inference): # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface if tool_choice: - TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python( + tool_choice + ) if tools is None: - raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") + raise ValueError( + "'tool_choice' is only allowed when 'tools' is also provided" + ) if tools: for tool in tools: TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) @@ -441,7 +469,9 @@ class InferenceRouter(Inference): enqueue_event(metric) # these metrics will show up in the client response. response.metrics = ( - metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + metrics + if not hasattr(response, "metrics") or response.metrics is None + else response.metrics + metrics ) return response @@ -477,19 +507,31 @@ class InferenceRouter(Inference): ) -> ListOpenAIChatCompletionResponse: if self.store: return await self.store.list_chat_completions(after, limit, model, order) - raise NotImplementedError("List chat completions is not supported: inference store is not configured.") + raise NotImplementedError( + "List chat completions is not supported: inference store is not configured." + ) - async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages: + async def get_chat_completion( + self, completion_id: str + ) -> OpenAICompletionWithInputMessages: if self.store: return await self.store.get_chat_completion(completion_id) - raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") + raise NotImplementedError( + "Get chat completion is not supported: inference store is not configured." + ) - async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: + async def _nonstream_openai_chat_completion( + self, provider: Inference, params: dict + ) -> OpenAIChatCompletion: response = await provider.openai_chat_completion(**params) for choice in response.choices: # some providers return an empty list for no tool calls in non-streaming responses # but the OpenAI API returns None. So, set tool_calls to None if it's empty - if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0: + if ( + choice.message + and choice.message.tool_calls is not None + and len(choice.message.tool_calls) == 0 + ): choice.message.tool_calls = None return response @@ -509,7 +551,9 @@ class InferenceRouter(Inference): message=f"Health check timed out after {timeout} seconds", ) except NotImplementedError: - health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED) + health_statuses[provider_id] = HealthResponse( + status=HealthStatus.NOT_IMPLEMENTED + ) except Exception as e: health_statuses[provider_id] = HealthResponse( status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" @@ -522,7 +566,7 @@ class InferenceRouter(Inference): prompt_tokens, model, tool_prompt_format: ToolPromptFormat | None = None, - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: completion_text = "" async for chunk in response: complete = False @@ -544,7 +588,11 @@ class InferenceRouter(Inference): else: if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): complete = True completion_tokens = await self._count_tokens(completion_text) # if we are done receiving tokens @@ -569,9 +617,14 @@ class InferenceRouter(Inference): # Return metrics in response async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in completion_metrics ] - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + chunk.metrics = ( + async_metrics + if chunk.metrics is None + else chunk.metrics + async_metrics + ) else: # Fallback if no telemetry completion_metrics = self._construct_metrics( @@ -581,14 +634,19 @@ class InferenceRouter(Inference): model, ) async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in completion_metrics ] - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + chunk.metrics = ( + async_metrics + if chunk.metrics is None + else chunk.metrics + async_metrics + ) yield chunk async def count_tokens_and_compute_metrics( self, - response: ChatCompletionResponse | CompletionResponse, + response: ChatCompletionResponse, prompt_tokens, model, tool_prompt_format: ToolPromptFormat | None = None, @@ -597,7 +655,9 @@ class InferenceRouter(Inference): content = [response.completion_message] else: content = response.content - completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) + completion_tokens = await self._count_tokens( + messages=content, tool_prompt_format=tool_prompt_format + ) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) # Create a separate span for completion metrics @@ -610,11 +670,17 @@ class InferenceRouter(Inference): model=model, ) for metric in completion_metrics: - if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens + if metric.metric in [ + "completion_tokens", + "total_tokens", + ]: # Only log completion and total tokens enqueue_event(metric) # Return metrics in response - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in completion_metrics + ] # Fallback if no telemetry metrics = self._construct_metrics( @@ -623,7 +689,10 @@ class InferenceRouter(Inference): total_tokens, model, ) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def stream_tokens_and_compute_metrics_openai_chat( self, @@ -664,33 +733,48 @@ class InferenceRouter(Inference): if choice_delta.delta: delta = choice_delta.delta if delta.content: - current_choice_data["content_parts"].append(delta.content) + current_choice_data["content_parts"].append( + delta.content + ) if delta.tool_calls: for tool_call_delta in delta.tool_calls: tc_idx = tool_call_delta.index - if tc_idx not in current_choice_data["tool_calls_builder"]: - current_choice_data["tool_calls_builder"][tc_idx] = { + if ( + tc_idx + not in current_choice_data["tool_calls_builder"] + ): + current_choice_data["tool_calls_builder"][ + tc_idx + ] = { "id": None, "type": "function", "function_name_parts": [], "function_arguments_parts": [], } - builder = current_choice_data["tool_calls_builder"][tc_idx] + builder = current_choice_data["tool_calls_builder"][ + tc_idx + ] if tool_call_delta.id: builder["id"] = tool_call_delta.id if tool_call_delta.type: builder["type"] = tool_call_delta.type if tool_call_delta.function: if tool_call_delta.function.name: - builder["function_name_parts"].append(tool_call_delta.function.name) + builder["function_name_parts"].append( + tool_call_delta.function.name + ) if tool_call_delta.function.arguments: builder["function_arguments_parts"].append( tool_call_delta.function.arguments ) if choice_delta.finish_reason: - current_choice_data["finish_reason"] = choice_delta.finish_reason + current_choice_data["finish_reason"] = ( + choice_delta.finish_reason + ) if choice_delta.logprobs and choice_delta.logprobs.content: - current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + current_choice_data["logprobs_content_parts"].extend( + choice_delta.logprobs.content + ) # Compute metrics on final chunk if chunk.choices and chunk.choices[0].finish_reason: @@ -720,8 +804,12 @@ class InferenceRouter(Inference): if choice_data["tool_calls_builder"]: for tc_build_data in choice_data["tool_calls_builder"].values(): if tc_build_data["id"]: - func_name = "".join(tc_build_data["function_name_parts"]) - func_args = "".join(tc_build_data["function_arguments_parts"]) + func_name = "".join( + tc_build_data["function_name_parts"] + ) + func_args = "".join( + tc_build_data["function_arguments_parts"] + ) assembled_tool_calls.append( OpenAIChatCompletionToolCall( id=tc_build_data["id"], @@ -734,10 +822,16 @@ class InferenceRouter(Inference): message = OpenAIAssistantMessageParam( role="assistant", content=content_str if content_str else None, - tool_calls=assembled_tool_calls if assembled_tool_calls else None, + tool_calls=( + assembled_tool_calls if assembled_tool_calls else None + ), ) logprobs_content = choice_data["logprobs_content_parts"] - final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None + final_logprobs = ( + OpenAIChoiceLogprobs(content=logprobs_content) + if logprobs_content + else None + ) assembled_choices.append( OpenAIChoice( @@ -756,4 +850,6 @@ class InferenceRouter(Inference): object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") - asyncio.create_task(self.store.store_chat_completion(final_response, messages)) + asyncio.create_task( + self.store.store_chat_completion(final_response, messages) + ) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index cd682dca6..b975fb13f 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -25,9 +25,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, -) from .config import SentenceTransformersInferenceConfig @@ -35,7 +32,6 @@ log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( - OpenAIChatCompletionToLlamaStackMixin, SentenceTransformerEmbeddingMixin, InferenceProvider, ModelsProtocolPrivate, @@ -114,4 +110,6 @@ class SentenceTransformersInferenceImpl( # for fill-in-the-middle type completion suffix: str | None = None, ) -> OpenAICompletion: - raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") + raise NotImplementedError( + "OpenAI completion not supported by sentence transformers provider" + ) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 95da71de8..d1ddf9670 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -11,8 +11,7 @@ from cerebras.cloud.sdk import AsyncCerebras from llama_stack.apis.inference import ( ChatCompletionRequest, - CompletionRequest, - CompletionResponse, + ChatCompletionResponse, Inference, LogProbConfig, Message, @@ -25,9 +24,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, TopKSamplingStrategy, ) -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, @@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, - completion_request_to_prompt, ) from .config import CerebrasImplConfig @@ -102,14 +98,18 @@ class CerebrasInferenceAdapter( else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse: + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) r = await self._cerebras_client.completions.create(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) stream = await self._cerebras_client.completions.create(**params) @@ -117,15 +117,17 @@ class CerebrasInferenceAdapter( async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: - if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + async def _get_params(self, request: ChatCompletionRequest) -> dict: + if request.sampling_params and isinstance( + request.sampling_params.strategy, TopKSamplingStrategy + ): raise ValueError("`top_k` not supported by Cerebras") prompt = "" if isinstance(request, ChatCompletionRequest): - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) - elif isinstance(request, CompletionRequest): - prompt = await completion_request_to_prompt(request) + prompt = await chat_completion_request_to_prompt( + request, self.get_llama_model(request.model) + ) else: raise ValueError(f"Unknown request type {type(request)}") diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 77c5c7187..15d04d8d6 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -10,11 +10,13 @@ from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import OpenAIEmbeddingsResponse -# from llama_stack.providers.datatypes import ModelsProtocolPrivate -from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry +from llama_stack.providers.utils.inference.model_registry import ( + build_hf_repo_model_entry, + ModelRegistryHelper, +) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, get_sampling_options, + OpenAIChatCompletionToLlamaStackMixin, process_chat_completion_response, process_chat_completion_stream_response, ) @@ -41,13 +43,12 @@ RUNPOD_SUPPORTED_MODELS = { "Llama3.2-3B": "meta-llama/Llama-3.2-3B", } -SAFETY_MODELS_ENTRIES = [] # Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template MODEL_ENTRIES = [ build_hf_repo_model_entry(provider_model_id, model_descriptor) for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items() -] + SAFETY_MODELS_ENTRIES +] class RunpodInferenceAdapter( @@ -56,7 +57,9 @@ class RunpodInferenceAdapter( OpenAIChatCompletionToLlamaStackMixin, ): def __init__(self, config: RunpodImplConfig) -> None: - ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS) + ModelRegistryHelper.__init__( + self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS + ) self.config = config async def initialize(self) -> None: @@ -103,7 +106,9 @@ class RunpodInferenceAdapter( r = client.completions.create(**params) return process_chat_completion_response(r, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest, client: OpenAI + ) -> AsyncGenerator: params = self._get_params(request) async def _to_async_generator(): diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index cb9d61102..9c8831c0d 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -9,12 +9,10 @@ from typing import Any from ibm_watsonx_ai.foundation_models import Model from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams -from openai import AsyncOpenAI from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, GreedySamplingStrategy, Inference, LogProbConfig, @@ -48,6 +46,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt, request_has_media, ) +from openai import AsyncOpenAI from . import WatsonXConfig from .models import MODEL_ENTRIES @@ -85,7 +84,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): pass def _get_client(self, model_id) -> Model: - config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None + config_api_key = ( + self._config.api_key.get_secret_value() if self._config.api_key else None + ) config_url = self._config.url project_id = self._config.project_id credentials = {"url": config_url, "apikey": config_api_key} @@ -132,14 +133,18 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): else: return await self._nonstream_chat_completion(request) - async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: params = await self._get_params(request) r = self._get_client(request.model).generate(**params) choices = [] if "results" in r: for result in r["results"]: choice = OpenAICompatCompletionChoice( - finish_reason=result["stop_reason"] if result["stop_reason"] else None, + finish_reason=( + result["stop_reason"] if result["stop_reason"] else None + ), text=result["generated_text"], ) choices.append(choice) @@ -148,7 +153,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): ) return process_chat_completion_response(response, request) - async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: params = await self._get_params(request) model_id = request.model @@ -168,28 +175,44 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict: + async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {"params": {}} media_present = request_has_media(request) llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + input_dict["prompt"] = await chat_completion_request_to_prompt( + request, llama_model + ) else: - assert not media_present, "Together does not support media for Completion requests" + assert ( + not media_present + ), "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) if request.sampling_params: if request.sampling_params.strategy: - input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type + input_dict["params"][ + GenParams.DECODING_METHOD + ] = request.sampling_params.strategy.type if request.sampling_params.max_tokens: - input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens + input_dict["params"][ + GenParams.MAX_NEW_TOKENS + ] = request.sampling_params.max_tokens if request.sampling_params.repetition_penalty: - input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty + input_dict["params"][ + GenParams.REPETITION_PENALTY + ] = request.sampling_params.repetition_penalty if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): - input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p - input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature + input_dict["params"][ + GenParams.TOP_P + ] = request.sampling_params.strategy.top_p + input_dict["params"][ + GenParams.TEMPERATURE + ] = request.sampling_params.strategy.temperature if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): - input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k + input_dict["params"][ + GenParams.TOP_K + ] = request.sampling_params.strategy.top_k if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): input_dict["params"][GenParams.TEMPERATURE] = 0.0