llama-stack-mirror/llama_stack/providers/remote/inference/clarifai/clarifai.py
2025-03-07 17:12:52 +05:30

204 lines
7.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional, Union
from clarifai import client
from llama_stack import logcat
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
request_has_media,
)
from .config import ClarifaiImplConfig
from .models import MODEL_ENTRIES
class ClarifaiInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: ClarifaiImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
def _get_client(self) -> client:
return client
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
return NotImplementedError
def resolve_clarifai_model(self, model_name: str) -> str:
# model = self.get_llama_model(model_name)
# assert (
# model is not None and model in CLARIFAI_SUPPORTED_MODELS.values()
# ), f"Unsupported model: {model_name}, use one of the supported models: {','.join(CLARIFAI_SUPPORTED_MODELS.keys())}"
user_id, app_id, model_id = model_name.split("/")
return f"https://clarifai.com/{user_id}/{app_id}/models/{model_id}"
# async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
# params = await self._get_params(request)
# model_url = self.resolve_clarifai_model(request.model)
# r = self._get_client().app.Model(url=model_url, pat=self.config.PAT).predict_by_bytes(**params)
# return process_completion_response(r)
# async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
# params = await self._get_params(request)
# model_url = self.resolve_clarifai_model(request.model)
# async def _to_async_generator():
# s = self._get_client().app.Model(url=model_url, pat=self.config.PAT).stream_by_bytes(**params)
# for chunk in s:
# yield chunk
# stream = _to_async_generator()
# async for chunk in process_completion_stream_response(stream):
# yield chunk
def _build_options(
self,
sampling_params: Optional[SamplingParams],
logprobs: Optional[LogProbConfig],
fmt: ResponseFormat,
) -> dict:
options = get_sampling_options(sampling_params)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
if logprobs and logprobs.top_k:
if logprobs.top_k != 1:
raise ValueError(
f"Unsupported value: Clarifai only supports logprobs top_k=1. {logprobs.top_k} was provided",
)
options["logprobs"] = 1
return options
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
model_url = self.resolve_clarifai_model(request.model)
r = self._get_client().app.Model(url=model_url, pat=self.config.PAT).predict_by_bytes(**params)
return process_chat_completion_response(r)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
model_url = self.resolve_clarifai_model(request.model)
async def _to_async_generator():
s = self._get_client().app.Model(url=model_url, pat=self.config.PAT).predict_by_bytes(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
assert not media_present, "Clarifai does not support media for ChatCompletion requests"
input_dict["input_bytes"] = (await chat_completion_request_to_prompt(request, llama_model)).encode()
params = {
**input_dict,
"input_type": "text",
"inference_params": self._build_options(request.sampling_params, request.logprobs, request.response_format),
}
logcat.debug("inference", f"params to clarifai: {params}")
return params
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()