Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.internal.logging.InternalLoggerFactory;
import java.net.URI;
import java.util.Objects;
import java.util.function.Supplier;
import org.neo4j.driver.AuthToken;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Config;
Expand All @@ -39,11 +41,13 @@
import org.neo4j.driver.internal.async.connection.ChannelConnectorImpl;
import org.neo4j.driver.internal.async.pool.ConnectionPoolImpl;
import org.neo4j.driver.internal.async.pool.PoolSettings;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cluster.RoutingProcedureClusterCompositionProvider;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.cluster.loadbalancing.LeastConnectedLoadBalancingStrategy;
import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer;
import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancingStrategy;
import org.neo4j.driver.internal.logging.NettyLogging;
import org.neo4j.driver.internal.metrics.DevNullMetricsProvider;
import org.neo4j.driver.internal.metrics.InternalMetricsProvider;
Expand All @@ -70,7 +74,7 @@ public final Driver newInstance(
RetrySettings retrySettings,
Config config,
SecurityPlan securityPlan) {
return newInstance(uri, authToken, routingSettings, retrySettings, config, null, securityPlan);
return newInstance(uri, authToken, routingSettings, retrySettings, config, null, securityPlan, null);
}

public final Driver newInstance(
Expand All @@ -80,7 +84,8 @@ public final Driver newInstance(
RetrySettings retrySettings,
Config config,
EventLoopGroup eventLoopGroup,
SecurityPlan securityPlan) {
SecurityPlan securityPlan,
Supplier<Rediscovery> rediscoverySupplier) {
Bootstrap bootstrap;
boolean ownsEventLoopGroup;
if (eventLoopGroup == null) {
Expand Down Expand Up @@ -119,6 +124,7 @@ public final Driver newInstance(
newRoutingSettings,
retryLogic,
metricsProvider,
rediscoverySupplier,
config);
}

Expand Down Expand Up @@ -185,6 +191,7 @@ private InternalDriver createDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
try {
String scheme = uri.getScheme().toLowerCase();
Expand All @@ -198,6 +205,7 @@ private InternalDriver createDriver(
routingSettings,
retryLogic,
metricsProvider,
rediscoverySupplier,
config);
} else {
assertNoRoutingContext(uri, routingSettings);
Expand Down Expand Up @@ -243,9 +251,10 @@ protected InternalDriver createRoutingDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
ConnectionProvider connectionProvider =
createLoadBalancer(address, connectionPool, eventExecutorGroup, config, routingSettings);
ConnectionProvider connectionProvider = createLoadBalancer(
address, connectionPool, eventExecutorGroup, config, routingSettings, rediscoverySupplier);
SessionFactory sessionFactory = createSessionFactory(connectionProvider, retryLogic, config);
InternalDriver driver = createDriver(securityPlan, sessionFactory, metricsProvider, config);
Logger log = config.logging().getLog(getClass());
Expand Down Expand Up @@ -273,24 +282,41 @@ protected LoadBalancer createLoadBalancer(
ConnectionPool connectionPool,
EventExecutorGroup eventExecutorGroup,
Config config,
RoutingSettings routingSettings) {
LoadBalancingStrategy loadBalancingStrategy =
new LeastConnectedLoadBalancingStrategy(connectionPool, config.logging());
ServerAddressResolver resolver = createResolver(config);
LoadBalancer loadBalancer = new LoadBalancer(
address,
routingSettings,
RoutingSettings routingSettings,
Supplier<Rediscovery> rediscoverySupplier) {
var loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(connectionPool, config.logging());
var resolver = createResolver(config);
var domainNameResolver = Objects.requireNonNull(getDomainNameResolver(), "domainNameResolver must not be null");
var clock = createClock();
var logging = config.logging();
if (rediscoverySupplier == null) {
rediscoverySupplier =
() -> createRediscovery(address, resolver, routingSettings, clock, logging, domainNameResolver);
}
var loadBalancer = new LoadBalancer(
connectionPool,
eventExecutorGroup,
createClock(),
config.logging(),
rediscoverySupplier.get(),
routingSettings,
loadBalancingStrategy,
resolver,
getDomainNameResolver());
eventExecutorGroup,
clock,
logging);
handleNewLoadBalancer(loadBalancer);
return loadBalancer;
}

protected Rediscovery createRediscovery(
BoltServerAddress initialRouter,
ServerAddressResolver resolver,
RoutingSettings settings,
Clock clock,
Logging logging,
DomainNameResolver domainNameResolver) {
var clusterCompositionProvider =
new RoutingProcedureClusterCompositionProvider(clock, settings.routingContext());
return new RediscoveryImpl(initialRouter, clusterCompositionProvider, resolver, logging, domainNameResolver);
}

/**
* Handles new {@link LoadBalancer} instance.
* <p>
Expand All @@ -307,8 +333,6 @@ private static ServerAddressResolver createResolver(Config config) {

/**
* Creates new {@link Clock}.
* <p>
* <b>This method is protected only for testing</b>
*/
protected Clock createClock() {
return Clock.SYSTEM;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.neo4j.driver.internal.cluster;

import static java.util.Objects.requireNonNull;
import static org.neo4j.driver.internal.async.ConnectionContext.PENDING_DATABASE_NAME_EXCEPTION_SUPPLIER;

import java.util.HashMap;
Expand Down Expand Up @@ -72,6 +73,7 @@ public RoutingTableRegistryImpl(
ConnectionPool connectionPool,
Rediscovery rediscovery,
Logging logging) {
requireNonNull(rediscovery, "rediscovery must not be null");
this.factory = factory;
this.routingTableHandlers = routingTableHandlers;
this.principalToDatabaseNameStage = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.exceptions.SessionExpiredException;
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.DomainNameResolver;
import org.neo4j.driver.internal.async.ConnectionContext;
import org.neo4j.driver.internal.async.connection.RoutingConnection;
import org.neo4j.driver.internal.cluster.ClusterCompositionProvider;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.cluster.RoutingProcedureClusterCompositionProvider;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.cluster.RoutingTable;
import org.neo4j.driver.internal.cluster.RoutingTableRegistry;
Expand All @@ -56,7 +52,6 @@
import org.neo4j.driver.internal.spi.ConnectionProvider;
import org.neo4j.driver.internal.util.Clock;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.net.ServerAddressResolver;

public class LoadBalancer implements ConnectionProvider {
private static final String CONNECTION_ACQUISITION_COMPLETION_FAILURE_MESSAGE =
Expand All @@ -73,27 +68,6 @@ public class LoadBalancer implements ConnectionProvider {
private final Rediscovery rediscovery;

public LoadBalancer(
BoltServerAddress initialRouter,
RoutingSettings settings,
ConnectionPool connectionPool,
EventExecutorGroup eventExecutorGroup,
Clock clock,
Logging logging,
LoadBalancingStrategy loadBalancingStrategy,
ServerAddressResolver resolver,
DomainNameResolver domainNameResolver) {
this(
connectionPool,
createRediscovery(
initialRouter, resolver, settings, clock, logging, requireNonNull(domainNameResolver)),
settings,
loadBalancingStrategy,
eventExecutorGroup,
clock,
logging);
}

private LoadBalancer(
ConnectionPool connectionPool,
Rediscovery rediscovery,
RoutingSettings settings,
Expand All @@ -117,6 +91,7 @@ private LoadBalancer(
LoadBalancingStrategy loadBalancingStrategy,
EventExecutorGroup eventExecutorGroup,
Logging logging) {
requireNonNull(rediscovery, "rediscovery must not be null");
this.connectionPool = connectionPool;
this.routingTables = routingTables;
this.rediscovery = rediscovery;
Expand Down Expand Up @@ -281,19 +256,14 @@ private static RoutingTableRegistry createRoutingTables(
connectionPool, rediscovery, clock, logging, settings.routingTablePurgeDelayMs());
}

private static Rediscovery createRediscovery(
BoltServerAddress initialRouter,
ServerAddressResolver resolver,
RoutingSettings settings,
Clock clock,
Logging logging,
DomainNameResolver domainNameResolver) {
ClusterCompositionProvider clusterCompositionProvider =
new RoutingProcedureClusterCompositionProvider(clock, settings.routingContext());
return new RediscoveryImpl(initialRouter, clusterCompositionProvider, resolver, logging, domainNameResolver);
}

private static RuntimeException unknownMode(AccessMode mode) {
return new IllegalArgumentException("Mode '" + mode + "' is not supported");
}

/**
* <b>This method is only for testing</b>
*/
public Rediscovery getRediscovery() {
return rediscovery;
}
}
3 changes: 3 additions & 0 deletions driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import java.net.URI;
import java.util.Iterator;
import java.util.List;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.DriverFactory;
import org.neo4j.driver.internal.InternalDriver;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.metrics.MetricsProvider;
import org.neo4j.driver.internal.retry.RetryLogic;
Expand Down Expand Up @@ -147,6 +149,7 @@ protected InternalDriver createRoutingDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
return driverIterator.next();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ private Driver createDriver(EventLoopGroup eventLoopGroup) {
RetrySettings.DEFAULT,
Config.defaultConfig(),
eventLoopGroup,
SecurityPlanImpl.insecure());
SecurityPlanImpl.insecure(),
null);
}

private void testConnection(Driver driver) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ void testCustomSecurityPlanUsed() {
RetrySettings.DEFAULT,
Config.defaultConfig(),
null,
securityPlan);
securityPlan,
null);

assertFalse(driverFactory.capturedSecurityPlans.isEmpty());
assertTrue(driverFactory.capturedSecurityPlans.stream().allMatch(capturePlan -> capturePlan == securityPlan));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.junit.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand All @@ -41,6 +45,7 @@
import io.netty.bootstrap.Bootstrap;
import io.netty.util.concurrent.EventExecutorGroup;
import java.net.URI;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -55,6 +60,8 @@
import org.neo4j.driver.internal.async.LeakLoggingNetworkSession;
import org.neo4j.driver.internal.async.NetworkSession;
import org.neo4j.driver.internal.async.connection.BootstrapFactory;
import org.neo4j.driver.internal.cluster.Rediscovery;
import org.neo4j.driver.internal.cluster.RediscoveryImpl;
import org.neo4j.driver.internal.cluster.RoutingContext;
import org.neo4j.driver.internal.cluster.RoutingSettings;
import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer;
Expand Down Expand Up @@ -191,6 +198,61 @@ void shouldCreateAppropriateDriverType(String uri) {
}
}

@Test
void shouldUseBuiltInRediscoveryByDefault() {
// GIVEN
var driverFactory = new DriverFactory();
var securityPlan =
new SecuritySettings.SecuritySettingsBuilder().build().createSecurityPlan("neo4j");

// WHEN
var driver = driverFactory.newInstance(
URI.create("neo4j://localhost:7687"),
AuthTokens.none(),
RoutingSettings.DEFAULT,
RetrySettings.DEFAULT,
Config.defaultConfig(),
null,
securityPlan,
null);

// THEN
var sessionFactory = ((InternalDriver) driver).getSessionFactory();
var connectionProvider = ((SessionFactoryImpl) sessionFactory).getConnectionProvider();
var rediscovery = ((LoadBalancer) connectionProvider).getRediscovery();
assertTrue(rediscovery instanceof RediscoveryImpl);
}

@Test
void shouldUseSuppliedRediscovery() {
// GIVEN
var driverFactory = new DriverFactory();
var securityPlan =
new SecuritySettings.SecuritySettingsBuilder().build().createSecurityPlan("neo4j");
@SuppressWarnings("unchecked")
Supplier<Rediscovery> rediscoverySupplier = mock(Supplier.class);
var rediscovery = mock(Rediscovery.class);
given(rediscoverySupplier.get()).willReturn(rediscovery);

// WHEN
var driver = driverFactory.newInstance(
URI.create("neo4j://localhost:7687"),
AuthTokens.none(),
RoutingSettings.DEFAULT,
RetrySettings.DEFAULT,
Config.defaultConfig(),
null,
securityPlan,
rediscoverySupplier);

// THEN
var sessionFactory = ((InternalDriver) driver).getSessionFactory();
var connectionProvider = ((SessionFactoryImpl) sessionFactory).getConnectionProvider();
var actualRediscovery = ((LoadBalancer) connectionProvider).getRediscovery();
then(rediscoverySupplier).should().get();
assertEquals(rediscovery, actualRediscovery);
}

private Driver createDriver(String uri, DriverFactory driverFactory) {
return createDriver(uri, driverFactory, defaultConfig());
}
Expand Down Expand Up @@ -239,6 +301,7 @@ protected InternalDriver createRoutingDriver(
RoutingSettings routingSettings,
RetryLogic retryLogic,
MetricsProvider metricsProvider,
Supplier<Rediscovery> rediscoverySupplier,
Config config) {
throw new UnsupportedOperationException("Can't create routing driver");
}
Expand Down Expand Up @@ -276,7 +339,8 @@ protected LoadBalancer createLoadBalancer(
ConnectionPool connectionPool,
EventExecutorGroup eventExecutorGroup,
Config config,
RoutingSettings routingSettings) {
RoutingSettings routingSettings,
Supplier<Rediscovery> rediscoverySupplier) {
return null;
}

Expand Down
Loading