mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
added new models and updated chat_completin logic to follow the format
This commit is contained in:
parent
44c51efc55
commit
e7b7b102cf
4 changed files with 233 additions and 46 deletions
|
@ -8,7 +8,7 @@ import os
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
|
||||||
class WatsonXProviderDataValidator(BaseModel):
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
@ -24,7 +24,7 @@ class WatsonXConfig(BaseModel):
|
||||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
description="A base url for accessing the Watsonx.ai",
|
description="A base url for accessing the Watsonx.ai",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[SecretStr] = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||||
description="The Watsonx API key, only needed of using the hosted service",
|
description="The Watsonx API key, only needed of using the hosted service",
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,6 +11,22 @@ MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"meta-llama/llama-3-3-70b-instruct",
|
"meta-llama/llama-3-3-70b-instruct",
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-2-13b-chat",
|
||||||
|
CoreModelId.llama2_13b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,44 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import List, Optional, Union, AsyncIterator
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import Inference, Message, ToolChoice, ResponseFormat, LogProbConfig, ToolConfig, \
|
|
||||||
ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, TextTruncation, EmbeddingTaskType
|
|
||||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
ResponseFormatType,
|
||||||
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
convert_message_to_openai_dict,
|
||||||
|
get_sampling_options,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
completion_request_to_prompt,
|
||||||
|
request_has_media,
|
||||||
|
)
|
||||||
|
|
||||||
from . import WatsonXConfig
|
from . import WatsonXConfig
|
||||||
|
|
||||||
|
@ -28,10 +59,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
print(f"Initializing WatsonXInferenceAdapter({config.url})...")
|
print(f"Initializing WatsonXInferenceAdapter({config.url})...")
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
self._credential = {
|
|
||||||
"url": self._config.url,
|
|
||||||
"apikey": self._config.api_key
|
|
||||||
}
|
|
||||||
|
|
||||||
self._project_id = self._config.project_id
|
self._project_id = self._config.project_id
|
||||||
self.params = {
|
self.params = {
|
||||||
|
@ -39,49 +66,173 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
GenParams.STOP_SEQUENCES: ["<|endoftext|>"]
|
GenParams.STOP_SEQUENCES: ["<|endoftext|>"]
|
||||||
}
|
}
|
||||||
|
|
||||||
async def completion(
|
async def initialize(self) -> None:
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
content: InterleavedContent,
|
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
stream: Optional[bool] = False,
|
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def embeddings(
|
async def shutdown(self) -> None:
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
|
||||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
|
||||||
output_dimension: Optional[int] = None,
|
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
|
||||||
) -> EmbeddingsResponse:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
def _get_client(self, model_id) -> Model:
|
||||||
|
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||||
|
config_url = self._config.url
|
||||||
|
project_id = self._config.project_id
|
||||||
|
credentials = {
|
||||||
|
"url": config_url,
|
||||||
|
"apikey": config_api_key
|
||||||
|
}
|
||||||
|
|
||||||
|
return Model(model_id=model_id,credentials=credentials, project_id=project_id)
|
||||||
|
|
||||||
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_completion_response(response)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = self._get_client(request.model).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_completion_stream_response(stream):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
):
|
) -> AsyncGenerator:
|
||||||
# Language model
|
if sampling_params is None:
|
||||||
model = Model(
|
sampling_params = SamplingParams()
|
||||||
model_id=model_id,
|
model = await self.model_store.get_model(model_id)
|
||||||
credentials=self._credential,
|
request = ChatCompletionRequest(
|
||||||
project_id=self._project_id,
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
prompt = "\n".join(messages) + "\nAI: "
|
|
||||||
|
|
||||||
response = model.generate_text(prompt=prompt, params=self.params)
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
return response
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
model_id = request.model
|
||||||
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
|
async def _to_async_generator():
|
||||||
|
s = self._get_client(model_id).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
input_dict = {}
|
||||||
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
if media_present or not llama_model:
|
||||||
|
input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages]
|
||||||
|
else:
|
||||||
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
|
else:
|
||||||
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
**input_dict,
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
pass
|
||||||
|
|
|
@ -75,6 +75,26 @@ models:
|
||||||
provider_id: watsonx
|
provider_id: watsonx
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-2-13b-chat
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue