diff --git a/temporal-spring-ai/build.gradle b/temporal-spring-ai/build.gradle index 12b66cc19..206b361d0 100644 --- a/temporal-spring-ai/build.gradle +++ b/temporal-spring-ai/build.gradle @@ -46,6 +46,8 @@ dependencies { testImplementation "org.mockito:mockito-core:${mockitoVersion}" testImplementation 'org.springframework.boot:spring-boot-starter-test' testImplementation 'org.springframework.ai:spring-ai-rag' + // Needed only so McpPluginTest can mock/reference McpSyncClient directly. + testImplementation 'org.springframework.ai:spring-ai-mcp' // Needed only so tests can reference Spring AI's NonTransientAiException to // verify the plugin's default retry classification. testImplementation 'org.springframework.ai:spring-ai-retry' diff --git a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java index 2f3635cfd..0208e6e1e 100644 --- a/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java +++ b/temporal-spring-ai/src/main/java/io/temporal/springai/plugin/McpPlugin.java @@ -6,6 +6,7 @@ import io.temporal.worker.Worker; import java.util.ArrayList; import java.util.List; +import java.util.Map; import javax.annotation.Nonnull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,22 +40,22 @@ public void setApplicationContext(ApplicationContext applicationContext) throws this.applicationContext = applicationContext; } - @SuppressWarnings("unchecked") private List getMcpClients() { if (!mcpClients.isEmpty()) { return mcpClients; } + if (applicationContext == null) { + return mcpClients; + } - if (applicationContext != null && applicationContext.containsBean("mcpSyncClients")) { - try { - Object bean = applicationContext.getBean("mcpSyncClients"); - if (bean instanceof List clientList && !clientList.isEmpty()) { - mcpClients = (List) clientList; - log.info("Found {} MCP client(s) in ApplicationContext", mcpClients.size()); - } - } catch (Exception e) { - log.debug("Failed to get mcpSyncClients bean: {}", e.getMessage()); + try { + Map beans = applicationContext.getBeansOfType(McpSyncClient.class); + if (!beans.isEmpty()) { + mcpClients = List.copyOf(beans.values()); + log.info("Discovered {} MCP client bean(s): {}", beans.size(), beans.keySet()); } + } catch (Exception e) { + log.debug("Failed to look up McpSyncClient beans: {}", e.getMessage()); } return mcpClients; diff --git a/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/McpPluginTest.java b/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/McpPluginTest.java new file mode 100644 index 000000000..7a2a14f34 --- /dev/null +++ b/temporal-spring-ai/src/test/java/io/temporal/springai/plugin/McpPluginTest.java @@ -0,0 +1,125 @@ +package io.temporal.springai.plugin; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import io.temporal.springai.mcp.McpClientActivityImpl; +import io.temporal.worker.Worker; +import java.util.LinkedHashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.context.ApplicationContext; + +class McpPluginTest { + + @Test + void discoversMcpClientBeansByType() { + McpSyncClient clientA = mockClientNamed("alpha"); + McpSyncClient clientB = mockClientNamed("beta"); + + // Spring's getBeansOfType keeps insertion order via LinkedHashMap; use that for determinism. + Map beans = new LinkedHashMap<>(); + beans.put("mcpClientAlpha", clientA); + beans.put("mcpClientBeta", clientB); + + ApplicationContext ctx = mock(ApplicationContext.class); + when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans); + + McpPlugin plugin = new McpPlugin(); + plugin.setApplicationContext(ctx); + + Worker worker = mock(Worker.class); + plugin.initializeWorker("mcp-tq", worker); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture()); + Object registered = captor.getValue(); + assertEquals(McpClientActivityImpl.class, registered.getClass()); + } + + @Test + void twoMcpBeans_duplicateClientInfoNames_throws() { + // Two distinct beans that both report the same clientInfo().name() — the activity impl + // has to reject this because it keys its internal client map by that name. + McpSyncClient clientA = mockClientNamed("shared"); + McpSyncClient clientB = mockClientNamed("shared"); + + Map beans = new LinkedHashMap<>(); + beans.put("mcpClientA", clientA); + beans.put("mcpClientB", clientB); + + ApplicationContext ctx = mock(ApplicationContext.class); + when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans); + + McpPlugin plugin = new McpPlugin(); + plugin.setApplicationContext(ctx); + + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> plugin.initializeWorker("mcp-tq", mock(Worker.class))); + assertTrue( + thrown.getMessage().contains("shared"), + "expected duplicate name in message, got: " + thrown.getMessage()); + } + + @Test + void noMcpBeans_defersWorker_thenClearsAfterSingletonsInstantiated() { + ApplicationContext ctx = mock(ApplicationContext.class); + when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(Map.of()); + + McpPlugin plugin = new McpPlugin(); + plugin.setApplicationContext(ctx); + + Worker worker = mock(Worker.class); + plugin.initializeWorker("mcp-tq", worker); + + // No beans → nothing registered yet, worker queued for deferred attempt. + verifyNoInteractions(worker); + + plugin.afterSingletonsInstantiated(); + + // Still no beans — the deferred attempt also finds nothing and doesn't crash. + verify(worker, org.mockito.Mockito.never()).registerActivitiesImplementations((Object[]) any()); + } + + @Test + void beansAppearLate_registeredViaAfterSingletonsInstantiated() { + ApplicationContext ctx = mock(ApplicationContext.class); + // First lookup returns empty (Spring AI MCP bean hasn't been created yet when + // initializeWorker runs). + when(ctx.getBeansOfType(McpSyncClient.class)) + .thenReturn(Map.of()) + .thenReturn(Map.of("mcpClient", mockClientNamed("late"))); + + McpPlugin plugin = new McpPlugin(); + plugin.setApplicationContext(ctx); + + Worker worker = mock(Worker.class); + plugin.initializeWorker("mcp-tq", worker); + verifyNoInteractions(worker); + + plugin.afterSingletonsInstantiated(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Object.class); + verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture()); + assertEquals(McpClientActivityImpl.class, captor.getValue().getClass()); + } + + private static McpSyncClient mockClientNamed(String name) { + McpSyncClient client = mock(McpSyncClient.class); + McpSchema.Implementation info = new McpSchema.Implementation(name, "1.0.0"); + when(client.getClientInfo()).thenReturn(info); + return client; + } +}