diff --git a/spring-ai-modules/pom.xml b/spring-ai-modules/pom.xml index a8a9b050cb30..4bc8bdfd2f83 100644 --- a/spring-ai-modules/pom.xml +++ b/spring-ai-modules/pom.xml @@ -16,6 +16,7 @@ + spring-ai-chat-stream spring-ai-introduction spring-ai-mcp spring-ai-text-to-sql diff --git a/spring-ai-modules/spring-ai-chat-stream/pom.xml b/spring-ai-modules/spring-ai-chat-stream/pom.xml new file mode 100644 index 000000000000..ea9a0f9e6f37 --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/pom.xml @@ -0,0 +1,99 @@ + + + 4.0.0 + com.baeldung + spring-ai-chat-stream + 0.0.1 + spring-ai-chat-stream + + + com.baeldung + spring-ai-modules + 0.0.1 + ../pom.xml + + + + + spring-milestones + Spring Milestones + https://repo.spring.io/milestone + + false + + + true + + + + spring-snapshots + Spring Snapshots + https://repo.spring.io/snapshot + + true + + + false + + + + + + + + org.springframework.ai + spring-ai-bom + ${spring-ai.version} + pom + import + + + org.springframework.boot + spring-boot-dependencies + ${spring-boot.version} + pom + import + + + + + + + org.springframework.boot + spring-boot-starter-webflux + + + org.springframework.ai + spring-ai-starter-model-openai + + + org.springframework.boot + spring-boot-starter-test + test + + + + + 21 + 1.0.1 + 3.5.5 + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + + ${java.version} + + + + + + \ No newline at end of file diff --git a/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/Application.java b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/Application.java new file mode 100644 index 000000000000..d793aff5a7a0 --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/Application.java @@ -0,0 +1,14 @@ +package com.baeldung.springai.streaming; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.ComponentScan; + +@SpringBootApplication +public class Application { + + public static void main(String[] args) { + SpringApplication.run(Application.class, args); + } + +} diff --git a/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatController.java b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatController.java new file mode 100644 index 000000000000..645ef3ce2f61 --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatController.java @@ -0,0 +1,46 @@ +package com.baeldung.springai.streaming; + +import javax.validation.Valid; + +import org.springframework.http.MediaType; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@RestController +@Validated +public class ChatController { + + private final ChatService chatService; + + public ChatController(ChatService chatService) { + this.chatService = chatService; + } + + @PostMapping(value = "/chat") + public Mono chat(@RequestBody @Valid ChatRequest request) { + return chatService.chatAsWord(request.getPrompt()) + .collectList() + .map(list -> String.join("", list)); + } + + @PostMapping(value = "/chat-word", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux chatAsWord(@RequestBody @Valid ChatRequest request) { + return chatService.chatAsWord(request.getPrompt()); + } + + @PostMapping(value = "/chat-chunk", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux chatAsChunk(@RequestBody @Valid ChatRequest request) { + return chatService.chatAsChunk(request.getPrompt()); + } + + @PostMapping(value = "/chat-json", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux chatAsJson(@RequestBody @Valid ChatRequest request) { + return chatService.chatAsJson(request.getPrompt()); + } + +} diff --git a/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatRequest.java b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatRequest.java new file mode 100644 index 000000000000..f3b192216031 --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatRequest.java @@ -0,0 +1,18 @@ +package com.baeldung.springai.streaming; + +import javax.validation.constraints.NotNull; + +public class ChatRequest { + + @NotNull + private String prompt; + + public String getPrompt() { + return prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + +} diff --git a/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatService.java b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatService.java new file mode 100644 index 000000000000..2d361b3cefa3 --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/src/main/java/com/baeldung/springai/streaming/ChatService.java @@ -0,0 +1,101 @@ +package com.baeldung.springai.streaming; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.stereotype.Component; + +import reactor.core.publisher.Flux; + +@Component +public class ChatService { + + private final ChatClient chatClient; + + public ChatService(ChatModel chatModel) { + this.chatClient = ChatClient.builder(chatModel) + .build(); + } + + public Flux chat(String prompt) { + return chatClient.prompt() + .user(userMessage -> userMessage.text(prompt)) + .stream() + .content(); + } + + public Flux chatAsWord(String prompt) { + return chatClient.prompt() + .user(userMessage -> userMessage.text(prompt)) + .stream() + .content(); + } + + public Flux chatAsChunk(String prompt) { + return chatClient.prompt() + .user(userMessage -> userMessage.text(prompt)) + .stream() + .content() + .transform(flux -> toChunk(flux, 100)); + } + + public Flux chatAsJson(String prompt) { + return chatClient.prompt() + .system(systemMessage -> systemMessage.text( + """ + Respond in NDJSON format. + Each JSON object should contains around 100 characters. + Sample json object format: {"part":0,"text":"Once in a small town..."} + """)) + .user(userMessage -> userMessage.text(prompt)) + .stream() + .content() + .transform(this::toJsonChunk); + } + + private Flux toChunk(Flux tokenFlux, int chunkSize) { + return Flux.create(sink -> { + StringBuilder buffer = new StringBuilder(); + tokenFlux.subscribe( + token -> { + buffer.append(token); + if (buffer.length() >= chunkSize) { + sink.next(buffer.toString()); + buffer.setLength(0); + } + }, + sink::error, + () -> { + if (buffer.length() > 0) { + sink.next(buffer.toString()); + } + sink.complete(); + } + ); + }); + } + + private Flux toJsonChunk(Flux tokenFlux) { + return Flux.create(sink -> { + StringBuilder buffer = new StringBuilder(); + tokenFlux.subscribe( + token -> { + buffer.append(token); + int idx; + if ((idx = buffer.indexOf("\n")) >= 0) { + String line = buffer.substring(0, idx); + sink.next(line); + buffer.delete(0, idx + 1); + } + }, + sink::error, + () -> { + if (buffer.length() > 0) { + sink.next(buffer.toString()); + } + sink.complete(); + } + ); + }); + } + +} \ No newline at end of file diff --git a/spring-ai-modules/spring-ai-chat-stream/src/main/resources/application.yml b/spring-ai-modules/spring-ai-chat-stream/src/main/resources/application.yml new file mode 100644 index 000000000000..20cc0b043bce --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/src/main/resources/application.yml @@ -0,0 +1,4 @@ +spring: + ai: + openai: + api-key: "" diff --git a/spring-ai-modules/spring-ai-chat-stream/src/main/resources/logback.xml b/spring-ai-modules/spring-ai-chat-stream/src/main/resources/logback.xml new file mode 100644 index 000000000000..449efbdaebb0 --- /dev/null +++ b/spring-ai-modules/spring-ai-chat-stream/src/main/resources/logback.xml @@ -0,0 +1,15 @@ + + + + [%d{yyyy-MM-dd HH:mm:ss}] [%p] [%c{1}] - %m%n + + + + + + + + + + + \ No newline at end of file