Skip to content
2 changes: 2 additions & 0 deletions temporal-spring-ai/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ dependencies {
testImplementation "org.mockito:mockito-core:${mockitoVersion}"
testImplementation 'org.springframework.boot:spring-boot-starter-test'
testImplementation 'org.springframework.ai:spring-ai-rag'
// Needed only so McpPluginTest can mock/reference McpSyncClient directly.
testImplementation 'org.springframework.ai:spring-ai-mcp'
// Needed only so tests can reference Spring AI's NonTransientAiException to
// verify the plugin's default retry classification.
testImplementation 'org.springframework.ai:spring-ai-retry'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.temporal.worker.Worker;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -39,22 +40,22 @@ public void setApplicationContext(ApplicationContext applicationContext) throws
this.applicationContext = applicationContext;
}

@SuppressWarnings("unchecked")
private List<McpSyncClient> getMcpClients() {
if (!mcpClients.isEmpty()) {
return mcpClients;
}
if (applicationContext == null) {
return mcpClients;
}

if (applicationContext != null && applicationContext.containsBean("mcpSyncClients")) {
try {
Object bean = applicationContext.getBean("mcpSyncClients");
if (bean instanceof List<?> clientList && !clientList.isEmpty()) {
mcpClients = (List<McpSyncClient>) clientList;
log.info("Found {} MCP client(s) in ApplicationContext", mcpClients.size());
}
} catch (Exception e) {
log.debug("Failed to get mcpSyncClients bean: {}", e.getMessage());
try {
Map<String, McpSyncClient> beans = applicationContext.getBeansOfType(McpSyncClient.class);
if (!beans.isEmpty()) {
mcpClients = List.copyOf(beans.values());
log.info("Discovered {} MCP client bean(s): {}", beans.size(), beans.keySet());
}
} catch (Exception e) {
log.debug("Failed to look up McpSyncClient beans: {}", e.getMessage());
}

return mcpClients;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package io.temporal.springai.plugin;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.temporal.springai.mcp.McpClientActivityImpl;
import io.temporal.worker.Worker;
import java.util.LinkedHashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.context.ApplicationContext;

class McpPluginTest {

@Test
void discoversMcpClientBeansByType() {
McpSyncClient clientA = mockClientNamed("alpha");
McpSyncClient clientB = mockClientNamed("beta");

// Spring's getBeansOfType keeps insertion order via LinkedHashMap; use that for determinism.
Map<String, McpSyncClient> beans = new LinkedHashMap<>();
beans.put("mcpClientAlpha", clientA);
beans.put("mcpClientBeta", clientB);

ApplicationContext ctx = mock(ApplicationContext.class);
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans);

McpPlugin plugin = new McpPlugin();
plugin.setApplicationContext(ctx);

Worker worker = mock(Worker.class);
plugin.initializeWorker("mcp-tq", worker);

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture());
Object registered = captor.getValue();
assertEquals(McpClientActivityImpl.class, registered.getClass());
}

@Test
void twoMcpBeans_duplicateClientInfoNames_throws() {
// Two distinct beans that both report the same clientInfo().name() — the activity impl
// has to reject this because it keys its internal client map by that name.
McpSyncClient clientA = mockClientNamed("shared");
McpSyncClient clientB = mockClientNamed("shared");

Map<String, McpSyncClient> beans = new LinkedHashMap<>();
beans.put("mcpClientA", clientA);
beans.put("mcpClientB", clientB);

ApplicationContext ctx = mock(ApplicationContext.class);
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(beans);

McpPlugin plugin = new McpPlugin();
plugin.setApplicationContext(ctx);

IllegalArgumentException thrown =
assertThrows(
IllegalArgumentException.class,
() -> plugin.initializeWorker("mcp-tq", mock(Worker.class)));
assertTrue(
thrown.getMessage().contains("shared"),
"expected duplicate name in message, got: " + thrown.getMessage());
}

@Test
void noMcpBeans_defersWorker_thenClearsAfterSingletonsInstantiated() {
ApplicationContext ctx = mock(ApplicationContext.class);
when(ctx.getBeansOfType(McpSyncClient.class)).thenReturn(Map.of());

McpPlugin plugin = new McpPlugin();
plugin.setApplicationContext(ctx);

Worker worker = mock(Worker.class);
plugin.initializeWorker("mcp-tq", worker);

// No beans → nothing registered yet, worker queued for deferred attempt.
verifyNoInteractions(worker);

plugin.afterSingletonsInstantiated();

// Still no beans — the deferred attempt also finds nothing and doesn't crash.
verify(worker, org.mockito.Mockito.never()).registerActivitiesImplementations((Object[]) any());
}

@Test
void beansAppearLate_registeredViaAfterSingletonsInstantiated() {
ApplicationContext ctx = mock(ApplicationContext.class);
// First lookup returns empty (Spring AI MCP bean hasn't been created yet when
// initializeWorker runs).
when(ctx.getBeansOfType(McpSyncClient.class))
.thenReturn(Map.of())
.thenReturn(Map.of("mcpClient", mockClientNamed("late")));

McpPlugin plugin = new McpPlugin();
plugin.setApplicationContext(ctx);

Worker worker = mock(Worker.class);
plugin.initializeWorker("mcp-tq", worker);
verifyNoInteractions(worker);

plugin.afterSingletonsInstantiated();

ArgumentCaptor<Object> captor = ArgumentCaptor.forClass(Object.class);
verify(worker, atLeastOnce()).registerActivitiesImplementations(captor.capture());
assertEquals(McpClientActivityImpl.class, captor.getValue().getClass());
}

private static McpSyncClient mockClientNamed(String name) {
McpSyncClient client = mock(McpSyncClient.class);
McpSchema.Implementation info = new McpSchema.Implementation(name, "1.0.0");
when(client.getClientInfo()).thenReturn(info);
return client;
}
}
Loading