diff --git a/spring-ai-3/pom.xml b/spring-ai-3/pom.xml index 860ecc62148d..695da7b8f951 100644 --- a/spring-ai-3/pom.xml +++ b/spring-ai-3/pom.xml @@ -51,6 +51,10 @@ org.springframework.ai spring-ai-openai-spring-boot-starter + + org.springframework.ai + spring-ai-mcp-server-webmvc-spring-boot-starter + org.hsqldb hsqldb @@ -61,6 +65,16 @@ spring-ai-starter-model-openai ${spring-ai-start-model-openai.version} + + org.springframework.boot + spring-boot-starter-oauth2-resource-server + ${oauth2-resource-server.version} + + + org.springframework.boot + spring-boot-starter-oauth2-authorization-server + ${oauth2-authorization-server.version} + @@ -146,6 +160,8 @@ 3.4.5 1.0.0-M6 1.0.0-M7 + 3.4.2 + 3.3.3 diff --git a/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/McpServerApplication.java b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/McpServerApplication.java new file mode 100644 index 000000000000..24d25f9ec9c7 --- /dev/null +++ b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/McpServerApplication.java @@ -0,0 +1,33 @@ +package com.baeldung.springai.mcp.oauth2; + +import org.springframework.ai.autoconfigure.chat.client.ChatClientAutoConfiguration; +import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; +import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.model.openai.autoconfigure.*; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration; +import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration; + +@SpringBootApplication(exclude = { + ChatClientAutoConfiguration.class, + MongoAutoConfiguration.class, + MistralAiAutoConfiguration.class, + MongoDataAutoConfiguration.class, + org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration.class, + org.springframework.ai.vectorstore.mongodb.autoconfigure.MongoDBAtlasVectorStoreAutoConfiguration.class, + OpenAiAudioSpeechAutoConfiguration.class, + OpenAiAutoConfiguration.class, + OpenAiAudioTranscriptionAutoConfiguration.class, + OpenAiChatAutoConfiguration.class, + OpenAiEmbeddingAutoConfiguration.class, + OpenAiImageAutoConfiguration.class, + OpenAiModerationAutoConfiguration.class}) +class McpServerApplication { + + public static void main(String[] args) { + SpringApplication app = new SpringApplication(McpServerApplication.class); + app.setAdditionalProfiles("mcp"); + app.run(args); + } +} \ No newline at end of file diff --git a/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/StockInformationHolder.java b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/StockInformationHolder.java new file mode 100644 index 000000000000..d30aabb73651 --- /dev/null +++ b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/StockInformationHolder.java @@ -0,0 +1,17 @@ +package com.baeldung.springai.mcp.oauth2; + +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.annotation.ToolParam; + +public class StockInformationHolder { + @Tool(description = "Get stock price for a company symbol") + public String getStockPrice(@ToolParam String symbol) { + if ("AAPL".equalsIgnoreCase(symbol)) { + return "AAPL: $150.00"; + } else if ("GOOGL".equalsIgnoreCase(symbol)) { + return "GOOGL: $2800.00"; + } else { + return symbol + ": Data not available"; + } + } +} diff --git a/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/configuration/McpServerConfiguration.java b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/configuration/McpServerConfiguration.java new file mode 100644 index 000000000000..2d3d1fa1bded --- /dev/null +++ b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/configuration/McpServerConfiguration.java @@ -0,0 +1,21 @@ +package com.baeldung.springai.mcp.oauth2.configuration; + +import com.baeldung.springai.mcp.oauth2.StockInformationHolder; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.method.MethodToolCallbackProvider; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Profile; + +@Profile("mcp") +@Configuration +public class McpServerConfiguration { + + @Bean + public ToolCallbackProvider stockTools() { + return MethodToolCallbackProvider + .builder() + .toolObjects(new StockInformationHolder()) + .build(); + } +} diff --git a/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/configuration/McpServerSecurityConfiguration.java b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/configuration/McpServerSecurityConfiguration.java new file mode 100644 index 000000000000..fd37fcd5499e --- /dev/null +++ b/spring-ai-3/src/main/java/com/baeldung/springai/mcp/oauth2/configuration/McpServerSecurityConfiguration.java @@ -0,0 +1,28 @@ +package com.baeldung.springai.mcp.oauth2.configuration; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.Customizer; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer; +import org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers.OAuth2AuthorizationServerConfigurer; +import org.springframework.security.web.SecurityFilterChain; + +@Configuration +@EnableWebSecurity +public class McpServerSecurityConfiguration { + @Bean + public SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + return http + .authorizeHttpRequests(auth -> auth + .requestMatchers("/mcp/**").authenticated() + .requestMatchers("/sse").authenticated() + .anyRequest().permitAll()) + .with(OAuth2AuthorizationServerConfigurer.authorizationServer(), Customizer.withDefaults()) + .oauth2ResourceServer(oauth2 -> oauth2.jwt(Customizer.withDefaults())) + .csrf(CsrfConfigurer::disable) + .cors(Customizer.withDefaults()) + .build(); + } +} diff --git a/spring-ai-3/src/main/resources/application-mcp.yml b/spring-ai-3/src/main/resources/application-mcp.yml new file mode 100644 index 000000000000..7e6b2ad7ae3e --- /dev/null +++ b/spring-ai-3/src/main/resources/application-mcp.yml @@ -0,0 +1,19 @@ +spring: + security: + oauth2: + authorizationserver: + client: + oidc-client: + registration: + client-id: mcp-client + client-secret: "{noop}secret" + client-authentication-methods: client_secret_basic + authorization-grant-types: client_credentials + # Avoid starting docker from the shared codebase + docker: + compose: + enabled: false + +logging: + level: + org.springframework.ai.mcp: DEBUG diff --git a/spring-ai-3/src/test/java/com/baeldung/springai/mcp/oauth2/McpServerOAuth2LiveTest.java b/spring-ai-3/src/test/java/com/baeldung/springai/mcp/oauth2/McpServerOAuth2LiveTest.java new file mode 100644 index 000000000000..76a6121d1884 --- /dev/null +++ b/spring-ai-3/src/test/java/com/baeldung/springai/mcp/oauth2/McpServerOAuth2LiveTest.java @@ -0,0 +1,105 @@ +package com.baeldung.springai.mcp.oauth2; + +import com.fasterxml.jackson.databind.JsonNode; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Base64; + +import static org.assertj.core.api.Assertions.assertThat; + +@ActiveProfiles("mcp") +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +class McpServerOAuth2LiveTest { + + private static final Logger log = LoggerFactory.getLogger(McpServerOAuth2LiveTest.class); + + @LocalServerPort + private int port; + + private WebClient webClient; + + @BeforeEach + void setup() { + webClient = WebClient.create("http://localhost:" + port); + } + + @Test + void givenSecuredMcpServer_whenCallingTheEndpointsWithValidAuthorizationHeader_thenExpectedResponseShouldBeObtained() { + Flux eventStream = webClient.get() + .uri("/sse") + .header("Authorization", obtainAccessToken()) + .accept(MediaType.TEXT_EVENT_STREAM) + .retrieve() + .bodyToFlux(String.class); + + eventStream.subscribe( + data -> { + log.info("Response received: {}", data); + if(!isRequestMessage(data)) { + assertThat(data).containsSequence("AAPL", "$150"); + } + }, + error -> log.error(error.getMessage()), + () -> log.info("Stream completed")); + + Flux sendMessage = webClient.post() + .uri("/mcp/message") + .header("Authorization", obtainAccessToken()) + .contentType(MediaType.APPLICATION_JSON) + .accept(MediaType.TEXT_EVENT_STREAM) + .bodyValue(""" + { + "jsonrpc": "2.0", + "id": "1", + "method": "tools/call", + "params": { + "name": "getStockPrice", + "arguments": { + "arg0": "AAPL" + } + } + } + """) + .retrieve() + .bodyToFlux(String.class); + + sendMessage.blockLast(); + eventStream.blockLast(); + } + + private boolean isRequestMessage(String data) { + return data.contains("/mcp/message"); + } + + public String obtainAccessToken() { + String clientId = "mcp-client"; + String clientSecret = "secret"; + String basicToken = Base64.getEncoder() + .encodeToString((clientId + ":" + clientSecret).getBytes(StandardCharsets.UTF_8)); + + return "Bearer " + webClient.post() + .uri("/oauth2/token") + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + .header(HttpHeaders.AUTHORIZATION, "Basic " + basicToken) + .body(BodyInserters + .fromFormData("grant_type", "client_credentials") + ) + .retrieve() + .bodyToMono(JsonNode.class) + .map(node -> node.get("access_token").asText()) + .block(Duration.ofSeconds(5)); + } +}