From 5bc5fed6df7e77c41d0c9e4725fed044fa3f12b7 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Tue, 8 Apr 2025 09:10:52 -0400 Subject: [PATCH] Clean up some more usage of direct OpenAI types --- llama_stack/distribution/routers/routers.py | 5 +- .../sentence_transformers.py | 62 +++---------------- .../providers/remote/inference/vllm/vllm.py | 4 +- .../utils/inference/openai_compat.py | 3 +- 4 files changed, 10 insertions(+), 64 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 4f3e97778..19cc8ac09 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,9 +7,6 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union -from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.completion import Completion as OpenAICompletion - from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -38,7 +35,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIMessageParam +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 7cce2fb92..9c370b6c5 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -5,10 +5,7 @@ # the root directory of this source tree. import logging -from typing import Any, AsyncGenerator, Dict, List, Optional, Union - -from openai.types.chat import ChatCompletion as OpenAIChatCompletion -from openai.types.completion import Completion as OpenAICompletion +from typing import AsyncGenerator, List, Optional, Union from llama_stack.apis.inference import ( CompletionResponse, @@ -22,11 +19,14 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIMessageParam from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, +) from .config import SentenceTransformersInferenceConfig @@ -34,6 +34,8 @@ log = logging.getLogger(__name__) class SentenceTransformersInferenceImpl( + OpenAIChatCompletionUnsupportedMixin, + OpenAICompletionUnsupportedMixin, SentenceTransformerEmbeddingMixin, Inference, ModelsProtocolPrivate, @@ -78,53 +80,3 @@ class SentenceTransformersInferenceImpl( tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") - - async def openai_completion( - self, - model: str, - prompt: str, - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> OpenAICompletion: - raise ValueError("Sentence transformers don't support openai completion") - - async def openai_chat_completion( - self, - model: str, - messages: List[OpenAIMessageParam], - frequency_penalty: Optional[float] = None, - function_call: Optional[Union[str, Dict[str, Any]]] = None, - functions: Optional[List[Dict[str, Any]]] = None, - logit_bias: Optional[Dict[str, float]] = None, - logprobs: Optional[bool] = None, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - parallel_tool_calls: Optional[bool] = None, - presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, - seed: Optional[int] = None, - stop: Optional[Union[str, List[str]]] = None, - stream: Optional[bool] = None, - stream_options: Optional[Dict[str, Any]] = None, - temperature: Optional[float] = None, - tool_choice: Optional[Union[str, Dict[str, Any]]] = None, - tools: Optional[List[Dict[str, Any]]] = None, - top_logprobs: Optional[int] = None, - top_p: Optional[float] = None, - user: Optional[str] = None, - ) -> OpenAIChatCompletion: - raise ValueError("Sentence transformers don't support openai chat completion") diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 696e72a32..d7555c39f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -9,11 +9,9 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union import httpx from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk as OpenAIChatCompletionChunk, ) -from openai.types.completion import Completion as OpenAICompletion from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -47,7 +45,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIMessageParam +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 3f1846b76..d9091d5c8 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -9,7 +9,6 @@ import warnings from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union from openai import AsyncStream -from openai.types.chat import ChatCompletion as OpenAIChatCompletion from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ) @@ -55,7 +54,6 @@ from openai.types.chat.chat_completion_content_part_image_param import ( from openai.types.chat.chat_completion_message_tool_call_param import ( Function as OpenAIFunction, ) -from openai.types.completion import Completion as OpenAICompletion from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -85,6 +83,7 @@ from llama_stack.apis.inference import ( TopPSamplingStrategy, UserMessage, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason,