mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 21:34:18 +00:00
Merge branch 'main' of https://github.com/meta-llama/llama-stack into register_custom_model
This commit is contained in:
commit
0990f60dad
74 changed files with 4854 additions and 1869 deletions
|
|
@ -178,6 +178,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
|
|
@ -190,6 +192,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
|
|
@ -498,6 +502,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
|
|
|||
|
|
@ -515,7 +515,8 @@ class MetaReferenceInferenceImpl(
|
|||
stop_reason = None
|
||||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
for token_results in self.generator.chat_completion([request]):
|
||||
token_result = token_results[0]
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
make_overlapped_chunks,
|
||||
|
|
@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
|||
)
|
||||
)
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
content=picked,
|
||||
|
|
|
|||
|
|
@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
|
@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
user=user,
|
||||
)
|
||||
|
||||
# Divert Llama Models through Llama Stack inference APIs because
|
||||
# Fireworks chat completions OpenAI-compatible API does not support
|
||||
# tool calls properly.
|
||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||
if llama_model:
|
||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
||||
|
||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||
|
|
|
|||
85
llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
85
llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
# NVIDIA Inference Provider for LlamaStack
|
||||
|
||||
This provider enables running inference using NVIDIA NIM.
|
||||
|
||||
## Features
|
||||
- Endpoints for completions, chat completions, and embeddings for registered models
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NIM deployment
|
||||
- NIM for model to use for inference is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = (
|
||||
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
|
||||
)
|
||||
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
### Create Completion
|
||||
|
||||
```python
|
||||
response = client.completion(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||
content="Complete the sentence using one word: Roses are red, violets are :",
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
print(f"Response: {response.content}")
|
||||
```
|
||||
|
||||
### Create Chat Completion
|
||||
|
||||
```python
|
||||
response = client.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You must respond to each message with only one word",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Complete the sentence using one word: Roses are red, violets are:",
|
||||
},
|
||||
],
|
||||
stream=False,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
print(f"Response: {response.completion_message.content}")
|
||||
```
|
||||
|
||||
### Create Embeddings
|
||||
```python
|
||||
response = client.embeddings(
|
||||
model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"]
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
```
|
||||
|
|
@ -48,6 +48,10 @@ MODEL_ENTRIES = [
|
|||
"meta/llama-3.2-90b-vision-instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# NeMo Retriever Text Embedding models -
|
||||
#
|
||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
|
|
|
|||
|
|
@ -129,6 +129,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
base_url = special_model_urls[provider_model_id]
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
raise RuntimeError("Model store is not set")
|
||||
model = await self.model_store.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {model_id} is unknown")
|
||||
return model.provider_model_id
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -147,7 +155,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
request=CompletionRequest(
|
||||
model=provider_model_id,
|
||||
|
|
@ -191,7 +199,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
#
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
model = self.get_provider_model_id(model_id)
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
|
||||
extra_body = {}
|
||||
|
||||
|
|
@ -214,8 +222,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._get_client(model).embeddings.create(
|
||||
model=model,
|
||||
response = await self._get_client(provider_model_id).embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
|
@ -249,10 +257,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
request=ChatCompletionRequest(
|
||||
model=self.get_provider_model_id(model_id),
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
|
|
@ -297,7 +305,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
|
|
@ -350,7 +358,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
|
|
|
|||
|
|
@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
# Together client has no close method, so just set to None
|
||||
self._client = None
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
if params.get("stream", True):
|
||||
if params.get("stream", False):
|
||||
return self._stream_openai_chat_completion(params)
|
||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||
|
||||
|
|
|
|||
|
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
||||
self.client = AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
def _lazy_initialize_client(self):
|
||||
if self.client is not None:
|
||||
return
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
def _create_client(self):
|
||||
return AsyncOpenAI(
|
||||
base_url=self.config.url,
|
||||
api_key=self.config.api_token,
|
||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||
self._lazy_initialize_client()
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self._get_model(model_id)
|
||||
|
|
@ -357,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
assert self.client is not None
|
||||
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
model = await self.register_helper.register_model(model)
|
||||
res = await self.client.models.list()
|
||||
res = await client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
raise ValueError(
|
||||
|
|
@ -410,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
|
|
@ -449,6 +464,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
extra_body: Dict[str, Any] = {}
|
||||
|
|
@ -505,6 +521,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
self._lazy_initialize_client()
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@ _MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
)
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.2-1b-instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,11 +27,12 @@ from .models import _MODEL_ENTRIES
|
|||
|
||||
# Map API status to JobStatus enum
|
||||
STATUS_MAPPING = {
|
||||
"running": "in_progress",
|
||||
"completed": "completed",
|
||||
"failed": "failed",
|
||||
"cancelled": "cancelled",
|
||||
"pending": "scheduled",
|
||||
"running": JobStatus.in_progress.value,
|
||||
"completed": JobStatus.completed.value,
|
||||
"failed": JobStatus.failed.value,
|
||||
"cancelled": JobStatus.cancelled.value,
|
||||
"pending": JobStatus.scheduled.value,
|
||||
"unknown": JobStatus.scheduled.value,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
77
llama_stack/providers/remote/safety/nvidia/README.md
Normal file
77
llama_stack/providers/remote/safety/nvidia/README.md
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
# NVIDIA Safety Provider for LlamaStack
|
||||
|
||||
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
|
||||
|
||||
## Features
|
||||
|
||||
- Run safety checks for messages
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to NVIDIA NeMo Guardrails service
|
||||
- NIM for model to use for safety check is deployed
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
#### Create a safety shield
|
||||
|
||||
```python
|
||||
from llama_stack.apis.safety import Shield
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
# Create a safety shield
|
||||
shield = Shield(
|
||||
shield_id="your-shield-id",
|
||||
provider_resource_id="safety-model-id", # The model to use for safety checks
|
||||
description="Safety checks for content moderation",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
await client.safety.register_shield(shield)
|
||||
```
|
||||
|
||||
#### Run safety checks
|
||||
|
||||
```python
|
||||
# Messages to check
|
||||
messages = [Message(role="user", content="Your message to check")]
|
||||
|
||||
# Run safety check
|
||||
response = await client.safety.run_shield(
|
||||
shield_id="your-shield-id",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Check for violations
|
||||
if response.violation:
|
||||
print(f"Safety violation detected: {response.violation.user_message}")
|
||||
print(f"Violation level: {response.violation.violation_level}")
|
||||
print(f"Metadata: {response.violation.metadata}")
|
||||
else:
|
||||
print("No safety violations detected")
|
||||
```
|
||||
|
|
@ -8,7 +8,17 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import (
|
||||
|
|
@ -78,6 +88,7 @@ from llama_stack.apis.common.content_types import (
|
|||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
_URLOrData,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -93,6 +104,7 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
SystemMessage,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolResponseMessage,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
|
|
@ -103,7 +115,6 @@ from llama_stack.apis.inference.inference import (
|
|||
OpenAIChatCompletion,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ToolConfig,
|
||||
)
|
||||
|
|
@ -612,13 +623,10 @@ async def convert_message_to_openai_dict_new(
|
|||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
params = {"tool_calls": tool_calls}
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
**params,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
|
|
@ -695,7 +703,10 @@ def to_openai_param_type(param_type: str) -> dict:
|
|||
if param_type.startswith("list[") and param_type.endswith("]"):
|
||||
inner_type = param_type[5:-1]
|
||||
if inner_type in basic_types:
|
||||
return {"type": "array", "items": {"type": basic_types.get(inner_type, inner_type)}}
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": basic_types.get(inner_type, inner_type)},
|
||||
}
|
||||
|
||||
return {"type": param_type}
|
||||
|
||||
|
|
@ -815,6 +826,10 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
|||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
try:
|
||||
tool_choice = ToolChoice(tool_choice)
|
||||
except ValueError:
|
||||
pass
|
||||
tool_config.tool_choice = tool_choice
|
||||
return tool_config
|
||||
|
||||
|
|
@ -849,7 +864,9 @@ def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None)
|
|||
return lls_tools
|
||||
|
||||
|
||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
||||
def _convert_openai_request_response_format(
|
||||
response_format: OpenAIResponseFormatParam = None,
|
||||
):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
|
|
@ -957,38 +974,50 @@ def _convert_openai_sampling_params(
|
|||
return sampling_params
|
||||
|
||||
|
||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
||||
lls_messages = []
|
||||
def openai_messages_to_messages(
|
||||
messages: List[OpenAIChatCompletionMessage],
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
"""
|
||||
converted_messages = []
|
||||
for message in messages:
|
||||
lls_message = dict(message)
|
||||
if message.role == "system":
|
||||
converted_message = SystemMessage(content=message.content)
|
||||
elif message.role == "user":
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||
elif message.role == "assistant":
|
||||
converted_message = CompletionMessage(
|
||||
content=message.content,
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
elif message.role == "tool":
|
||||
converted_message = ToolResponseMessage(
|
||||
role="tool",
|
||||
call_id=message.tool_call_id,
|
||||
content=openai_content_to_content(message.content),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role {message.role}")
|
||||
converted_messages.append(converted_message)
|
||||
return converted_messages
|
||||
|
||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
||||
if tool_call_id:
|
||||
lls_message["call_id"] = tool_call_id
|
||||
|
||||
content = lls_message.get("content", None)
|
||||
if isinstance(content, list):
|
||||
lls_content = []
|
||||
for item in content:
|
||||
# items can either by pydantic models or dicts here...
|
||||
item = dict(item)
|
||||
if item.get("type", "") == "image_url":
|
||||
lls_item = ImageContentItem(
|
||||
type="image",
|
||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
||||
)
|
||||
elif item.get("type", "") == "text":
|
||||
lls_item = TextContentItem(
|
||||
type="text",
|
||||
text=item.get("text", ""),
|
||||
)
|
||||
lls_content.append(lls_item)
|
||||
lls_message["content"] = lls_content
|
||||
lls_messages.append(lls_message)
|
||||
|
||||
return lls_messages
|
||||
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return [openai_content_to_content(c) for c in content]
|
||||
elif hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
return TextContentItem(type="text", text=content.text)
|
||||
elif content.type == "image_url":
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content}")
|
||||
|
||||
|
||||
def convert_openai_chat_completion_choice(
|
||||
|
|
@ -1313,7 +1342,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages = _convert_openai_request_messages(messages)
|
||||
messages = openai_messages_to_messages(messages)
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -1321,7 +1350,10 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p=top_p,
|
||||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
|
|
@ -1346,7 +1378,9 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
)
|
||||
|
||||
async def _process_stream_response(
|
||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
||||
self,
|
||||
model: str,
|
||||
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||
):
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
for outstanding_response in outstanding_responses:
|
||||
|
|
@ -1369,11 +1403,31 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
|
||||
# First chunk includes full structure
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
||||
name=tool_call.tool_name,
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
# arguments
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
arguments=tool_call.arguments_json,
|
||||
),
|
||||
)
|
||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue