mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Code refactoring and removing dead code
This commit is contained in:
parent
ef0736527d
commit
f6080040da
6 changed files with 302 additions and 137 deletions
|
@ -6,16 +6,7 @@
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Protocol,
|
|
||||||
runtime_checkable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||||
from llama_stack.apis.common.responses import Order
|
from llama_stack.apis.common.responses import Order
|
||||||
|
@ -32,6 +23,9 @@ from llama_stack.models.llama.datatypes import (
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
register_schema(ToolCall)
|
register_schema(ToolCall)
|
||||||
register_schema(ToolDefinition)
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
@ -357,32 +351,32 @@ class CompletionRequest(BaseModel):
|
||||||
logprobs: LogProbConfig | None = None
|
logprobs: LogProbConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
# @json_schema_type
|
||||||
class CompletionResponse(MetricResponseMixin):
|
# class CompletionResponse(MetricResponseMixin):
|
||||||
"""Response from a completion request.
|
# """Response from a completion request.
|
||||||
|
|
||||||
:param content: The generated completion text
|
# :param content: The generated completion text
|
||||||
:param stop_reason: Reason why generation stopped
|
# :param stop_reason: Reason why generation stopped
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
# :param logprobs: Optional log probabilities for generated tokens
|
||||||
"""
|
# """
|
||||||
|
|
||||||
content: str
|
# content: str
|
||||||
stop_reason: StopReason
|
# stop_reason: StopReason
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
# logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
# @json_schema_type
|
||||||
class CompletionResponseStreamChunk(MetricResponseMixin):
|
# class CompletionResponseStreamChunk(MetricResponseMixin):
|
||||||
"""A chunk of a streamed completion response.
|
# """A chunk of a streamed completion response.
|
||||||
|
|
||||||
:param delta: New content generated since last chunk. This can be one or more tokens.
|
# :param delta: New content generated since last chunk. This can be one or more tokens.
|
||||||
:param stop_reason: Optional reason why generation stopped, if complete
|
# :param stop_reason: Optional reason why generation stopped, if complete
|
||||||
:param logprobs: Optional log probabilities for generated tokens
|
# :param logprobs: Optional log probabilities for generated tokens
|
||||||
"""
|
# """
|
||||||
|
|
||||||
delta: str
|
# delta: str
|
||||||
stop_reason: StopReason | None = None
|
# stop_reason: StopReason | None = None
|
||||||
logprobs: list[TokenLogProbs] | None = None
|
# logprobs: list[TokenLogProbs] | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageBehavior(Enum):
|
class SystemMessageBehavior(Enum):
|
||||||
|
@ -415,7 +409,9 @@ class ToolConfig(BaseModel):
|
||||||
|
|
||||||
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
tool_choice: ToolChoice | str | None = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
tool_prompt_format: ToolPromptFormat | None = Field(default=None)
|
||||||
system_message_behavior: SystemMessageBehavior | None = Field(default=SystemMessageBehavior.append)
|
system_message_behavior: SystemMessageBehavior | None = Field(
|
||||||
|
default=SystemMessageBehavior.append
|
||||||
|
)
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if isinstance(self.tool_choice, str):
|
if isinstance(self.tool_choice, str):
|
||||||
|
@ -544,15 +540,21 @@ class OpenAIFile(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionContentPartParam = Annotated[
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam | OpenAIFile,
|
OpenAIChatCompletionContentPartTextParam
|
||||||
|
| OpenAIChatCompletionContentPartImageParam
|
||||||
|
| OpenAIFile,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
register_schema(
|
||||||
|
OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
OpenAIChatCompletionMessageContent = str | list[OpenAIChatCompletionContentPartParam]
|
||||||
|
|
||||||
OpenAIChatCompletionTextOnlyMessageContent = str | list[OpenAIChatCompletionContentPartTextParam]
|
OpenAIChatCompletionTextOnlyMessageContent = (
|
||||||
|
str | list[OpenAIChatCompletionContentPartTextParam]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -720,7 +722,9 @@ class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
OpenAIResponseFormatParam = Annotated[
|
OpenAIResponseFormatParam = Annotated[
|
||||||
OpenAIResponseFormatText | OpenAIResponseFormatJSONSchema | OpenAIResponseFormatJSONObject,
|
OpenAIResponseFormatText
|
||||||
|
| OpenAIResponseFormatJSONSchema
|
||||||
|
| OpenAIResponseFormatJSONObject,
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
@ -1049,8 +1053,16 @@ class InferenceProvider(Protocol):
|
||||||
async def rerank(
|
async def rerank(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
query: (
|
||||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
str
|
||||||
|
| OpenAIChatCompletionContentPartTextParam
|
||||||
|
| OpenAIChatCompletionContentPartImageParam
|
||||||
|
),
|
||||||
|
items: list[
|
||||||
|
str
|
||||||
|
| OpenAIChatCompletionContentPartTextParam
|
||||||
|
| OpenAIChatCompletionContentPartImageParam
|
||||||
|
],
|
||||||
max_num_results: int | None = None,
|
max_num_results: int | None = None,
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
"""Rerank a list of documents based on their relevance to a query.
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
@ -1064,7 +1076,12 @@ class InferenceProvider(Protocol):
|
||||||
raise NotImplementedError("Reranking is not implemented")
|
raise NotImplementedError("Reranking is not implemented")
|
||||||
return # this is so mypy's safe-super rule will consider the method concrete
|
return # this is so mypy's safe-super rule will consider the method concrete
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(
|
||||||
|
route="/openai/v1/completions",
|
||||||
|
method="POST",
|
||||||
|
level=LLAMA_STACK_API_V1,
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -1116,7 +1133,12 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(
|
||||||
|
route="/openai/v1/chat/completions",
|
||||||
|
method="POST",
|
||||||
|
level=LLAMA_STACK_API_V1,
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -1173,7 +1195,12 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(
|
||||||
|
route="/openai/v1/embeddings",
|
||||||
|
method="POST",
|
||||||
|
level=LLAMA_STACK_API_V1,
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -1203,7 +1230,12 @@ class Inference(InferenceProvider):
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(
|
||||||
|
route="/openai/v1/chat/completions",
|
||||||
|
method="GET",
|
||||||
|
level=LLAMA_STACK_API_V1,
|
||||||
|
deprecated=True,
|
||||||
|
)
|
||||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_chat_completions(
|
async def list_chat_completions(
|
||||||
self,
|
self,
|
||||||
|
@ -1223,10 +1255,19 @@ class Inference(InferenceProvider):
|
||||||
raise NotImplementedError("List chat completions is not implemented")
|
raise NotImplementedError("List chat completions is not implemented")
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
route="/openai/v1/chat/completions/{completion_id}",
|
||||||
|
method="GET",
|
||||||
|
level=LLAMA_STACK_API_V1,
|
||||||
|
deprecated=True,
|
||||||
)
|
)
|
||||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
route="/chat/completions/{completion_id}",
|
||||||
|
method="GET",
|
||||||
|
level=LLAMA_STACK_API_V1,
|
||||||
|
)
|
||||||
|
async def get_chat_completion(
|
||||||
|
self, completion_id: str
|
||||||
|
) -> OpenAICompletionWithInputMessages:
|
||||||
"""Describe a chat completion by its ID.
|
"""Describe a chat completion by its ID.
|
||||||
|
|
||||||
:param completion_id: ID of the chat completion.
|
:param completion_id: ID of the chat completion.
|
||||||
|
|
|
@ -7,24 +7,16 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from datetime import UTC, datetime
|
from datetime import datetime, UTC
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
|
||||||
from pydantic import Field, TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
InterleavedContent,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
CompletionResponse,
|
|
||||||
CompletionResponseStreamChunk,
|
|
||||||
Inference,
|
Inference,
|
||||||
ListOpenAIChatCompletionResponse,
|
ListOpenAIChatCompletionResponse,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
|
@ -57,7 +49,16 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
enqueue_event,
|
||||||
|
get_current_span,
|
||||||
|
)
|
||||||
|
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam,
|
||||||
|
ChatCompletionToolParam as OpenAIChatCompletionToolParam,
|
||||||
|
)
|
||||||
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core::routers")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
@ -101,7 +102,9 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||||
)
|
)
|
||||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
await self.routing_table.register_model(
|
||||||
|
model_id, provider_model_id, provider_id, metadata, model_type
|
||||||
|
)
|
||||||
|
|
||||||
def _construct_metrics(
|
def _construct_metrics(
|
||||||
self,
|
self,
|
||||||
|
@ -156,11 +159,16 @@ class InferenceRouter(Inference):
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> list[MetricInResponse]:
|
) -> list[MetricInResponse]:
|
||||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
metrics = self._construct_metrics(
|
||||||
|
prompt_tokens, completion_tokens, total_tokens, model
|
||||||
|
)
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
enqueue_event(metric)
|
enqueue_event(metric)
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
return [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in metrics
|
||||||
|
]
|
||||||
|
|
||||||
async def _count_tokens(
|
async def _count_tokens(
|
||||||
self,
|
self,
|
||||||
|
@ -207,8 +215,13 @@ class InferenceRouter(Inference):
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format:
|
if (
|
||||||
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match")
|
tool_prompt_format
|
||||||
|
and tool_prompt_format != tool_config.tool_prompt_format
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"tool_prompt_format and tool_config.tool_prompt_format must match"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params = {}
|
params = {}
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
@ -226,9 +239,14 @@ class InferenceRouter(Inference):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# verify tool_choice is one of the tools
|
# verify tool_choice is one of the tools
|
||||||
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools]
|
tool_names = [
|
||||||
|
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
|
||||||
|
for t in tools
|
||||||
|
]
|
||||||
if tool_config.tool_choice not in tool_names:
|
if tool_config.tool_choice not in tool_names:
|
||||||
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}")
|
raise ValueError(
|
||||||
|
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
|
||||||
|
)
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -243,7 +261,9 @@ class InferenceRouter(Inference):
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format)
|
prompt_tokens = await self._count_tokens(
|
||||||
|
messages, tool_config.tool_prompt_format
|
||||||
|
)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
response_stream = await provider.chat_completion(**params)
|
response_stream = await provider.chat_completion(**params)
|
||||||
|
@ -263,7 +283,9 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
# these metrics will show up in the client response.
|
# these metrics will show up in the client response.
|
||||||
response.metrics = (
|
response.metrics = (
|
||||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
metrics
|
||||||
|
if not hasattr(response, "metrics") or response.metrics is None
|
||||||
|
else response.metrics + metrics
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -336,7 +358,9 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
# these metrics will show up in the client response.
|
# these metrics will show up in the client response.
|
||||||
response.metrics = (
|
response.metrics = (
|
||||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
metrics
|
||||||
|
if not hasattr(response, "metrics") or response.metrics is None
|
||||||
|
else response.metrics + metrics
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -374,9 +398,13 @@ class InferenceRouter(Inference):
|
||||||
# Use the OpenAI client for a bit of extra input validation without
|
# Use the OpenAI client for a bit of extra input validation without
|
||||||
# exposing the OpenAI client itself as part of our API surface
|
# exposing the OpenAI client itself as part of our API surface
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(
|
||||||
|
tool_choice
|
||||||
|
)
|
||||||
if tools is None:
|
if tools is None:
|
||||||
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
raise ValueError(
|
||||||
|
"'tool_choice' is only allowed when 'tools' is also provided"
|
||||||
|
)
|
||||||
if tools:
|
if tools:
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||||
|
@ -441,7 +469,9 @@ class InferenceRouter(Inference):
|
||||||
enqueue_event(metric)
|
enqueue_event(metric)
|
||||||
# these metrics will show up in the client response.
|
# these metrics will show up in the client response.
|
||||||
response.metrics = (
|
response.metrics = (
|
||||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
metrics
|
||||||
|
if not hasattr(response, "metrics") or response.metrics is None
|
||||||
|
else response.metrics + metrics
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -477,19 +507,31 @@ class InferenceRouter(Inference):
|
||||||
) -> ListOpenAIChatCompletionResponse:
|
) -> ListOpenAIChatCompletionResponse:
|
||||||
if self.store:
|
if self.store:
|
||||||
return await self.store.list_chat_completions(after, limit, model, order)
|
return await self.store.list_chat_completions(after, limit, model, order)
|
||||||
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
raise NotImplementedError(
|
||||||
|
"List chat completions is not supported: inference store is not configured."
|
||||||
|
)
|
||||||
|
|
||||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
async def get_chat_completion(
|
||||||
|
self, completion_id: str
|
||||||
|
) -> OpenAICompletionWithInputMessages:
|
||||||
if self.store:
|
if self.store:
|
||||||
return await self.store.get_chat_completion(completion_id)
|
return await self.store.get_chat_completion(completion_id)
|
||||||
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
raise NotImplementedError(
|
||||||
|
"Get chat completion is not supported: inference store is not configured."
|
||||||
|
)
|
||||||
|
|
||||||
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
async def _nonstream_openai_chat_completion(
|
||||||
|
self, provider: Inference, params: dict
|
||||||
|
) -> OpenAIChatCompletion:
|
||||||
response = await provider.openai_chat_completion(**params)
|
response = await provider.openai_chat_completion(**params)
|
||||||
for choice in response.choices:
|
for choice in response.choices:
|
||||||
# some providers return an empty list for no tool calls in non-streaming responses
|
# some providers return an empty list for no tool calls in non-streaming responses
|
||||||
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
||||||
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
if (
|
||||||
|
choice.message
|
||||||
|
and choice.message.tool_calls is not None
|
||||||
|
and len(choice.message.tool_calls) == 0
|
||||||
|
):
|
||||||
choice.message.tool_calls = None
|
choice.message.tool_calls = None
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -509,7 +551,9 @@ class InferenceRouter(Inference):
|
||||||
message=f"Health check timed out after {timeout} seconds",
|
message=f"Health check timed out after {timeout} seconds",
|
||||||
)
|
)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
health_statuses[provider_id] = HealthResponse(
|
||||||
|
status=HealthStatus.NOT_IMPLEMENTED
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
health_statuses[provider_id] = HealthResponse(
|
health_statuses[provider_id] = HealthResponse(
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
|
@ -522,7 +566,7 @@ class InferenceRouter(Inference):
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
model,
|
model,
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
completion_text = ""
|
completion_text = ""
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
complete = False
|
complete = False
|
||||||
|
@ -544,7 +588,11 @@ class InferenceRouter(Inference):
|
||||||
else:
|
else:
|
||||||
if hasattr(chunk, "delta"):
|
if hasattr(chunk, "delta"):
|
||||||
completion_text += chunk.delta
|
completion_text += chunk.delta
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
if (
|
||||||
|
hasattr(chunk, "stop_reason")
|
||||||
|
and chunk.stop_reason
|
||||||
|
and self.telemetry
|
||||||
|
):
|
||||||
complete = True
|
complete = True
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
completion_tokens = await self._count_tokens(completion_text)
|
||||||
# if we are done receiving tokens
|
# if we are done receiving tokens
|
||||||
|
@ -569,9 +617,14 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
# Return metrics in response
|
# Return metrics in response
|
||||||
async_metrics = [
|
async_metrics = [
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in completion_metrics
|
||||||
]
|
]
|
||||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
chunk.metrics = (
|
||||||
|
async_metrics
|
||||||
|
if chunk.metrics is None
|
||||||
|
else chunk.metrics + async_metrics
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback if no telemetry
|
# Fallback if no telemetry
|
||||||
completion_metrics = self._construct_metrics(
|
completion_metrics = self._construct_metrics(
|
||||||
|
@ -581,14 +634,19 @@ class InferenceRouter(Inference):
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
async_metrics = [
|
async_metrics = [
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in completion_metrics
|
||||||
]
|
]
|
||||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
chunk.metrics = (
|
||||||
|
async_metrics
|
||||||
|
if chunk.metrics is None
|
||||||
|
else chunk.metrics + async_metrics
|
||||||
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def count_tokens_and_compute_metrics(
|
async def count_tokens_and_compute_metrics(
|
||||||
self,
|
self,
|
||||||
response: ChatCompletionResponse | CompletionResponse,
|
response: ChatCompletionResponse,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
model,
|
model,
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
tool_prompt_format: ToolPromptFormat | None = None,
|
||||||
|
@ -597,7 +655,9 @@ class InferenceRouter(Inference):
|
||||||
content = [response.completion_message]
|
content = [response.completion_message]
|
||||||
else:
|
else:
|
||||||
content = response.content
|
content = response.content
|
||||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
completion_tokens = await self._count_tokens(
|
||||||
|
messages=content, tool_prompt_format=tool_prompt_format
|
||||||
|
)
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||||
|
|
||||||
# Create a separate span for completion metrics
|
# Create a separate span for completion metrics
|
||||||
|
@ -610,11 +670,17 @@ class InferenceRouter(Inference):
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
for metric in completion_metrics:
|
for metric in completion_metrics:
|
||||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
if metric.metric in [
|
||||||
|
"completion_tokens",
|
||||||
|
"total_tokens",
|
||||||
|
]: # Only log completion and total tokens
|
||||||
enqueue_event(metric)
|
enqueue_event(metric)
|
||||||
|
|
||||||
# Return metrics in response
|
# Return metrics in response
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
return [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in completion_metrics
|
||||||
|
]
|
||||||
|
|
||||||
# Fallback if no telemetry
|
# Fallback if no telemetry
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
|
@ -623,7 +689,10 @@ class InferenceRouter(Inference):
|
||||||
total_tokens,
|
total_tokens,
|
||||||
model,
|
model,
|
||||||
)
|
)
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
return [
|
||||||
|
MetricInResponse(metric=metric.metric, value=metric.value)
|
||||||
|
for metric in metrics
|
||||||
|
]
|
||||||
|
|
||||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||||
self,
|
self,
|
||||||
|
@ -664,33 +733,48 @@ class InferenceRouter(Inference):
|
||||||
if choice_delta.delta:
|
if choice_delta.delta:
|
||||||
delta = choice_delta.delta
|
delta = choice_delta.delta
|
||||||
if delta.content:
|
if delta.content:
|
||||||
current_choice_data["content_parts"].append(delta.content)
|
current_choice_data["content_parts"].append(
|
||||||
|
delta.content
|
||||||
|
)
|
||||||
if delta.tool_calls:
|
if delta.tool_calls:
|
||||||
for tool_call_delta in delta.tool_calls:
|
for tool_call_delta in delta.tool_calls:
|
||||||
tc_idx = tool_call_delta.index
|
tc_idx = tool_call_delta.index
|
||||||
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
if (
|
||||||
current_choice_data["tool_calls_builder"][tc_idx] = {
|
tc_idx
|
||||||
|
not in current_choice_data["tool_calls_builder"]
|
||||||
|
):
|
||||||
|
current_choice_data["tool_calls_builder"][
|
||||||
|
tc_idx
|
||||||
|
] = {
|
||||||
"id": None,
|
"id": None,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function_name_parts": [],
|
"function_name_parts": [],
|
||||||
"function_arguments_parts": [],
|
"function_arguments_parts": [],
|
||||||
}
|
}
|
||||||
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
builder = current_choice_data["tool_calls_builder"][
|
||||||
|
tc_idx
|
||||||
|
]
|
||||||
if tool_call_delta.id:
|
if tool_call_delta.id:
|
||||||
builder["id"] = tool_call_delta.id
|
builder["id"] = tool_call_delta.id
|
||||||
if tool_call_delta.type:
|
if tool_call_delta.type:
|
||||||
builder["type"] = tool_call_delta.type
|
builder["type"] = tool_call_delta.type
|
||||||
if tool_call_delta.function:
|
if tool_call_delta.function:
|
||||||
if tool_call_delta.function.name:
|
if tool_call_delta.function.name:
|
||||||
builder["function_name_parts"].append(tool_call_delta.function.name)
|
builder["function_name_parts"].append(
|
||||||
|
tool_call_delta.function.name
|
||||||
|
)
|
||||||
if tool_call_delta.function.arguments:
|
if tool_call_delta.function.arguments:
|
||||||
builder["function_arguments_parts"].append(
|
builder["function_arguments_parts"].append(
|
||||||
tool_call_delta.function.arguments
|
tool_call_delta.function.arguments
|
||||||
)
|
)
|
||||||
if choice_delta.finish_reason:
|
if choice_delta.finish_reason:
|
||||||
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
current_choice_data["finish_reason"] = (
|
||||||
|
choice_delta.finish_reason
|
||||||
|
)
|
||||||
if choice_delta.logprobs and choice_delta.logprobs.content:
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
||||||
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
current_choice_data["logprobs_content_parts"].extend(
|
||||||
|
choice_delta.logprobs.content
|
||||||
|
)
|
||||||
|
|
||||||
# Compute metrics on final chunk
|
# Compute metrics on final chunk
|
||||||
if chunk.choices and chunk.choices[0].finish_reason:
|
if chunk.choices and chunk.choices[0].finish_reason:
|
||||||
|
@ -720,8 +804,12 @@ class InferenceRouter(Inference):
|
||||||
if choice_data["tool_calls_builder"]:
|
if choice_data["tool_calls_builder"]:
|
||||||
for tc_build_data in choice_data["tool_calls_builder"].values():
|
for tc_build_data in choice_data["tool_calls_builder"].values():
|
||||||
if tc_build_data["id"]:
|
if tc_build_data["id"]:
|
||||||
func_name = "".join(tc_build_data["function_name_parts"])
|
func_name = "".join(
|
||||||
func_args = "".join(tc_build_data["function_arguments_parts"])
|
tc_build_data["function_name_parts"]
|
||||||
|
)
|
||||||
|
func_args = "".join(
|
||||||
|
tc_build_data["function_arguments_parts"]
|
||||||
|
)
|
||||||
assembled_tool_calls.append(
|
assembled_tool_calls.append(
|
||||||
OpenAIChatCompletionToolCall(
|
OpenAIChatCompletionToolCall(
|
||||||
id=tc_build_data["id"],
|
id=tc_build_data["id"],
|
||||||
|
@ -734,10 +822,16 @@ class InferenceRouter(Inference):
|
||||||
message = OpenAIAssistantMessageParam(
|
message = OpenAIAssistantMessageParam(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=content_str if content_str else None,
|
content=content_str if content_str else None,
|
||||||
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
tool_calls=(
|
||||||
|
assembled_tool_calls if assembled_tool_calls else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logprobs_content = choice_data["logprobs_content_parts"]
|
logprobs_content = choice_data["logprobs_content_parts"]
|
||||||
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
final_logprobs = (
|
||||||
|
OpenAIChoiceLogprobs(content=logprobs_content)
|
||||||
|
if logprobs_content
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
assembled_choices.append(
|
assembled_choices.append(
|
||||||
OpenAIChoice(
|
OpenAIChoice(
|
||||||
|
@ -756,4 +850,6 @@ class InferenceRouter(Inference):
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
asyncio.create_task(
|
||||||
|
self.store.store_chat_completion(final_response, messages)
|
||||||
|
)
|
||||||
|
|
|
@ -25,9 +25,6 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import SentenceTransformersInferenceConfig
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
@ -35,7 +32,6 @@ log = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformersInferenceImpl(
|
class SentenceTransformersInferenceImpl(
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
|
@ -114,4 +110,6 @@ class SentenceTransformersInferenceImpl(
|
||||||
# for fill-in-the-middle type completion
|
# for fill-in-the-middle type completion
|
||||||
suffix: str | None = None,
|
suffix: str | None = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
|
raise NotImplementedError(
|
||||||
|
"OpenAI completion not supported by sentence transformers provider"
|
||||||
|
)
|
||||||
|
|
|
@ -11,8 +11,7 @@ from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
ChatCompletionResponse,
|
||||||
CompletionResponse,
|
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
@ -25,9 +24,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
ModelRegistryHelper,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
|
@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import CerebrasImplConfig
|
from .config import CerebrasImplConfig
|
||||||
|
@ -102,14 +98,18 @@ class CerebrasInferenceAdapter(
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
r = await self._cerebras_client.completions.create(**params)
|
r = await self._cerebras_client.completions.create(**params)
|
||||||
|
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self._cerebras_client.completions.create(**params)
|
stream = await self._cerebras_client.completions.create(**params)
|
||||||
|
@ -117,15 +117,17 @@ class CerebrasInferenceAdapter(
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
if request.sampling_params and isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
if request.sampling_params and isinstance(
|
||||||
|
request.sampling_params.strategy, TopKSamplingStrategy
|
||||||
|
):
|
||||||
raise ValueError("`top_k` not supported by Cerebras")
|
raise ValueError("`top_k` not supported by Cerebras")
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
prompt = await chat_completion_request_to_prompt(
|
||||||
elif isinstance(request, CompletionRequest):
|
request, self.get_llama_model(request.model)
|
||||||
prompt = await completion_request_to_prompt(request)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown request type {type(request)}")
|
raise ValueError(f"Unknown request type {type(request)}")
|
||||||
|
|
||||||
|
|
|
@ -10,11 +10,13 @@ from openai import OpenAI
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||||
|
|
||||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
build_hf_repo_model_entry,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
@ -41,13 +43,12 @@ RUNPOD_SUPPORTED_MODELS = {
|
||||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||||
}
|
}
|
||||||
|
|
||||||
SAFETY_MODELS_ENTRIES = []
|
|
||||||
|
|
||||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
||||||
MODEL_ENTRIES = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
||||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
||||||
] + SAFETY_MODELS_ENTRIES
|
]
|
||||||
|
|
||||||
|
|
||||||
class RunpodInferenceAdapter(
|
class RunpodInferenceAdapter(
|
||||||
|
@ -56,7 +57,9 @@ class RunpodInferenceAdapter(
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
):
|
):
|
||||||
def __init__(self, config: RunpodImplConfig) -> None:
|
def __init__(self, config: RunpodImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
ModelRegistryHelper.__init__(
|
||||||
|
self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
@ -103,7 +106,9 @@ class RunpodInferenceAdapter(
|
||||||
r = client.completions.create(**params)
|
r = client.completions.create(**params)
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest, client: OpenAI
|
||||||
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
|
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
|
|
@ -9,12 +9,10 @@ from typing import Any
|
||||||
|
|
||||||
from ibm_watsonx_ai.foundation_models import Model
|
from ibm_watsonx_ai.foundation_models import Model
|
||||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
|
@ -48,6 +46,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from . import WatsonXConfig
|
from . import WatsonXConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
@ -85,7 +84,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_client(self, model_id) -> Model:
|
def _get_client(self, model_id) -> Model:
|
||||||
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
config_api_key = (
|
||||||
|
self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||||
|
)
|
||||||
config_url = self._config.url
|
config_url = self._config.url
|
||||||
project_id = self._config.project_id
|
project_id = self._config.project_id
|
||||||
credentials = {"url": config_url, "apikey": config_api_key}
|
credentials = {"url": config_url, "apikey": config_api_key}
|
||||||
|
@ -132,14 +133,18 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = self._get_client(request.model).generate(**params)
|
r = self._get_client(request.model).generate(**params)
|
||||||
choices = []
|
choices = []
|
||||||
if "results" in r:
|
if "results" in r:
|
||||||
for result in r["results"]:
|
for result in r["results"]:
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
finish_reason=(
|
||||||
|
result["stop_reason"] if result["stop_reason"] else None
|
||||||
|
),
|
||||||
text=result["generated_text"],
|
text=result["generated_text"],
|
||||||
)
|
)
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
@ -148,7 +153,9 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(
|
||||||
|
self, request: ChatCompletionRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
model_id = request.model
|
model_id = request.model
|
||||||
|
|
||||||
|
@ -168,28 +175,44 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest | CompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
input_dict = {"params": {}}
|
input_dict = {"params": {}}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
|
request, llama_model
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
assert (
|
||||||
|
not media_present
|
||||||
|
), "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
if request.sampling_params:
|
if request.sampling_params:
|
||||||
if request.sampling_params.strategy:
|
if request.sampling_params.strategy:
|
||||||
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
input_dict["params"][
|
||||||
|
GenParams.DECODING_METHOD
|
||||||
|
] = request.sampling_params.strategy.type
|
||||||
if request.sampling_params.max_tokens:
|
if request.sampling_params.max_tokens:
|
||||||
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
input_dict["params"][
|
||||||
|
GenParams.MAX_NEW_TOKENS
|
||||||
|
] = request.sampling_params.max_tokens
|
||||||
if request.sampling_params.repetition_penalty:
|
if request.sampling_params.repetition_penalty:
|
||||||
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
input_dict["params"][
|
||||||
|
GenParams.REPETITION_PENALTY
|
||||||
|
] = request.sampling_params.repetition_penalty
|
||||||
|
|
||||||
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||||
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
input_dict["params"][
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
GenParams.TOP_P
|
||||||
|
] = request.sampling_params.strategy.top_p
|
||||||
|
input_dict["params"][
|
||||||
|
GenParams.TEMPERATURE
|
||||||
|
] = request.sampling_params.strategy.temperature
|
||||||
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||||
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
input_dict["params"][
|
||||||
|
GenParams.TOP_K
|
||||||
|
] = request.sampling_params.strategy.top_k
|
||||||
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||||
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue