forked from phoenix-oss/llama-stack-mirror
		
	completion() for tgi (#295)
This commit is contained in:
		
							parent
							
								
									cb84034567
								
							
						
					
					
						commit
						3e1c3fdb3f
					
				
					 9 changed files with 173 additions and 35 deletions
				
			
		|  | @ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): | |||
|             "model": self.map_to_provider_model(request.model), | ||||
|             "prompt": chat_completion_request_to_prompt(request, self.formatter), | ||||
|             "stream": request.stream, | ||||
|             **get_sampling_options(request), | ||||
|             **get_sampling_options(request.sampling_params), | ||||
|         } | ||||
| 
 | ||||
|     async def embeddings( | ||||
|  |  | |||
|  | @ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): | |||
|         if prompt.startswith("<|begin_of_text|>"): | ||||
|             prompt = prompt[len("<|begin_of_text|>") :] | ||||
| 
 | ||||
|         options = get_sampling_options(request) | ||||
|         options = get_sampling_options(request.sampling_params) | ||||
|         options.setdefault("max_tokens", 512) | ||||
| 
 | ||||
|         if fmt := request.response_format: | ||||
|  |  | |||
|  | @ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|             return await self._nonstream_completion(request) | ||||
| 
 | ||||
|     def _get_params_for_completion(self, request: CompletionRequest) -> dict: | ||||
|         sampling_options = get_sampling_options(request) | ||||
|         sampling_options = get_sampling_options(request.sampling_params) | ||||
|         # This is needed since the Ollama API expects num_predict to be set | ||||
|         # for early truncation instead of max_tokens. | ||||
|         if sampling_options["max_tokens"] is not None: | ||||
|  | @ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|         return { | ||||
|             "model": OLLAMA_SUPPORTED_MODELS[request.model], | ||||
|             "prompt": chat_completion_request_to_prompt(request, self.formatter), | ||||
|             "options": get_sampling_options(request), | ||||
|             "options": get_sampling_options(request.sampling_params), | ||||
|             "raw": True, | ||||
|             "stream": request.stream, | ||||
|         } | ||||
|  |  | |||
|  | @ -24,9 +24,12 @@ from llama_stack.providers.utils.inference.openai_compat import ( | |||
|     OpenAICompatCompletionResponse, | ||||
|     process_chat_completion_response, | ||||
|     process_chat_completion_stream_response, | ||||
|     process_completion_response, | ||||
|     process_completion_stream_response, | ||||
| ) | ||||
| from llama_stack.providers.utils.inference.prompt_adapter import ( | ||||
|     chat_completion_request_to_model_input_info, | ||||
|     completion_request_to_prompt_model_input_info, | ||||
| ) | ||||
| 
 | ||||
| from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig | ||||
|  | @ -75,7 +78,98 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): | |||
|         stream: Optional[bool] = False, | ||||
|         logprobs: Optional[LogProbConfig] = None, | ||||
|     ) -> AsyncGenerator: | ||||
|         raise NotImplementedError() | ||||
|         request = CompletionRequest( | ||||
|             model=model, | ||||
|             content=content, | ||||
|             sampling_params=sampling_params, | ||||
|             response_format=response_format, | ||||
|             stream=stream, | ||||
|             logprobs=logprobs, | ||||
|         ) | ||||
|         if stream: | ||||
|             return self._stream_completion(request) | ||||
|         else: | ||||
|             return await self._nonstream_completion(request) | ||||
| 
 | ||||
|     def _get_max_new_tokens(self, sampling_params, input_tokens): | ||||
|         return min( | ||||
|             sampling_params.max_tokens or (self.max_tokens - input_tokens), | ||||
|             self.max_tokens - input_tokens - 1, | ||||
|         ) | ||||
| 
 | ||||
