Code refactoring and removing dead code

This commit is contained in:
Omar Abdelwahab 2025-10-02 18:38:30 -07:00
parent ef0736527d
commit f6080040da
6 changed files with 302 additions and 137 deletions

View file

@ -6,16 +6,7 @@
from collections.abc import AsyncIterator
from enum import Enum
from typing import (
Annotated,
Any,
Literal,
Protocol,
runtime_checkable,
)
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order
@ -32,6 +23,9 @@ from llama_stack.models.llama.datatypes import (
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
register_schema(ToolCall)
register_schema(ToolDefinition)
@ -357,32 +351,32 @@ class CompletionRequest(BaseModel):
logprobs: LogProbConfig | None = None
@json_schema_type
class CompletionResponse(MetricResponseMixin):
"""Response from a completion request.
# @json_schema_type
# class CompletionResponse(MetricResponseMixin):
# """Response from a completion request.
:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""
# :param content: The generated completion text
# :param stop_reason: Reason why generation stopped
# :param logprobs: Optional log probabilities for generated tokens
# """
content: str
stop_reason: StopReason
logprobs: list[TokenLogProbs] | None = None
# content: str
# stop_reason: StopReason
# logprobs: list[TokenLogProbs] | None = None
@json_schema_type
class CompletionResponseStreamChunk(MetricResponseMixin):
"""A chunk of a streamed completion response.
# @json_schema_type
# class CompletionResponseStreamChunk(MetricResponseMixin):
# """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""
# :param delta: New content generated since last chunk. This can be one or more tokens.
# :param stop_reason: Optional reason why generation stopped, if complete
# :param logprobs: Optional log probabilities for generated tokens
# """
delta: str
stop_reason: StopReason | None = None
logprobs: list[TokenLogProbs] | None = None
# delta: str
# stop_reason: StopReason | None = None
# logprobs: list[TokenLogProbs] | None = None
class SystemMessageBehavior(Enum):
@ -415,7 +409,9 @@ class ToolConfig(BaseModel):
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
system_message_behavior: SystemMessageBehavior | None = Field(
default=SystemMessageBehavior.append
)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.tool_choice, str):
@ -544,15 +540,21 @@ class OpenAIFile(BaseModel):
OpenAIChatCompletionContentPartParam = Annotated[
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
| OpenAIFile,
Field(discriminator="type"),
]
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
register_schema(
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
)
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
OpenAIChatCompletionTextOnlyMessageContent = (
str | list[OpenAIChatCompletionContentPartTextParam]
)
@json_schema_type
@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel):
OpenAIResponseFormatParam = Annotated[
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
OpenAIResponseFormatText
| OpenAIResponseFormatJSONSchema
| OpenAIResponseFormatJSONObject,
Field(discriminator="type"),
]
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
@ -1049,8 +1053,16 @@ class InferenceProvider(Protocol):
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
query: (
str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
),
items: list[
str
| OpenAIChatCompletionContentPartTextParam
| OpenAIChatCompletionContentPartImageParam
],
max_num_results: int | None = None,
) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query.
@ -1064,7 +1076,12 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
@ -1116,7 +1133,12 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/chat/completions",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
@ -1173,7 +1195,12 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/embeddings",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
@ -1203,7 +1230,12 @@ class Inference(InferenceProvider):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(
route="/openai/v1/chat/completions",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
@ -1223,10 +1255,19 @@ class Inference(InferenceProvider):
raise NotImplementedError("List chat completions is not implemented")
@webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
route="/openai/v1/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
@webmethod(
route="/chat/completions/{completion_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def get_chat_completion(
self, completion_id: str
) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.