diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index daf57497a..8c92836f9 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -20,9 +20,12 @@ from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, + completion_request_to_prompt, ) from .config import TogetherImplConfig @@ -41,6 +44,7 @@ TOGETHER_SUPPORTED_MODELS = { class TogetherInferenceAdapter( ModelRegistryHelper, Inference, NeedsRequestProviderData ): + def __init__(self, config: TogetherImplConfig) -> None: ModelRegistryHelper.__init__( self, stack_to_provider_models_map=TOGETHER_SUPPORTED_MODELS @@ -49,7 +53,7 @@ class TogetherInferenceAdapter( self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: - return + pass async def shutdown(self) -> None: pass @@ -63,7 +67,76 @@ class TogetherInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - raise NotImplementedError() + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_client(self) -> Together: + together_api_key = None + if self.config.api_key is not None: + together_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.together_api_key: + raise ValueError( + 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' + ) + together_api_key = provider_data.together_api_key + return Together(api_key=together_api_key) + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params_for_completion(request) + r = self._get_client().completions.create(**params) + return process_completion_response(r, self.formatter) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = self._get_params_for_completion(request) + + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_client().completions.create(**params) + for chunk in s: + yield chunk + + stream = _to_async_generator() + async for chunk in process_completion_stream_response(stream, self.formatter): + yield chunk + + def _build_options( + self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + ) -> dict: + options = get_sampling_options(sampling_params) + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + raise NotImplementedError("Grammar response format not supported yet") + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return options + + def _get_params_for_completion(self, request: CompletionRequest) -> dict: + return { + "model": self.map_to_provider_model(request.model), + "prompt": completion_request_to_prompt(request, self.formatter), + "stream": request.stream, + **self._build_options(request.sampling_params, request.response_format), + } async def chat_completion( self, @@ -77,18 +150,7 @@ class TogetherInferenceAdapter( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - together_api_key = None - if self.config.api_key is not None: - together_api_key = self.config.api_key - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.together_api_key: - raise ValueError( - 'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": }' - ) - together_api_key = provider_data.together_api_key - client = Together(api_key=together_api_key) request = ChatCompletionRequest( model=model, messages=messages, @@ -102,25 +164,25 @@ class TogetherInferenceAdapter( ) if stream: - return self._stream_chat_completion(request, client) + return self._stream_chat_completion(request) else: - return await self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: Together + self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = self._get_params(request) - r = client.completions.create(**params) + r = self._get_client().completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( - self, request: ChatCompletionRequest, client: Together + self, request: ChatCompletionRequest ) -> AsyncGenerator: params = self._get_params(request) # if we shift to TogetherAsyncClient, we won't need this wrapper async def _to_async_generator(): - s = client.completions.create(**params) + s = self._get_client().completions.create(**params) for chunk in s: yield chunk @@ -131,23 +193,11 @@ class TogetherInferenceAdapter( yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: - options = get_sampling_options(request.sampling_params) - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["response_format"] = { - "type": "json_object", - "schema": fmt.schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - raise NotImplementedError("Grammar response format not supported yet") - else: - raise ValueError(f"Unknown response format {fmt.type}") - return { "model": self.map_to_provider_model(request.model), "prompt": chat_completion_request_to_prompt(request, self.formatter), "stream": request.stream, - **options, + **self._build_options(request.sampling_params, request.response_format), } async def embeddings( diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index c7cbdd592..8b803808d 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -138,11 +138,12 @@ async def test_completion(inference_settings): "meta-reference", "remote::ollama", "remote::tgi", + "remote::together", ): pytest.skip("Other inference providers don't support completion() yet") response = await inference_impl.completion( - content="Roses are red,", + content="Micheael Jordan is born in ", stream=False, model=params["model"], sampling_params=SamplingParams( @@ -151,7 +152,7 @@ async def test_completion(inference_settings): ) assert isinstance(response, CompletionResponse) - assert "violets are blue" in response.content + assert "1963" in response.content chunks = [ r @@ -180,6 +181,7 @@ async def test_completions_structured_output(inference_settings): if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::tgi", + "remote::together", ): pytest.skip( "Other inference providers don't support structured output in completions yet"