|     def _build_options( | ||||
|         self, | ||||
|         sampling_params: Optional[SamplingParams] = None, | ||||
|         fmt: ResponseFormat = None, | ||||
|     ): | ||||
|         options = get_sampling_options(sampling_params) | ||||
|         # delete key "max_tokens" from options since its not supported by the API | ||||
|         options.pop("max_tokens", None) | ||||
|         if fmt: | ||||
|             if fmt.type == ResponseFormatType.json_schema.value: | ||||
|                 options["grammar"] = { | ||||
|                     "type": "json", | ||||
|                     "value": fmt.schema, | ||||
|                 } | ||||
|             elif fmt.type == ResponseFormatType.grammar.value: | ||||
|                 raise ValueError("Grammar response format not supported yet") | ||||
|             else: | ||||
|                 raise ValueError(f"Unexpected response format: {fmt.type}") | ||||
| 
 | ||||
|         return options | ||||
| 
 | ||||
|     def _get_params_for_completion(self, request: CompletionRequest) -> dict: | ||||
|         prompt, input_tokens = completion_request_to_prompt_model_input_info( | ||||
|             request, self.formatter | ||||
|         ) | ||||
| 
 | ||||
|         return dict( | ||||
|             prompt=prompt, | ||||
|             stream=request.stream, | ||||
|             details=True, | ||||
|             max_new_tokens=self._get_max_new_tokens( | ||||
|                 request.sampling_params, input_tokens | ||||
|             ), | ||||
|             stop_sequences=["<|eom_id|>", "<|eot_id|>"], | ||||
|             **self._build_options(request.sampling_params, request.response_format), | ||||
|         ) | ||||
| 
 | ||||
|     async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: | ||||
|         params = self._get_params_for_completion(request) | ||||
| 
 | ||||
|         async def _generate_and_convert_to_openai_compat(): | ||||
|             s = await self.client.text_generation(**params) | ||||
|             async for chunk in s: | ||||
|                 token_result = chunk.token | ||||
|                 finish_reason = None | ||||
|                 if chunk.details: | ||||
|                     finish_reason = chunk.details.finish_reason | ||||
| 
 | ||||
|                 choice = OpenAICompatCompletionChoice( | ||||
|                     text=token_result.text, finish_reason=finish_reason | ||||
|                 ) | ||||
|                 yield OpenAICompatCompletionResponse( | ||||
|                     choices=[choice], | ||||
|                 ) | ||||
| 
 | ||||
|         stream = _generate_and_convert_to_openai_compat() | ||||
|         async for chunk in process_completion_stream_response(stream, self.formatter): | ||||
|             yield chunk | ||||
| 
 | ||||
|     async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: | ||||
|         params = self._get_params_for_completion(request) | ||||
|         r = await self.client.text_generation(**params) | ||||
| 
 | ||||
|         choice = OpenAICompatCompletionChoice( | ||||
|             finish_reason=r.details.finish_reason, | ||||
|             text="".join(t.text for t in r.details.tokens), | ||||
|         ) | ||||
| 
 | ||||
|         response = OpenAICompatCompletionResponse( | ||||
|             choices=[choice], | ||||
|         ) | ||||
| 
 | ||||
|         return process_completion_response(response, self.formatter) | ||||
| 
 | ||||
|     async def chat_completion( | ||||
|         self, | ||||
|  | @ -146,29 +240,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): | |||
|         prompt, input_tokens = chat_completion_request_to_model_input_info( | ||||
|             request, self.formatter | ||||
|         ) | ||||
|         max_new_tokens = min( | ||||
|             request.sampling_params.max_tokens or (self.max_tokens - input_tokens), | ||||
|             self.max_tokens - input_tokens - 1, | ||||
|         ) | ||||
|         options = get_sampling_options(request) | ||||
|         if fmt := request.response_format: | ||||
|             if fmt.type == ResponseFormatType.json_schema.value: | ||||
|                 options["grammar"] = { | ||||
|                     "type": "json", | ||||
|                     "value": fmt.schema, | ||||
|                 } | ||||
|             elif fmt.type == ResponseFormatType.grammar.value: | ||||
|                 raise ValueError("Grammar response format not supported yet") | ||||
|             else: | ||||
|                 raise ValueError(f"Unexpected response format: {fmt.type}") | ||||
| 
 | ||||
