mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
chore: remove llama_models.llama3.api imports from providers (#1107)
There should be a choke-point for llama3.api imports -- this is the prompt adapter. Creating a ChatFormat() object on demand is inexpensive. The underlying Tokenizer is a singleton anyway.
This commit is contained in:
parent
e9b8259cf9
commit
cdcbeb005b
13 changed files with 77 additions and 113 deletions
|
@ -9,7 +9,6 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
@ -62,7 +61,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMConfig):
|
def __init__(self, config: VLLMConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.engine = None
|
self.engine = None
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
log.info("Initializing vLLM inference provider.")
|
log.info("Initializing vLLM inference provider.")
|
||||||
|
@ -177,7 +175,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
request_id = _random_uuid()
|
request_id = _random_uuid()
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.config.model, self.formatter)
|
prompt = await chat_completion_request_to_prompt(request, self.config.model)
|
||||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||||
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
results_generator = self.engine.generate(prompt, vllm_sampling_params, request_id)
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -201,11 +199,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
response = OpenAICompatCompletionResponse(
|
response = OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, self.formatter, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
cur = []
|
cur = []
|
||||||
async for chunk in results_generator:
|
async for chunk in results_generator:
|
||||||
|
@ -216,7 +216,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
output = chunk.outputs[-1]
|
output = chunk.outputs[-1]
|
||||||
|
|
||||||
new_tokens = output.token_ids[len(cur) :]
|
new_tokens = output.token_ids[len(cur) :]
|
||||||
text = self.formatter.tokenizer.decode(new_tokens)
|
text = tokenizer.decode(new_tokens)
|
||||||
cur.extend(new_tokens)
|
cur.extend(new_tokens)
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
|
@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||||
|
|
|
@ -8,8 +8,6 @@ import json
|
||||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -54,7 +52,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
self._client = create_bedrock_client(config)
|
self._client = create_bedrock_client(config)
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> BaseClient:
|
def client(self) -> BaseClient:
|
||||||
|
@ -119,7 +116,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
response = OpenAICompatCompletionResponse(choices=[choice])
|
response = OpenAICompatCompletionResponse(choices=[choice])
|
||||||
return process_chat_completion_response(response, self.formatter, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params_for_chat_completion(request)
|
params = await self._get_params_for_chat_completion(request)
|
||||||
|
@ -137,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
yield OpenAICompatCompletionResponse(choices=[choice])
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
||||||
|
@ -151,7 +148,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if sampling_params.repetition_penalty > 0:
|
if sampling_params.repetition_penalty > 0:
|
||||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
|
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||||
return {
|
return {
|
||||||
"modelId": bedrock_model,
|
"modelId": bedrock_model,
|
||||||
"body": json.dumps(
|
"body": json.dumps(
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -53,7 +51,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model_aliases=model_aliases,
|
model_aliases=model_aliases,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
self.client = AsyncCerebras(
|
self.client = AsyncCerebras(
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
|
@ -96,14 +93,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
|
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -143,14 +140,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
r = await self.client.completions.create(**params)
|
r = await self.client.completions.create(**params)
|
||||||
|
|
||||||
return process_chat_completion_response(r, self.formatter, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self.client.completions.create(**params)
|
stream = await self.client.completions.create(**params)
|
||||||
|
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
@ -159,11 +156,9 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
prompt = await chat_completion_request_to_prompt(
|
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
|
||||||
)
|
|
||||||
elif isinstance(request, CompletionRequest):
|
elif isinstance(request, CompletionRequest):
|
||||||
prompt = await completion_request_to_prompt(request, self.formatter)
|
prompt = await completion_request_to_prompt(request)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown request type {type(request)}")
|
raise ValueError(f"Unknown request type {type(request)}")
|
||||||
|
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -54,12 +52,8 @@ model_aliases = [
|
||||||
|
|
||||||
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases=model_aliases)
|
||||||
self,
|
|
||||||
model_aliases=model_aliases,
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
return
|
||||||
|
@ -112,7 +106,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
r = client.completions.create(**params)
|
r = client.completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
|
@ -123,13 +117,13 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter),
|
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request.sampling_params),
|
**get_sampling_options(request.sampling_params),
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -56,7 +54,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -105,7 +102,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -117,7 +114,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _build_options(
|
def _build_options(
|
||||||
|
@ -186,7 +183,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
r = await self._get_client().chat.completions.acreate(**params)
|
r = await self._get_client().chat.completions.acreate(**params)
|
||||||
else:
|
else:
|
||||||
r = await self._get_client().completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_chat_completion_response(r, self.formatter, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -200,7 +197,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
@ -214,11 +211,11 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Fireworks does not support media for Completion requests"
|
assert not media_present, "Fireworks does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
if "prompt" in input_dict:
|
if "prompt" in input_dict:
|
||||||
|
|
|
@ -8,8 +8,6 @@ import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -138,7 +136,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, url: str) -> None:
|
def __init__(self, url: str) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(model_aliases)
|
self.register_helper = ModelRegistryHelper(model_aliases)
|
||||||
self.url = url
|
self.url = url
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def client(self) -> AsyncClient:
|
||||||
|
@ -197,7 +194,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
@ -212,7 +209,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
|
||||||
return process_completion_response(response, self.formatter)
|
return process_completion_response(response)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -262,11 +259,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request,
|
request,
|
||||||
self.register_helper.get_llama_model(request.model),
|
self.register_helper.get_llama_model(request.model),
|
||||||
self.formatter,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Ollama does not support media for Completion requests"
|
assert not media_present, "Ollama does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
input_dict["raw"] = True
|
input_dict["raw"] = True
|
||||||
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
|
@ -304,7 +300,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response = OpenAICompatCompletionResponse(
|
response = OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, self.formatter, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -330,7 +326,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -5,8 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
@ -45,7 +43,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: RunpodImplConfig) -> None:
|
def __init__(self, config: RunpodImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
ModelRegistryHelper.__init__(self, stack_to_provider_models_map=RUNPOD_SUPPORTED_MODELS)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
return
|
||||||
|
@ -56,7 +53,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedContent,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
@ -97,7 +94,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
r = client.completions.create(**params)
|
r = client.completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
|
@ -108,13 +105,13 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": self.map_to_provider_model(request.model),
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
"prompt": chat_completion_request_to_prompt(request),
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request.sampling_params),
|
**get_sampling_options(request.sampling_params),
|
||||||
}
|
}
|
||||||
|
@ -122,6 +119,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedContent],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -7,8 +7,6 @@
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -38,13 +36,8 @@ from .models import MODEL_ALIASES
|
||||||
|
|
||||||
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: SambaNovaImplConfig) -> None:
|
def __init__(self, config: SambaNovaImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(self, model_aliases=MODEL_ALIASES)
|
||||||
self,
|
|
||||||
model_aliases=MODEL_ALIASES,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
return
|
||||||
|
@ -120,7 +113,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -9,8 +9,6 @@ import logging
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -72,7 +70,6 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str
|
model_id: str
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
self.huggingface_repo_to_llama_model_id = {
|
self.huggingface_repo_to_llama_model_id = {
|
||||||
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
|
||||||
|
@ -149,7 +146,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
return options
|
return options
|
||||||
|
|
||||||
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||||
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request, self.formatter)
|
prompt, input_tokens = await completion_request_to_prompt_model_input_info(request)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -177,7 +174,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
@ -193,7 +190,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
|
||||||
return process_completion_response(response, self.formatter)
|
return process_completion_response(response)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -236,7 +233,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response = OpenAICompatCompletionResponse(
|
response = OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(response, self.formatter, request)
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -252,12 +249,12 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
prompt, input_tokens = await chat_completion_request_to_model_input_info(
|
prompt, input_tokens = await chat_completion_request_to_model_input_info(
|
||||||
request, self.register_helper.get_llama_model(request.model), self.formatter
|
request, self.register_helper.get_llama_model(request.model)
|
||||||
)
|
)
|
||||||
return dict(
|
return dict(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
|
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -55,7 +53,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
def __init__(self, config: TogetherImplConfig) -> None:
|
def __init__(self, config: TogetherImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -102,7 +99,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = self._get_client().completions.create(**params)
|
r = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -114,7 +111,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _build_options(
|
def _build_options(
|
||||||
|
@ -180,7 +177,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
r = self._get_client().chat.completions.create(**params)
|
r = self._get_client().chat.completions.create(**params)
|
||||||
else:
|
else:
|
||||||
r = self._get_client().completions.create(**params)
|
r = self._get_client().completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -195,7 +192,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
@ -206,11 +203,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
else:
|
else:
|
||||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||||
request, self.get_llama_model(request.model), self.formatter
|
request, self.get_llama_model(request.model)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not media_present, "Together does not support media for Completion requests"
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
|
|
|
@ -8,8 +8,6 @@ import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import StopReason, ToolCall
|
from llama_models.datatypes import StopReason, ToolCall
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||||
|
@ -191,7 +189,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
self.config = config
|
self.config = config
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
@ -286,14 +283,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if len(request.tools) > 0:
|
if len(request.tools) > 0:
|
||||||
res = _process_vllm_chat_completion_stream_response(stream)
|
res = _process_vllm_chat_completion_stream_response(stream)
|
||||||
else:
|
else:
|
||||||
res = process_chat_completion_stream_response(stream, self.formatter, request)
|
res = process_chat_completion_stream_response(stream, request)
|
||||||
async for chunk in res:
|
async for chunk in res:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = self.client.completions.create(**params)
|
r = self.client.completions.create(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -305,7 +302,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
stream = _to_async_generator()
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
@ -332,10 +329,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||||
else:
|
else:
|
||||||
assert not request_has_media(request), "vLLM does not support media for Completion requests"
|
assert not request_has_media(request), "vLLM does not support media for Completion requests"
|
||||||
input_dict["prompt"] = await completion_request_to_prompt(
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
request,
|
|
||||||
self.formatter,
|
|
||||||
)
|
|
||||||
|
|
||||||
if fmt := request.response_format:
|
if fmt := request.response_format:
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
|
|
|
@ -7,7 +7,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from openai.types.chat import ChatCompletionMessageToolCall
|
from openai.types.chat import ChatCompletionMessageToolCall
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -40,6 +39,7 @@ from llama_stack.models.llama.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_content_to_url,
|
convert_image_content_to_url,
|
||||||
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -149,7 +149,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: Optional[Unio
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_completion_response(response: OpenAICompatCompletionResponse, formatter: ChatFormat) -> CompletionResponse:
|
def process_completion_response(response: OpenAICompatCompletionResponse) -> CompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||||
if choice.text.endswith("<|eot_id|>"):
|
if choice.text.endswith("<|eot_id|>"):
|
||||||
|
@ -174,16 +174,13 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format
|
||||||
|
|
||||||
def process_chat_completion_response(
|
def process_chat_completion_response(
|
||||||
response: OpenAICompatCompletionResponse,
|
response: OpenAICompatCompletionResponse,
|
||||||
formatter: ChatFormat,
|
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|
||||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||||
raw_message = formatter.decode_assistant_message_from_content(
|
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||||
# expect the ToolCall in the response. Instead, we should return the raw
|
# expect the ToolCall in the response. Instead, we should return the raw
|
||||||
|
@ -217,7 +214,7 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
|
|
||||||
async def process_completion_stream_response(
|
async def process_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
|
@ -254,7 +251,6 @@ async def process_completion_stream_response(
|
||||||
|
|
||||||
async def process_chat_completion_stream_response(
|
async def process_chat_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None],
|
||||||
formatter: ChatFormat,
|
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
@ -333,7 +329,7 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = formatter.decode_assistant_message_from_content(buffer, stop_reason)
|
message = decode_assistant_message(buffer, stop_reason)
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
|
|
|
@ -13,7 +13,9 @@ import re
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from llama_models.datatypes import StopReason
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -66,6 +68,11 @@ class CompletionRequestWithRawContent(CompletionRequest):
|
||||||
content: RawContent
|
content: RawContent
|
||||||
|
|
||||||
|
|
||||||
|
def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
|
||||||
|
formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
return formatter.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
|
||||||
|
|
||||||
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
|
||||||
def _process(c) -> str:
|
def _process(c) -> str:
|
||||||
if isinstance(c, str):
|
if isinstance(c, str):
|
||||||
|
@ -207,20 +214,22 @@ async def convert_image_content_to_url(
|
||||||
return base64.b64encode(content).decode("utf-8")
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
async def completion_request_to_prompt(request: CompletionRequest, formatter: ChatFormat) -> str:
|
async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
||||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||||
request.content = content
|
request.content = content
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_content(request.content)
|
model_input = formatter.encode_content(request.content)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
async def completion_request_to_prompt_model_input_info(
|
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||||
request: CompletionRequest, formatter: ChatFormat
|
|
||||||
) -> Tuple[str, int]:
|
|
||||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||||
request.content = content
|
request.content = content
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_content(request.content)
|
model_input = formatter.encode_content(request.content)
|
||||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||||
|
|
||||||
|
@ -237,22 +246,24 @@ def augment_content_with_response_format_prompt(response_format, content):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_prompt(
|
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
|
||||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
|
||||||
) -> str:
|
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
request.messages = messages
|
request.messages = messages
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||||
return formatter.tokenizer.decode(model_input.tokens)
|
return formatter.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_request_to_model_input_info(
|
async def chat_completion_request_to_model_input_info(
|
||||||
request: ChatCompletionRequest, llama_model: str, formatter: ChatFormat
|
request: ChatCompletionRequest, llama_model: str
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
messages = chat_completion_request_to_messages(request, llama_model)
|
messages = chat_completion_request_to_messages(request, llama_model)
|
||||||
request.messages = messages
|
request.messages = messages
|
||||||
request = await convert_request_to_raw(request)
|
request = await convert_request_to_raw(request)
|
||||||
|
|
||||||
|
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
|
||||||
model_input = formatter.encode_dialog_prompt(request.messages)
|
model_input = formatter.encode_dialog_prompt(request.messages)
|
||||||
return (
|
return (
|
||||||
formatter.tokenizer.decode(model_input.tokens),
|
formatter.tokenizer.decode(model_input.tokens),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue