Clean up some more usage of direct OpenAI types

This commit is contained in:
Ben Browning 2025-04-08 09:10:52 -04:00
parent 92fdf6d0c9
commit 5bc5fed6df
4 changed files with 10 additions and 64 deletions

View file

@ -7,9 +7,6 @@
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from openai.types.completion import Completion as OpenAICompletion
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -38,7 +35,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIMessageParam
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (

View file

@ -5,10 +5,7 @@
# the root directory of this source tree.
import logging
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from openai.types.completion import Completion as OpenAICompletion
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
@ -22,11 +19,14 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIMessageParam
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
)
from .config import SentenceTransformersInferenceConfig
@ -34,6 +34,8 @@ log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
OpenAIChatCompletionUnsupportedMixin,
OpenAICompletionUnsupportedMixin,
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
@ -78,53 +80,3 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def openai_completion(
self,
model: str,
prompt: str,
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAICompletion:
raise ValueError("Sentence transformers don't support openai completion")
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[Dict[str, str]] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> OpenAIChatCompletion:
raise ValueError("Sentence transformers don't support openai chat completion")

View file

@ -9,11 +9,9 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import httpx
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
)
from openai.types.completion import Completion as OpenAICompletion
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -47,7 +45,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import OpenAIMessageParam
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.models.llama.sku_list import all_registered_models

View file

@ -9,7 +9,6 @@ import warnings
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
)
@ -55,7 +54,6 @@ from openai.types.chat.chat_completion_content_part_image_param import (
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from openai.types.completion import Completion as OpenAICompletion
from pydantic import BaseModel
from llama_stack.apis.common.content_types import (
@ -85,6 +83,7 @@ from llama_stack.apis.inference import (
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,