diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py index 289f782e9..2f25b54af 100644 --- a/llama_stack/providers/remote/inference/watsonx/config.py +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -8,7 +8,7 @@ import os from typing import Optional, Dict, Any from llama_stack.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, SecretStr class WatsonXProviderDataValidator(BaseModel): @@ -24,7 +24,7 @@ class WatsonXConfig(BaseModel): default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), description="A base url for accessing the Watsonx.ai", ) - api_key: Optional[str] = Field( + api_key: Optional[SecretStr] = Field( default_factory=lambda: os.getenv("WATSONX_API_KEY"), description="The Watsonx API key, only needed of using the hosted service", ) diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py index 06f1bb62b..bded586d7 100644 --- a/llama_stack/providers/remote/inference/watsonx/models.py +++ b/llama_stack/providers/remote/inference/watsonx/models.py @@ -11,6 +11,22 @@ MODEL_ENTRIES = [ build_hf_repo_model_entry( "meta-llama/llama-3-3-70b-instruct", CoreModelId.llama3_3_70b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-2-13b-chat", + CoreModelId.llama2_13b.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-1-70b-instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-1-8b-instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-11b-vision-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, ) ] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index 17cd801d0..9dda8dea5 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -4,13 +4,44 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Optional, Union, AsyncIterator +from typing import AsyncGenerator, List, Optional, Union from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem -from llama_stack.apis.inference import Inference, Message, ToolChoice, ResponseFormat, LogProbConfig, ToolConfig, \ - ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, TextTruncation, EmbeddingTaskType from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + EmbeddingTaskType, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + TextTruncation, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + convert_message_to_openai_dict, + get_sampling_options, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, + request_has_media, +) from . import WatsonXConfig @@ -28,10 +59,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): print(f"Initializing WatsonXInferenceAdapter({config.url})...") self._config = config - self._credential = { - "url": self._config.url, - "apikey": self._config.api_key - } self._project_id = self._config.project_id self.params = { @@ -39,49 +66,173 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): GenParams.STOP_SEQUENCES: ["<|endoftext|>"] } - async def completion( - self, - model_id: str, - content: InterleavedContent, - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - ): + async def initialize(self) -> None: pass - async def embeddings( - self, - model_id: str, - contents: List[str] | List[InterleavedContentItem], - text_truncation: Optional[TextTruncation] = TextTruncation.none, - output_dimension: Optional[int] = None, - task_type: Optional[EmbeddingTaskType] = None, - ) -> EmbeddingsResponse: + async def shutdown(self) -> None: pass + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_client(self, model_id) -> Model: + config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None + config_url = self._config.url + project_id = self._config.project_id + credentials = { + "url": config_url, + "apikey": config_api_key + } + + return Model(model_id=model_id,credentials=credentials, project_id=project_id) + + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: + params = await self._get_params(request) + r = self._get_client(request.model).generate(**params) + choices = [] + if "results" in r: + for result in r["results"]: + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"] if result["stop_reason"] else None, + text=result["generated_text"], + ) + choices.append(choice) + response = OpenAICompatCompletionResponse( + choices=choices, + ) + return process_completion_response(response) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + + async def _generate_and_convert_to_openai_compat(): + s = self._get_client(request.model).generate_text_stream(**params) + for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=None, + text=chunk, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream): + yield chunk + async def chat_completion( - self, - model_id: str, - messages: List[Message], - sampling_params: Optional[SamplingParams] = None, - response_format: Optional[ResponseFormat] = None, - tools: Optional[List[ToolDefinition]] = None, - tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = None, - stream: Optional[bool] = False, - logprobs: Optional[LogProbConfig] = None, - tool_config: Optional[ToolConfig] = None, - ): - # Language model - model = Model( - model_id=model_id, - credentials=self._credential, - project_id=self._project_id, + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + response_format=response_format, + stream=stream, + logprobs=logprobs, + tool_config=tool_config, ) - prompt = "\n".join(messages) + "\nAI: " - response = model.generate_text(prompt=prompt, params=self.params) + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) - return response + async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + params = await self._get_params(request) + r = self._get_client(request.model).generate(**params) + choices = [] + if "results" in r: + for result in r["results"]: + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"] if result["stop_reason"] else None, + text=result["generated_text"], + ) + choices.append(choice) + response = OpenAICompatCompletionResponse( + choices=choices, + ) + return process_chat_completion_response(response, request) + async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + model_id = request.model + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_client(model_id).generate_text_stream(**params) + for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=None, + text=chunk, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response(stream, request): + yield chunk + + async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + input_dict = {} + media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) + if isinstance(request, ChatCompletionRequest): + if media_present or not llama_model: + input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] + else: + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + else: + assert not media_present, "Together does not support media for Completion requests" + input_dict["prompt"] = await completion_request_to_prompt(request) + + params = { + **input_dict, + } + return params + + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: + pass diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index 0dd439da9..851b19810 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -75,6 +75,26 @@ models: provider_id: watsonx provider_model_id: meta-llama/llama-3-3-70b-instruct model_type: llm +- metadata: {} + model_id: meta-llama/llama-2-13b-chat + provider_id: watsonx + provider_model_id: meta-llama/llama-2-13b-chat + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-1-70b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-1-8b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-11b-vision-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-11b-vision-instruct + model_type: llm shields: [] vector_dbs: [] datasets: []