|         return dict( | ||||
|             prompt=prompt, | ||||
|             stream=request.stream, | ||||
|             details=True, | ||||
|             max_new_tokens=max_new_tokens, | ||||
|             max_new_tokens=self._get_max_new_tokens( | ||||
|                 request.sampling_params, input_tokens | ||||
|             ), | ||||
|             stop_sequences=["<|eom_id|>", "<|eot_id|>"], | ||||
|             **options, | ||||
|             **self._build_options(request.sampling_params, request.response_format), | ||||
|         ) | ||||
| 
 | ||||
|     async def embeddings( | ||||
|  |  | |||
|  | @ -131,7 +131,7 @@ class TogetherInferenceAdapter( | |||
|             yield chunk | ||||
| 
 | ||||
|     def _get_params(self, request: ChatCompletionRequest) -> dict: | ||||
|         options = get_sampling_options(request) | ||||
|         options = get_sampling_options(request.sampling_params) | ||||
|         if fmt := request.response_format: | ||||
|             if fmt.type == ResponseFormatType.json_schema.value: | ||||
|                 options["response_format"] = { | ||||
|  |  | |||
|  | @ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): | |||
|             "model": VLLM_SUPPORTED_MODELS[request.model], | ||||
|             "prompt": chat_completion_request_to_prompt(request, self.formatter), | ||||
|             "stream": request.stream, | ||||
|             **get_sampling_options(request), | ||||
|             **get_sampling_options(request.sampling_params), | ||||
|         } | ||||
| 
 | ||||
|     async def embeddings( | ||||
|  |  | |||
|  | @ -137,6 +137,7 @@ async def test_completion(inference_settings): | |||
|     if provider.__provider_spec__.provider_type not in ( | ||||
|         "meta-reference", | ||||
|         "remote::ollama", | ||||
|         "remote::tgi", | ||||
|     ): | ||||
|         pytest.skip("Other inference providers don't support completion() yet") | ||||
| 
 | ||||
|  | @ -170,6 +171,46 @@ async def test_completion(inference_settings): | |||
|     assert last.stop_reason == StopReason.out_of_tokens | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| async def test_completions_structured_output(inference_settings): | ||||
|     inference_impl = inference_settings["impl"] | ||||
|     params = inference_settings["common_params"] | ||||
| 
 | ||||
|     provider = inference_impl.routing_table.get_provider_impl(params["model"]) | ||||
|     if provider.__provider_spec__.provider_type not in ( | ||||
|         "meta-reference", | ||||
|         "remote::tgi", | ||||
|     ): | ||||
|         pytest.skip( | ||||
|             "Other inference providers don't support structured output in completions yet" | ||||
|         ) | ||||
| 
 | ||||
|     class Output(BaseModel): | ||||
|         name: str | ||||
|         year_born: str | ||||
|         year_retired: str | ||||
| 
 | ||||
|     user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003." | ||||
|     response = await inference_impl.completion( | ||||
|         content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ", | ||||
|         stream=False, | ||||
|         model=params["model"], | ||||
|         sampling_params=SamplingParams( | ||||
|             max_tokens=50, | ||||
|         ), | ||||
|         response_format=JsonResponseFormat( | ||||
|             schema=Output.model_json_schema(), | ||||
|         ), | ||||
|     ) | ||||
|     assert isinstance(response, CompletionResponse) | ||||
|     assert isinstance(response.content, str) | ||||
| 
 | ||||
|     answer = Output.parse_raw(response.content) | ||||
|     assert answer.name == "Michael Jordan" | ||||
|     assert answer.year_born == "1963" | ||||
|     assert answer.year_retired == "2003" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| async def test_chat_completion_non_streaming(inference_settings, sample_messages): | ||||
|     inference_impl = inference_settings["impl"] | ||||
|  |  | |||
|  | @ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel): | |||
|     choices: List[OpenAICompatCompletionChoice] | ||||
| 
 | ||||
| 
 | ||||
| def get_sampling_options(request: ChatCompletionRequest) -> dict: | ||||
| def get_sampling_options(params: SamplingParams) -> dict: | ||||
|     options = {} | ||||
|     if params := request.sampling_params: | ||||
|     if params: | ||||
|         for attr in {"temperature", "top_p", "top_k", "max_tokens"}: | ||||
|             if getattr(params, attr): | ||||
|                 options[attr] = getattr(params, attr) | ||||
|  | @ -64,7 +64,18 @@ def process_completion_response( | |||
|     response: OpenAICompatCompletionResponse, formatter: ChatFormat | ||||
| ) -> CompletionResponse: | ||||
|     choice = response.choices[0] | ||||
| 
 | ||||
|     # drop suffix <eot_id> if present and return stop reason as end of turn | ||||
|     if choice.text.endswith("<|eot_id|>"): | ||||
|         return CompletionResponse( | ||||
|             stop_reason=StopReason.end_of_turn, | ||||
|             content=choice.text[: -len("<|eot_id|>")], | ||||
|         ) | ||||
|     # drop suffix <eom_id> if present and return stop reason as end of message | ||||
|     if choice.text.endswith("<|eom_id|>"): | ||||
|         return CompletionResponse( | ||||
|             stop_reason=StopReason.end_of_message, | ||||
|             content=choice.text[: -len("<|eom_id|>")], | ||||
|         ) | ||||
|     return CompletionResponse( | ||||
|         stop_reason=get_stop_reason(choice.finish_reason), | ||||
|         content=choice.text, | ||||
|  | @ -95,13 +106,6 @@ async def process_completion_stream_response( | |||
|         choice = chunk.choices[0] | ||||
|         finish_reason = choice.finish_reason | ||||
| 
 | ||||
|         if finish_reason: | ||||
|             if finish_reason in ["stop", "eos", "eos_token"]: | ||||
|                 stop_reason = StopReason.end_of_turn | ||||
|             elif finish_reason == "length": | ||||
|                 stop_reason = StopReason.out_of_tokens | ||||
|             break | ||||
| 
 | ||||
|         text = text_from_choice(choice) | ||||
|         if text == "<|eot_id|>": | ||||
|             stop_reason = StopReason.end_of_turn | ||||
|  | @ -115,6 +119,12 @@ async def process_completion_stream_response( | |||
|             delta=text, | ||||
|             stop_reason=stop_reason, | ||||
|         ) | ||||
|         if finish_reason: | ||||
|             if finish_reason in ["stop", "eos", "eos_token"]: | ||||
|                 stop_reason = StopReason.end_of_turn | ||||
|             elif finish_reason == "length": | ||||
|                 stop_reason = StopReason.out_of_tokens | ||||
|             break | ||||
| 
 | ||||
|     yield CompletionResponseStreamChunk( | ||||
|         delta="", | ||||
|  |  | |||
|  | @ -31,6 +31,13 @@ def completion_request_to_prompt( | |||
|     return formatter.tokenizer.decode(model_input.tokens) | ||||
| 
 | ||||
| 
 | ||||
| def completion_request_to_prompt_model_input_info( | ||||
|     request: CompletionRequest, formatter: ChatFormat | ||||
| ) -> Tuple[str, int]: | ||||
|     model_input = formatter.encode_content(request.content) | ||||
|     return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens)) | ||||
| 
 | ||||
| 
 | ||||
| def chat_completion_request_to_prompt( | ||||
|     request: ChatCompletionRequest, formatter: ChatFormat | ||||
| ) -> str: | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue