mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Merge remote-tracking branch 'origin/main' into migrate-eval-to-openai
This commit is contained in:
commit
1222657626
31 changed files with 247 additions and 720 deletions
|
@ -924,7 +924,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
|
|
|
@ -11,21 +11,17 @@ from botocore.client import BaseClient
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -47,8 +43,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
@ -218,36 +212,6 @@ class BedrockInferenceAdapter(
|
|||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(model.provider_resource_id, region_name)
|
||||
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_content_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=inference_profile_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,21 +11,17 @@ from cerebras.cloud.sdk import AsyncCerebras
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,15 +11,12 @@ from databricks.sdk import WorkspaceClient
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -27,7 +24,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -118,16 +114,6 @@ class DatabricksInferenceAdapter(
|
|||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {} # from OpenAIMixin
|
||||
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
|
||||
|
|
|
@ -10,22 +10,18 @@ from fireworks.client import Fireworks
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -48,8 +44,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -259,28 +253,3 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
|||
logger.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimension"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Fireworks does not support media for embeddings"
|
||||
)
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -11,8 +11,6 @@ from openai import NOT_GIVEN, APIConnectionError
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -21,8 +19,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -31,7 +27,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
|
@ -156,60 +151,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
|||
# we pass n=1 to get only one completion
|
||||
return convert_openai_completion_choice(response.choices[0])
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
if any(content_has_media(content) for content in contents):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
#
|
||||
# Llama Stack: contents = list[str] | list[InterleavedContentItem]
|
||||
# ->
|
||||
# OpenAI: input = str | list[str]
|
||||
#
|
||||
# we can ignore str and always pass list[str] to OpenAI
|
||||
#
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
|
||||
extra_body = {}
|
||||
|
||||
if text_truncation is not None:
|
||||
text_truncation_options = {
|
||||
TextTruncation.none: "NONE",
|
||||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
extra_body["dimensions"] = output_dimension
|
||||
|
||||
if task_type is not None:
|
||||
task_type_options = {
|
||||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||
# ->
|
||||
# Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
|
||||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -14,7 +14,6 @@ from ollama import AsyncClient as AsyncOllamaClient
|
|||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
|
@ -25,8 +24,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
|
@ -34,7 +31,6 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -66,9 +62,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_content_to_url,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -363,27 +357,6 @@ class OllamaInferenceAdapter(
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
)
|
||||
response = await self.ollama_client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
|
|
|
@ -14,8 +14,6 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -27,7 +25,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -190,25 +187,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[InterleavedContent],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return await client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -136,16 +136,6 @@ class RunpodInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -12,14 +12,11 @@ from pydantic import SecretStr
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -27,7 +24,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -306,16 +302,6 @@ class _HfAdapter(
|
|||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -12,14 +12,11 @@ from together.constants import BASE_URL
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -27,7 +24,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -50,8 +46,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -247,26 +241,6 @@ class TogetherInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Need
|
|||
logger.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {}
|
||||
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
|
||||
|
|
|
@ -16,7 +16,6 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
|
@ -31,8 +30,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
|
@ -41,7 +38,6 @@ from llama_stack.apis.inference import (
|
|||
ModelStore,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -74,8 +70,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -550,27 +544,3 @@ class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsPro
|
|||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = await self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
|
|
@ -11,13 +11,11 @@ from ibm_watsonx_ai.foundation_models import Model
|
|||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
|
@ -30,7 +28,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -265,16 +262,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
}
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -15,16 +15,11 @@ if TYPE_CHECKING:
|
|||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
InterleavedContentItem,
|
||||
ModelStore,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
TextTruncation,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
@ -35,23 +30,6 @@ log = get_logger(name=__name__, category="providers::utils")
|
|||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = await asyncio.to_thread(
|
||||
embedding_model.encode,
|
||||
[interleaved_content_as_str(content) for content in contents],
|
||||
show_progress_bar=False,
|
||||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,14 +11,11 @@ import litellm
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
|
@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -50,9 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
@ -269,24 +262,6 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
return api_key
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model.provider_resource_id),
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
||||
embeddings = [data["embedding"] for data in response["data"]]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -3,6 +3,9 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Order,
|
||||
)
|
||||
|
@ -14,24 +17,51 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectWithInput,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.datatypes import AccessRule, ResponsesStoreConfig
|
||||
from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..sqlstore.api import ColumnDefinition, ColumnType
|
||||
from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, sqlstore_impl
|
||||
from ..sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig, SqlStoreType, sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="responses_store")
|
||||
|
||||
|
||||
class ResponsesStore:
|
||||
def __init__(self, sql_store_config: SqlStoreConfig, policy: list[AccessRule]):
|
||||
if not sql_store_config:
|
||||
sql_store_config = SqliteSqlStoreConfig(
|
||||
def __init__(
|
||||
self,
|
||||
config: ResponsesStoreConfig | SqlStoreConfig,
|
||||
policy: list[AccessRule],
|
||||
):
|
||||
# Handle backward compatibility
|
||||
if not isinstance(config, ResponsesStoreConfig):
|
||||
# Legacy: SqlStoreConfig passed directly as config
|
||||
config = ResponsesStoreConfig(
|
||||
sql_store_config=config,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.sql_store_config = config.sql_store_config
|
||||
if not self.sql_store_config:
|
||||
self.sql_store_config = SqliteSqlStoreConfig(
|
||||
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
)
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
self.enable_write_queue = self.sql_store_config.type != SqlStoreType.sqlite
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIResponseObject, list[OpenAIResponseInput]]] | None = None
|
||||
self._worker_tasks: list[asyncio.Task[Any]] = []
|
||||
self._max_write_queue_size: int = config.max_write_queue_size
|
||||
self._num_writers: int = max(1, config.num_writers)
|
||||
|
||||
async def initialize(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
@ -42,9 +72,68 @@ class ResponsesStore:
|
|||
},
|
||||
)
|
||||
|
||||
if self.enable_write_queue:
|
||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
return
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
for t in self._worker_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
for t in self._worker_tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._worker_tasks.clear()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all queued writes to complete. Useful for testing."""
|
||||
if self.enable_write_queue and self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def store_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
) -> None:
|
||||
if self.enable_write_queue:
|
||||
if self._queue is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
try:
|
||||
self._queue.put_nowait((response_object, input))
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"Write queue full; adding response id={getattr(response_object, 'id', '<unknown>')}")
|
||||
await self._queue.put((response_object, input))
|
||||
else:
|
||||
await self._write_response_object(response_object, input)
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
try:
|
||||
item = await self._queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
response_object, input = item
|
||||
try:
|
||||
await self._write_response_object(response_object, input)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error writing response object: {e}")
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _write_response_object(
|
||||
self, response_object: OpenAIResponseObject, input: list[OpenAIResponseInput]
|
||||
) -> None:
|
||||
if self.sql_store is None:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
data = response_object.model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in input]
|
||||
|
||||
|
@ -73,6 +162,9 @@ class ResponsesStore:
|
|||
:param model: The model to filter by.
|
||||
:param order: The order to sort the responses by.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
|
@ -100,6 +192,9 @@ class ResponsesStore:
|
|||
"""
|
||||
Get a response object with automatic access control checking.
|
||||
"""
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one(
|
||||
"openai_responses",
|
||||
where={"id": response_id},
|
||||
|
@ -113,6 +208,9 @@ class ResponsesStore:
|
|||
return OpenAIResponseObjectWithInput(**row["response_object"])
|
||||
|
||||
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
if not self.sql_store:
|
||||
raise ValueError("Responses store is not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
|
||||
if not row:
|
||||
raise ValueError(f"Response with id {response_id} not found")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue