mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
feat: remote ramalama provider implementation
Implement remote ramalama provider using AsyncOpenAI as the client since ramalama doesn't have its own Async library. Ramalama is similar to ollama, as it is a lightweight local inference server. However, it runs by default in a containerized mode. RAMALAMA_URL is http://localhost:8080 by default Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
94f83382eb
commit
4de45560bf
8 changed files with 680 additions and 0 deletions
|
@ -306,6 +306,7 @@ async def instantiate_provider(
|
|||
additional_protocols = additional_protocols_map()
|
||||
|
||||
provider_spec = provider.spec
|
||||
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
|
|
|
@ -77,6 +77,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.inference.ollama",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="ramalama",
|
||||
pip_packages=["ramalama", "aiohttp"],
|
||||
config_class="llama_stack.providers.remote.inference.ramalama.RamalamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ramalama",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
15
llama_stack/providers/remote/inference/ramalama/__init__.py
Normal file
15
llama_stack/providers/remote/inference/ramalama/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import RamalamaImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RamalamaImplConfig, _deps):
|
||||
from .ramalama import RamalamaInferenceAdapter
|
||||
|
||||
impl = RamalamaInferenceAdapter(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
19
llama_stack/providers/remote/inference/ramalama/config.py
Normal file
19
llama_stack/providers/remote/inference/ramalama/config.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
DEFAULT_RAMALAMA_URL = "http://localhost:8080"
|
||||
|
||||
|
||||
class RamalamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_RAMALAMA_URL
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.RAMALAMA_URL:http://localhost:8080}", **kwargs) -> Dict[str, Any]:
|
||||
return {"url": url}
|
103
llama_stack/providers/remote/inference/ramalama/models.py
Normal file
103
llama_stack/providers/remote/inference/ramalama/models.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
build_model_entry,
|
||||
)
|
||||
|
||||
model_entries = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:70b-instruct-fp16",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:70b",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:405b-instruct-fp16",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.1:405b",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:1b-instruct-fp16",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:1b",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2:3b",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:11b-instruct-fp16",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:latest",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2-vision:90b-instruct-fp16",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"llama3.2-vision:90b",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:latest",
|
||||
aliases=["all-minilm"],
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-embed-text",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
344
llama_stack/providers/remote/inference/ramalama/openai_utils.py
Normal file
344
llama_stack/providers/remote/inference/ramalama/openai_utils.py
Normal file
|
@ -0,0 +1,344 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
)
|
||||
|
||||
|
||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||
"""
|
||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||
the content field for compabilitiy with OpenAI-style APIs.
|
||||
|
||||
Generates a content string that emulates the current behavior
|
||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||
|
||||
:param message: Message that may include ``context`` field
|
||||
|
||||
:returns: A version of ``message`` with any context merged into the
|
||||
``content`` field.
|
||||
"""
|
||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||
return message
|
||||
if message.context is None:
|
||||
return message
|
||||
return UserMessage(
|
||||
role=message.role,
|
||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||
content=message.content + "\n\n" + message.context,
|
||||
context=None,
|
||||
)
|
||||
|
||||
|
||||
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
request: ChatCompletionRequest,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a chat completion request in Llama Stack format into an
|
||||
equivalent set of arguments to pass to an OpenAI-compatible
|
||||
chat completions API.
|
||||
|
||||
:param request: Bundled request parameters in Llama Stack format.
|
||||
|
||||
:returns: Dictionary of key-value pairs to use as an initializer
|
||||
for a dataclass or to be converted directly to JSON and sent
|
||||
over the wire.
|
||||
"""
|
||||
|
||||
converted_messages = [
|
||||
# This mystery async call makes the parent function also be async
|
||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True)
|
||||
for m in request.messages
|
||||
]
|
||||
# converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||
|
||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||
# tool choice unless at least one tool is enabled.
|
||||
converted_tool_choice = "none"
|
||||
if (
|
||||
request.tool_config is not None
|
||||
and request.tool_config.tool_choice == ToolChoice.auto
|
||||
and request.tools is not None
|
||||
and len(request.tools) > 0
|
||||
):
|
||||
converted_tool_choice = "auto"
|
||||
|
||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||
# Other connectors appear to drop it quietly.
|
||||
|
||||
# Use Llama Stack shared code to translate sampling parameters.
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
|
||||
# get_sampling_options() translates repetition penalties to an option that
|
||||
# OpenAI's APIs don't know about.
|
||||
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||
# For now, translate repetition penalties into a format that vLLM's broken
|
||||
# API will handle correctly. Two wrongs make a right...
|
||||
if "repeat_penalty" in sampling_options:
|
||||
del sampling_options["repeat_penalty"]
|
||||
if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0:
|
||||
sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
# Convert a single response format into four different parameters, per
|
||||
# the OpenAI spec
|
||||
guided_decoding_options = dict()
|
||||
if request.response_format is None:
|
||||
# Use defaults
|
||||
pass
|
||||
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'")
|
||||
|
||||
logprob_options = dict()
|
||||
if request.logprobs is not None:
|
||||
logprob_options["logprobs"] = request.logprobs.top_k
|
||||
|
||||
# Marshall together all the arguments for a ChatCompletionRequest
|
||||
request_options = {
|
||||
"model": request.model,
|
||||
"messages": converted_messages,
|
||||
"tool_choice": converted_tool_choice,
|
||||
"stream": request.stream,
|
||||
**sampling_options,
|
||||
**guided_decoding_options,
|
||||
**logprob_options,
|
||||
}
|
||||
|
||||
return request_options
|
||||
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion import (
|
||||
Choice as OpenAIChoice,
|
||||
)
|
||||
from openai.types.completion import Completion as OpenAICompletion
|
||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
TokenLogProbs,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
_convert_openai_finish_reason,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_tooldef_to_openai_tool,
|
||||
)
|
||||
|
||||
|
||||
async def convert_chat_completion_request(
|
||||
request: ChatCompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# messages -> messages
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# response_format -> GrammarResponseFormat TODO(mf)
|
||||
# response_format -> JsonSchemaResponseFormat: response_format = "json_object" & nvext["guided_json"] = json_schema
|
||||
# tools -> tools
|
||||
# tool_choice ("auto", "required") -> tool_choice
|
||||
# tool_prompt_format -> TBD
|
||||
# stream -> stream
|
||||
# logprobs -> logprobs
|
||||
|
||||
if request.response_format and not isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {request.response_format}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
messages=[await convert_message_to_openai_dict_new(message) for message in request.messages],
|
||||
stream=request.stream,
|
||||
n=n,
|
||||
extra_body=dict(nvext=nvext),
|
||||
extra_headers={
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
},
|
||||
)
|
||||
|
||||
if request.response_format:
|
||||
# server bug - setting guided_json changes the behavior of response_format resulting in an error
|
||||
# payload.update(response_format="json_object")
|
||||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.tools:
|
||||
payload.update(tools=[convert_tooldef_to_openai_tool(tool) for tool in request.tools])
|
||||
if request.tool_config.tool_choice:
|
||||
payload.update(
|
||||
tool_choice=request.tool_config.tool_choice.value
|
||||
) # we cannot include tool_choice w/o tools, server will complain
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=True)
|
||||
payload.update(top_logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
strategy = request.sampling_params.strategy
|
||||
if isinstance(strategy, TopPSamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=strategy.top_p)
|
||||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif isinstance(strategy, GreedySamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy: {strategy}")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def convert_completion_request(
|
||||
request: CompletionRequest,
|
||||
n: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
|
||||
"""
|
||||
# model -> model
|
||||
# prompt -> prompt
|
||||
# sampling_params TODO(mattf): review strategy
|
||||
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
|
||||
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
|
||||
# strategy=top_k -> nvext.top_k = top_k
|
||||
# temperature -> temperature
|
||||
# top_p -> top_p
|
||||
# top_k -> nvext.top_k
|
||||
# max_tokens -> max_tokens
|
||||
# repetition_penalty -> nvext.repetition_penalty
|
||||
# response_format -> nvext.guided_json
|
||||
# stream -> stream
|
||||
# logprobs.top_k -> logprobs
|
||||
|
||||
nvext = {}
|
||||
payload: Dict[str, Any] = dict(
|
||||
model=request.model,
|
||||
prompt=request.content,
|
||||
stream=request.stream,
|
||||
extra_body=dict(nvext=nvext),
|
||||
extra_headers={
|
||||
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
|
||||
},
|
||||
n=n,
|
||||
)
|
||||
|
||||
if request.response_format:
|
||||
# this is not openai compliant, it is a nim extension
|
||||
nvext.update(guided_json=request.response_format.json_schema)
|
||||
|
||||
if request.logprobs:
|
||||
payload.update(logprobs=request.logprobs.top_k)
|
||||
|
||||
if request.sampling_params:
|
||||
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
|
||||
|
||||
if request.sampling_params.max_tokens:
|
||||
payload.update(max_tokens=request.sampling_params.max_tokens)
|
||||
|
||||
if request.sampling_params.strategy == "top_p":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
payload.update(temperature=request.sampling_params.temperature)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _convert_openai_completion_logprobs(
|
||||
logprobs: Optional[OpenAICompletionLogprobs],
|
||||
) -> Optional[List[TokenLogProbs]]:
|
||||
"""
|
||||
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
|
||||
"""
|
||||
if not logprobs:
|
||||
return None
|
||||
|
||||
return [TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs]
|
||||
|
||||
|
||||
def convert_openai_completion_choice(
|
||||
choice: OpenAIChoice,
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Convert an OpenAI Completion Choice into a CompletionResponse.
|
||||
"""
|
||||
return CompletionResponse(
|
||||
content=choice.text,
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
||||
async def convert_openai_completion_stream(
|
||||
stream: AsyncStream[OpenAICompletion],
|
||||
) -> AsyncGenerator[CompletionResponse, None]:
|
||||
"""
|
||||
Convert a stream of OpenAI Completions into a stream
|
||||
of ChatCompletionResponseStreamChunks.
|
||||
"""
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=choice.text,
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
188
llama_stack/providers/remote/inference/ramalama/ramalama.py
Normal file
188
llama_stack/providers/remote/inference/ramalama/ramalama.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, BadRequestError
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
)
|
||||
|
||||
from .models import model_entries
|
||||
from .openai_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_completion_request,
|
||||
convert_openai_completion_choice,
|
||||
convert_openai_completion_stream,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class RamalamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
self.register_helper = ModelRegistryHelper(model_entries)
|
||||
self.url = url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ramalama at `{self.url}`...")
|
||||
self.client = AsyncOpenAI(base_url=self.url, api_key="NO KEY")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
response = await self.client.completions.create(**request)
|
||||
if stream:
|
||||
return convert_openai_completion_stream(response)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_completion_choice(response.choices[0])
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
response_format=response_format,
|
||||
tool_config=tool_config,
|
||||
),
|
||||
n=1,
|
||||
)
|
||||
s = await self.client.chat.completions.create(**request)
|
||||
if stream:
|
||||
return convert_openai_chat_completion_stream(s, enable_incremental_tool_calls=False)
|
||||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(s.choices[0])
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
|
||||
extra_body = {}
|
||||
|
||||
if text_truncation is not None:
|
||||
text_truncation_options = {
|
||||
TextTruncation.none: "NONE",
|
||||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
extra_body["dimensions"] = output_dimension
|
||||
|
||||
if task_type is not None:
|
||||
task_type_options = {
|
||||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._client.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
except BadRequestError as e:
|
||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
||||
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = await self.client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
f"Model {model.provider_resource_id} is not being served by vLLM. "
|
||||
f"Available models: {', '.join(available_models)}"
|
||||
)
|
||||
return model
|
|
@ -260,6 +260,7 @@ exclude = [
|
|||
"^llama_stack/providers/remote/inference/nvidia/",
|
||||
"^llama_stack/providers/remote/inference/openai/",
|
||||
"^llama_stack/providers/remote/inference/passthrough/",
|
||||
"^llama_stack/providers/remote/inference/ramalama/",
|
||||
"^llama_stack/providers/remote/inference/runpod/",
|
||||
"^llama_stack/providers/remote/inference/sambanova/",
|
||||
"^llama_stack/providers/remote/inference/sample/",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue