mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Code refactoring and removing dead code
This commit is contained in:
parent
ef0736527d
commit
f6080040da
6 changed files with 302 additions and 137 deletions
|
@ -6,16 +6,7 @@
|
|||
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order
|
||||
|
@ -32,6 +23,9 @@ from llama_stack.models.llama.datatypes import (
|
|||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
|
@ -357,32 +351,32 @@ class CompletionRequest(BaseModel):
|
|||
logprobs: LogProbConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponse(MetricResponseMixin):
|
||||
"""Response from a completion request.
|
||||
# @json_schema_type
|
||||
# class CompletionResponse(MetricResponseMixin):
|
||||
# """Response from a completion request.
|
||||
|
||||
:param content: The generated completion text
|
||||
:param stop_reason: Reason why generation stopped
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
# :param content: The generated completion text
|
||||
# :param stop_reason: Reason why generation stopped
|
||||
# :param logprobs: Optional log probabilities for generated tokens
|
||||
# """
|
||||
|
||||
content: str
|
||||
stop_reason: StopReason
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
# content: str
|
||||
# stop_reason: StopReason
|
||||
# logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
"""A chunk of a streamed completion response.
|
||||
# @json_schema_type
|
||||
# class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||
# """A chunk of a streamed completion response.
|
||||
|
||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
:param stop_reason: Optional reason why generation stopped, if complete
|
||||
:param logprobs: Optional log probabilities for generated tokens
|
||||
"""
|
||||
# :param delta: New content generated since last chunk. This can be one or more tokens.
|
||||
# :param stop_reason: Optional reason why generation stopped, if complete
|
||||
# :param logprobs: Optional log probabilities for generated tokens
|
||||
# """
|
||||
|
||||
delta: str
|
||||
stop_reason: StopReason | None = None
|
||||
logprobs: list[TokenLogProbs] | None = None
|
||||
# delta: str
|
||||
# stop_reason: StopReason | None = None
|
||||
# logprobs: list[TokenLogProbs] | None = None
|
||||
|
||||
|
||||
class SystemMessageBehavior(Enum):
|
||||
|
@ -415,7 +409,9 @@ class ToolConfig(BaseModel):
|
|||
|
||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
||||
system_message_behavior: SystemMessageBehavior | None = Field(
|
||||
default=SystemMessageBehavior.append
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.tool_choice, str):
|
||||
|
@ -544,15 +540,21 @@ class OpenAIFile(BaseModel):
|
|||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
||||
OpenAIChatCompletionContentPartTextParam
|
||||
| OpenAIChatCompletionContentPartImageParam
|
||||
| OpenAIFile,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
register_schema(
|
||||
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
|
||||
)
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||
|
||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
OpenAIChatCompletionTextOnlyMessageContent = (
|
||||
str | list[OpenAIChatCompletionContentPartTextParam]
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
|||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatText
|
||||
| OpenAIResponseFormatJSONSchema
|
||||
| OpenAIResponseFormatJSONObject,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
@ -1049,8 +1053,16 @@ class InferenceProvider(Protocol):
|
|||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
query: (
|
||||
str
|
||||
| OpenAIChatCompletionContentPartTextParam
|
||||
| OpenAIChatCompletionContentPartImageParam
|
||||
),
|
||||
items: list[
|
||||
str
|
||||
| OpenAIChatCompletionContentPartTextParam
|
||||
| OpenAIChatCompletionContentPartImageParam
|
||||
],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
"""Rerank a list of documents based on their relevance to a query.
|
||||
|
@ -1064,7 +1076,12 @@ class InferenceProvider(Protocol):
|
|||
raise NotImplementedError("Reranking is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(
|
||||
route="/openai/v1/completions",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_completion(
|
||||
self,
|
||||
|
@ -1116,7 +1133,12 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(
|
||||
route="/openai/v1/chat/completions",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
@ -1173,7 +1195,12 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(
|
||||
route="/openai/v1/embeddings",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
@ -1203,7 +1230,12 @@ class Inference(InferenceProvider):
|
|||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(
|
||||
route="/openai/v1/chat/completions",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
|
@ -1223,10 +1255,19 @@ class Inference(InferenceProvider):
|
|||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(
|
||||
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
route="/openai/v1/chat/completions/{completion_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
@webmethod(
|
||||
route="/chat/completions/{completion_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def get_chat_completion(
|
||||
self, completion_id: str
|
||||
) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
||||
:param completion_id: ID of the chat completion.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue