diff --git a/pom.xml b/pom.xml index ebd48585..e984aac3 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-r2dbc - 1.1.0.BUILD-SNAPSHOT + 1.1.0.gh-220-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java index c35cffbc..1e332bcf 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java @@ -204,6 +204,15 @@ protected Mono getConnection() { return ConnectionFactoryUtils.getConnection(obtainConnectionFactory()); } + /** + * Obtain the {@link ReactiveDataAccessStrategy}. + * + * @return a the ReactiveDataAccessStrategy. + */ + protected ReactiveDataAccessStrategy getDataAccessStrategy() { + return dataAccessStrategy; + } + /** * Release the {@link Connection}. * @@ -300,9 +309,9 @@ private static void bindByIndex(Statement statement, Map byIndex.forEach((i, o) -> { if (o.getValue() != null) { - statement.bind(i.intValue(), o.getValue()); + statement.bind(i, o.getValue()); } else { - statement.bindNull(i.intValue(), o.getType()); + statement.bindNull(i, o.getType()); } }); } @@ -809,8 +818,8 @@ private FetchSpec exchange(BiFunction mappingFunctio StatementMapper mapper = dataAccessStrategy.getStatementMapper(); - StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.table).withProjection(this.projectedFields) - .withSort(this.sort).withPage(this.page); + StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.table) + .withProjection(this.projectedFields.toArray(new SqlIdentifier[0])).withSort(this.sort).withPage(this.page); if (this.criteria != null) { selectSpec = selectSpec.withCriteria(this.criteria); @@ -922,8 +931,8 @@ private FetchSpec exchange(BiFunction mappingFunctio columns = this.projectedFields; } - StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.table).withProjection(columns) - .withPage(this.page).withSort(this.sort); + StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.table) + .withProjection(columns.toArray(new SqlIdentifier[0])).withPage(this.page).withSort(this.sort); if (this.criteria != null) { selectSpec = selectSpec.withCriteria(this.criteria); @@ -1029,7 +1038,7 @@ private FetchSpec exchange(BiFunction mappingFunctio StatementMapper.InsertSpec insert = mapper.createInsert(this.table); for (SqlIdentifier column : this.byName.keySet()) { - insert = insert.withColumn(dataAccessStrategy.toSql(column), this.byName.get(column)); + insert = insert.withColumn(column, this.byName.get(column)); } PreparedOperation operation = mapper.getMappedObject(insert); @@ -1152,7 +1161,7 @@ private FetchSpec exchange(Object toInsert, BiFunction getMappedObject(SelectSpec selectSpec) { private PreparedOperation getMappedObject(SelectSpec selectSpec, if (selectSpec.getSort().isSorted()) { - Sort mappedSort = this.updateMapper.getMappedObject(selectSpec.getSort(), entity); - selectBuilder.orderBy(createOrderByFields(table, mappedSort)); + List sort = this.updateMapper.getMappedSort(table, selectSpec.getSort(), entity); + selectBuilder.orderBy(sort); } - if (selectSpec.getPage().isPaged()) { - - Pageable page = selectSpec.getPage(); + if (selectSpec.getLimit() > 0) { + selectBuilder.limit(selectSpec.getLimit()); + } - selectBuilder.limitOffset(page.getPageSize(), page.getOffset()); + if (selectSpec.getOffset() > 0) { + selectBuilder.offset(selectSpec.getOffset()); } Select select = selectBuilder.build(); return new DefaultPreparedOperation<>(select, this.renderContext, bindings); } - private Collection createOrderByFields(Table table, Sort sortToUse) { - - List fields = new ArrayList<>(); + protected List getSelectList(SelectSpec selectSpec, @Nullable RelationalPersistentEntity entity) { - for (Sort.Order order : sortToUse) { + if (entity == null) { + return selectSpec.getSelectList(); + } - OrderByField orderByField = OrderByField.from(table.column(order.getProperty())); + List selectList = selectSpec.getSelectList(); + List mapped = new ArrayList<>(selectList.size()); - if (order.getDirection() != null) { - fields.add(order.isAscending() ? orderByField.asc() : orderByField.desc()); - } else { - fields.add(orderByField); - } + for (Expression expression : selectList) { + mapped.add(updateMapper.getMappedObject(expression, entity)); } - return fields; + return mapped; } /* @@ -259,24 +255,13 @@ private PreparedOperation getMappedObject(DeleteSpec deleteSpec, * (non-Javadoc) * @see org.springframework.data.r2dbc.function.StatementMapper#toSql(SqlIdentifier) */ - public String toSql(SqlIdentifier identifier) { + private String toSql(SqlIdentifier identifier) { Assert.notNull(identifier, "SqlIdentifier must not be null"); return identifier.toSql(this.dialect.getIdentifierProcessing()); } - private List toSql(List identifiers) { - - List list = new ArrayList<>(identifiers.size()); - - for (SqlIdentifier sqlIdentifier : identifiers) { - list.add(toSql(sqlIdentifier)); - } - - return list; - } - /** * Default implementation of {@link PreparedOperation}. * @@ -288,7 +273,8 @@ static class DefaultPreparedOperation implements PreparedOperation { private final RenderContext renderContext; private final Bindings bindings; - public DefaultPreparedOperation(T source, RenderContext renderContext, Bindings bindings) { + DefaultPreparedOperation(T source, RenderContext renderContext, Bindings bindings) { + this.source = source; this.renderContext = renderContext; this.bindings = bindings; diff --git a/src/main/java/org/springframework/data/r2dbc/core/FluentR2dbcOperations.java b/src/main/java/org/springframework/data/r2dbc/core/FluentR2dbcOperations.java new file mode 100644 index 00000000..8ab66f6c --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/FluentR2dbcOperations.java @@ -0,0 +1,26 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +/** + * Stripped down interface providing access to a fluent API that specifies a basic set of reactive R2DBC operations. + * + * @author Mark Paluch + * @since 1.1 + * @see R2dbcEntityOperations + */ +public interface FluentR2dbcOperations + extends ReactiveSelectOperation, ReactiveInsertOperation, ReactiveUpdateOperation, ReactiveDeleteOperation {} diff --git a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityOperations.java b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityOperations.java new file mode 100644 index 00000000..2cb72138 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityOperations.java @@ -0,0 +1,144 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.dao.DataAccessException; +import org.springframework.dao.TransientDataAccessResourceException; +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.r2dbc.query.Update; + +/** + * Interface specifying a basic set of reactive R2DBC operations using entities. Implemented by + * {@link R2dbcEntityTemplate}. Not often used directly, but a useful option to enhance testability, as it can easily be + * mocked or stubbed. + * + * @author Mark Paluch + * @since 1.1 + * @see DatabaseClient + */ +public interface R2dbcEntityOperations extends FluentR2dbcOperations { + + /** + * Expose the underlying {@link DatabaseClient} to allow SQL operations. + * + * @return the underlying {@link DatabaseClient}. + * @see DatabaseClient + */ + DatabaseClient getDatabaseClient(); + + // ------------------------------------------------------------------------- + // Methods dealing with org.springframework.data.r2dbc.query.Query + // ------------------------------------------------------------------------- + + /** + * Returns the number of rows for the given entity class applying {@link Query}. This overridden method allows users + * to further refine the selection Query using a {@link Query} predicate to determine how many entities of the given + * {@link Class type} match the Query. + * + * @param query user-defined count {@link Query} to execute; must not be {@literal null}. + * @param entityClass {@link Class type} of the entity; must not be {@literal null}. + * @return the number of existing entities. + * @throws DataAccessException if any problem occurs while executing the query. + */ + Mono count(Query query, Class entityClass) throws DataAccessException; + + /** + * Determine whether the result for {@code entityClass} {@link Query} yields at least one row. + * + * @param query user-defined exists {@link Query} to execute; must not be {@literal null}. + * @param entityClass {@link Class type} of the entity; must not be {@literal null}. + * @return {@literal true} if the object exists. + * @throws DataAccessException if any problem occurs while executing the query. + * @since 2.1 + */ + Mono exists(Query query, Class entityClass) throws DataAccessException; + + /** + * Execute a {@code SELECT} query and convert the resulting items to a stream of entities. + * + * @param query must not be {@literal null}. + * @param entityClass The entity type must not be {@literal null}. + * @return the result objects returned by the action. + * @throws DataAccessException if there is any problem issuing the execution. + */ + Flux select(Query query, Class entityClass) throws DataAccessException; + + /** + * Execute a {@code SELECT} query and convert the resulting item to an entity. + * + * @param query must not be {@literal null}. + * @param entityClass The entity type must not be {@literal null}. + * @return the result object returned by the action or {@link Mono#empty()}. + * @throws DataAccessException if there is any problem issuing the execution. + */ + Mono selectOne(Query query, Class entityClass) throws DataAccessException; + + /** + * Update the queried entities and return {@literal true} if the update was applied. + * + * @param query must not be {@literal null}. + * @param update must not be {@literal null}. + * @param entityClass The entity type must not be {@literal null}. + * @return the number of affected rows. + * @throws DataAccessException if there is any problem executing the query. + */ + Mono update(Query query, Update update, Class entityClass) throws DataAccessException; + + /** + * Remove entities (rows)/columns from the table by {@link Query}. + * + * @param query must not be {@literal null}. + * @param entityClass The entity type must not be {@literal null}. + * @return the number of affected rows. + * @throws DataAccessException if there is any problem issuing the execution. + */ + Mono delete(Query query, Class entityClass) throws DataAccessException; + + // ------------------------------------------------------------------------- + // Methods dealing with entities + // ------------------------------------------------------------------------- + + /** + * Insert the given entity and emit the entity if the insert was applied. + * + * @param entity The entity to insert, must not be {@literal null}. + * @return the inserted entity. + * @throws DataAccessException if there is any problem issuing the execution. + */ + Mono insert(T entity) throws DataAccessException; + + /** + * Update the given entity and emit the entity if the update was applied. + * + * @param entity The entity to update, must not be {@literal null}. + * @return the updated entity. + * @throws DataAccessException if there is any problem issuing the execution. + * @throws TransientDataAccessResourceException if the update did not affect any rows. + */ + Mono update(T entity) throws DataAccessException; + + /** + * Delete the given entity and emit the entity if the delete was applied. + * + * @param entity must not be {@literal null}. + * @return the deleted entity. + * @throws DataAccessException if there is any problem issuing the execution. + */ + Mono delete(T entity) throws DataAccessException; +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java new file mode 100644 index 00000000..0544013b --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -0,0 +1,495 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.beans.FeatureDescriptor; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.dao.DataAccessException; +import org.springframework.dao.TransientDataAccessResourceException; +import org.springframework.data.mapping.IdentifierAccessor; +import org.springframework.data.mapping.MappingException; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.projection.ProjectionInformation; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; +import org.springframework.data.r2dbc.query.Criteria; +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.r2dbc.query.Update; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.relational.core.sql.Expression; +import org.springframework.data.relational.core.sql.Functions; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.util.ProxyUtils; +import org.springframework.util.Assert; + +/** + * Implementation of {@link R2dbcEntityOperations}. It simplifies the use of Reactive R2DBC usage through entities and + * helps to avoid common errors. This class uses {@link DatabaseClient} to execute SQL queries or updates, initiating + * iteration over {@link io.r2dbc.spi.Result}. + *

+ * Can be used within a service implementation via direct instantiation with a {@link DatabaseClient} reference, or get + * prepared in an application context and given to services as bean reference. + * + * @author Mark Paluch + * @since 1.1 + */ +public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware { + + private final DatabaseClient databaseClient; + + private final ReactiveDataAccessStrategy dataAccessStrategy; + + private final MappingContext, ? extends RelationalPersistentProperty> mappingContext; + + private final SpelAwareProxyProjectionFactory projectionFactory; + + /** + * Create a new {@link R2dbcEntityTemplate} given {@link DatabaseClient}. + * + * @param databaseClient must not be {@literal null}. + */ + public R2dbcEntityTemplate(DatabaseClient databaseClient) { + + Assert.notNull(databaseClient, "DatabaseClient must not be null"); + + this.databaseClient = databaseClient; + this.dataAccessStrategy = getDataAccessStrategy(databaseClient); + this.mappingContext = getMappingContext(this.dataAccessStrategy); + this.projectionFactory = new SpelAwareProxyProjectionFactory(); + } + + /** + * Create a new {@link R2dbcEntityTemplate} given {@link DatabaseClient} and {@link ReactiveDataAccessStrategy}. + * + * @param databaseClient must not be {@literal null}. + */ + public R2dbcEntityTemplate(DatabaseClient databaseClient, ReactiveDataAccessStrategy strategy) { + + Assert.notNull(databaseClient, "DatabaseClient must not be null"); + Assert.notNull(strategy, "ReactiveDataAccessStrategy must not be null"); + + this.databaseClient = databaseClient; + this.dataAccessStrategy = strategy; + this.mappingContext = strategy.getConverter().getMappingContext(); + this.projectionFactory = new SpelAwareProxyProjectionFactory(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#getDatabaseClient() + */ + @Override + public DatabaseClient getDatabaseClient() { + return this.databaseClient; + } + + /* + * (non-Javadoc) + * @see org.springframework.beans.factory.BeanFactoryAware#setBeanFactory(org.springframework.beans.factory.BeanFactory) + */ + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + this.projectionFactory.setBeanFactory(beanFactory); + } + + // ------------------------------------------------------------------------- + // Methods dealing with org.springframework.data.r2dbc.core.FluentR2dbcOperations + // ------------------------------------------------------------------------- + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation#select(java.lang.Class) + */ + @Override + public ReactiveSelect select(Class domainType) { + return new ReactiveSelectOperationSupport(this).select(domainType); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveInsertOperation#insert(java.lang.Class) + */ + @Override + public ReactiveInsert insert(Class domainType) { + return new ReactiveInsertOperationSupport(this).insert(domainType); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveUpdateOperation#update(java.lang.Class) + */ + @Override + public ReactiveUpdate update(Class domainType) { + return new ReactiveUpdateOperationSupport(this).update(domainType); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveDeleteOperation#delete(java.lang.Class) + */ + @Override + public ReactiveDelete delete(Class domainType) { + return new ReactiveDeleteOperationSupport(this).delete(domainType); + } + + // ------------------------------------------------------------------------- + // Methods dealing with org.springframework.data.r2dbc.query.Query + // ------------------------------------------------------------------------- + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#count(org.springframework.data.r2dbc.query.Query, java.lang.Class) + */ + @Override + public Mono count(Query query, Class entityClass) throws DataAccessException { + + Assert.notNull(query, "Query must not be null"); + Assert.notNull(entityClass, "entity class must not be null"); + + return doCount(query, entityClass, getTableName(entityClass)); + } + + Mono doCount(Query query, Class entityClass, SqlIdentifier tableName) { + + RelationalPersistentEntity entity = getRequiredEntity(entityClass); + StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); + + StatementMapper.SelectSpec selectSpec = statementMapper // + .createSelect(tableName) // + .doWithTable((table, spec) -> { + return spec.withProjection(Functions.count(table.column(entity.getRequiredIdProperty().getColumnName()))); + }); + + Optional criteria = query.getCriteria(); + if (criteria.isPresent()) { + selectSpec = criteria.map(selectSpec::withCriteria).orElse(selectSpec); + } + + PreparedOperation operation = statementMapper.getMappedObject(selectSpec); + + return this.databaseClient.execute(operation) // + .map((r, md) -> r.get(0, Long.class)) // + .first() // + .defaultIfEmpty(0L); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#exists(org.springframework.data.r2dbc.query.Query, java.lang.Class) + */ + @Override + public Mono exists(Query query, Class entityClass) throws DataAccessException { + + Assert.notNull(query, "Query must not be null"); + Assert.notNull(entityClass, "entity class must not be null"); + + return doExists(query, entityClass, getTableName(entityClass)); + } + + Mono doExists(Query query, Class entityClass, SqlIdentifier tableName) { + + RelationalPersistentEntity entity = getRequiredEntity(entityClass); + StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); + + SqlIdentifier columnName = entity.hasIdProperty() ? entity.getRequiredIdProperty().getColumnName() + : SqlIdentifier.unquoted("*"); + + StatementMapper.SelectSpec selectSpec = statementMapper // + .createSelect(tableName) // + .withProjection(columnName); + + Optional criteria = query.getCriteria(); + if (criteria.isPresent()) { + selectSpec = criteria.map(selectSpec::withCriteria).orElse(selectSpec); + } + + PreparedOperation operation = statementMapper.getMappedObject(selectSpec); + + return this.databaseClient.execute(operation) // + .map((r, md) -> r) // + .first() // + .hasElement(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#select(org.springframework.data.r2dbc.query.Query, java.lang.Class) + */ + @Override + public Flux select(Query query, Class entityClass) throws DataAccessException { + + Assert.notNull(query, "Query must not be null"); + Assert.notNull(entityClass, "entity class must not be null"); + + return doSelect(query, entityClass, getTableName(entityClass), entityClass).all(); + } + + RowsFetchSpec doSelect(Query query, Class entityClass, SqlIdentifier tableName, Class returnType) { + + StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); + + StatementMapper.SelectSpec selectSpec = statementMapper // + .createSelect(tableName) // + .doWithTable((table, spec) -> spec.withProjection(getSelectProjection(table, query, returnType))); + + if (query.getLimit() > 0) { + selectSpec = selectSpec.limit(query.getLimit()); + } + + if (query.getOffset() > 0) { + selectSpec = selectSpec.offset(query.getOffset()); + } + + if (query.isSorted()) { + selectSpec = selectSpec.withSort(query.getSort()); + } + + Optional criteria = query.getCriteria(); + if (criteria.isPresent()) { + selectSpec = criteria.map(selectSpec::withCriteria).orElse(selectSpec); + } + + PreparedOperation operation = statementMapper.getMappedObject(selectSpec); + + BiFunction rowMapper; + if (returnType.isInterface()) { + rowMapper = dataAccessStrategy.getRowMapper(entityClass) + .andThen(o -> projectionFactory.createProjection(returnType, o)); + } else { + rowMapper = dataAccessStrategy.getRowMapper(returnType); + } + + return this.databaseClient.execute(operation).map(rowMapper); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#selectOne(org.springframework.data.r2dbc.query.Query, java.lang.Class) + */ + @Override + public Mono selectOne(Query query, Class entityClass) throws DataAccessException { + return doSelect(query.limit(2), entityClass, getTableName(entityClass), entityClass).one(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(org.springframework.data.r2dbc.query.Query, org.springframework.data.r2dbc.query.Update, java.lang.Class) + */ + @Override + public Mono update(Query query, Update update, Class entityClass) throws DataAccessException { + + Assert.notNull(query, "Query must not be null"); + Assert.notNull(update, "Update must not be null"); + Assert.notNull(entityClass, "entity class must not be null"); + + return doUpdate(query, update, entityClass, getTableName(entityClass)); + } + + Mono doUpdate(Query query, Update update, Class entityClass, SqlIdentifier tableName) { + + StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); + + StatementMapper.UpdateSpec selectSpec = statementMapper // + .createUpdate(tableName, update); + + Optional criteria = query.getCriteria(); + if (criteria.isPresent()) { + selectSpec = criteria.map(selectSpec::withCriteria).orElse(selectSpec); + } + + PreparedOperation operation = statementMapper.getMappedObject(selectSpec); + return this.databaseClient.execute(operation).fetch().rowsUpdated(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(org.springframework.data.r2dbc.query.Query, java.lang.Class) + */ + @Override + public Mono delete(Query query, Class entityClass) throws DataAccessException { + + Assert.notNull(query, "Query must not be null"); + Assert.notNull(entityClass, "entity class must not be null"); + + return doDelete(query, entityClass, getTableName(entityClass)); + } + + Mono doDelete(Query query, Class entityClass, SqlIdentifier tableName) { + + StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityClass); + + StatementMapper.DeleteSpec selectSpec = statementMapper // + .createDelete(tableName); + + Optional criteria = query.getCriteria(); + if (criteria.isPresent()) { + selectSpec = criteria.map(selectSpec::withCriteria).orElse(selectSpec); + } + + PreparedOperation operation = statementMapper.getMappedObject(selectSpec); + return this.databaseClient.execute(operation).fetch().rowsUpdated().defaultIfEmpty(0); + } + + // ------------------------------------------------------------------------- + // Methods dealing with entities + // ------------------------------------------------------------------------- + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#insert(java.lang.Object) + */ + @Override + public Mono insert(T entity) throws DataAccessException { + + Assert.notNull(entity, "Entity must not be null"); + + return doInsert(entity, getRequiredEntity(entity).getTableName()); + } + + Mono doInsert(T entity, SqlIdentifier tableName) { + + RelationalPersistentEntity persistentEntity = getRequiredEntity(entity); + + return this.databaseClient.insert() // + .into(persistentEntity.getType()) // + .table(tableName).using(entity) // + .map(this.dataAccessStrategy.getConverter().populateIdIfNecessary(entity)) // + .first() // + .defaultIfEmpty(entity); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object) + */ + @Override + public Mono update(T entity) throws DataAccessException { + + Assert.notNull(entity, "Entity must not be null"); + + RelationalPersistentEntity persistentEntity = getRequiredEntity(entity); + + return this.databaseClient.update() // + .table(persistentEntity.getType()) // + .table(persistentEntity.getTableName()).using(entity) // + .fetch().rowsUpdated().handle((rowsUpdated, sink) -> { + + if (rowsUpdated == 0) { + sink.error(new TransientDataAccessResourceException( + String.format("Failed to update table [%s]. Row with Id [%s] does not exist.", + persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier()))); + } else { + sink.next(entity); + } + }); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object) + */ + @Override + public Mono delete(T entity) throws DataAccessException { + + Assert.notNull(entity, "Entity must not be null"); + + RelationalPersistentEntity persistentEntity = getRequiredEntity(entity); + + return delete(getByIdQuery(entity, persistentEntity), persistentEntity.getType()).thenReturn(entity); + } + + private Query getByIdQuery(T entity, RelationalPersistentEntity persistentEntity) { + if (!persistentEntity.hasIdProperty()) { + throw new MappingException("No id property found for object of type " + persistentEntity.getType() + "!"); + } + + IdentifierAccessor identifierAccessor = persistentEntity.getIdentifierAccessor(entity); + Object id = identifierAccessor.getRequiredIdentifier(); + + return Query.query(Criteria.where(persistentEntity.getRequiredIdProperty().getName()).is(id)); + } + + SqlIdentifier getTableName(Class entityClass) { + return getRequiredEntity(entityClass).getTableName(); + } + + private RelationalPersistentEntity getRequiredEntity(Class entityClass) { + return this.mappingContext.getRequiredPersistentEntity(entityClass); + } + + private RelationalPersistentEntity getRequiredEntity(T entity) { + Class entityType = ProxyUtils.getUserClass(entity); + return (RelationalPersistentEntity) getRequiredEntity(entityType); + } + + private List getSelectProjection(Table table, Query query, Class returnType) { + + if (query.getColumns().isEmpty()) { + + if (returnType.isInterface()) { + + ProjectionInformation projectionInformation = projectionFactory.getProjectionInformation(returnType); + + if (projectionInformation.isClosed()) { + return projectionInformation.getInputProperties().stream().map(FeatureDescriptor::getName).map(table::column) + .collect(Collectors.toList()); + } + } + + return Collections.singletonList(table.asterisk()); + } + + return query.getColumns().stream().map(table::column).collect(Collectors.toList()); + } + + private static ReactiveDataAccessStrategy getDataAccessStrategy(DatabaseClient databaseClient) { + + if (databaseClient instanceof DefaultDatabaseClient) { + + DefaultDatabaseClient client = (DefaultDatabaseClient) databaseClient; + return client.getDataAccessStrategy(); + } + + throw new IllegalStateException("Cannot obtain ReactiveDataAccessStrategy"); + } + + private static MappingContext, ? extends RelationalPersistentProperty> getMappingContext( + ReactiveDataAccessStrategy strategy) { + + if (strategy instanceof DefaultReactiveDataAccessStrategy) { + + DefaultReactiveDataAccessStrategy strategy1 = (DefaultReactiveDataAccessStrategy) strategy; + return strategy1.getMappingContext(); + } + return new R2dbcMappingContext(); + } + +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperation.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperation.java new file mode 100644 index 00000000..dbc046dd --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperation.java @@ -0,0 +1,124 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Mono; + +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.relational.core.sql.SqlIdentifier; + +/** + * The {@link ReactiveDeleteOperation} interface allows creation and execution of {@code DELETE} operations in a fluent + * API style. + *

+ * The starting {@literal domainType} is used for mapping the {@link Query} provided via {@code matching}. By default, + * the table to operate on is derived from the initial {@literal domainType} and can be defined there via + * {@link org.springframework.data.relational.core.mapping.Table} annotation. Using {@code inTable} allows to override + * the table name for the execution. + * + *

+ *     
+ *         delete(Jedi.class)
+ *             .from("star_wars")
+ *             .matching(query(where("firstname").is("luke")))
+ *             .all();
+ *     
+ * 
+ * + * @author Mark Paluch + * @since 1.1 + */ +public interface ReactiveDeleteOperation { + + /** + * Begin creating a {@code DELETE} operation for the given {@link Class domainType}. + * + * @param domainType {@link Class type} of domain object to delete; must not be {@literal null}. + * @return new instance of {@link ReactiveDelete}. + * @throws IllegalArgumentException if {@link Class domainType} is {@literal null}. + * @see ReactiveDelete + */ + ReactiveDelete delete(Class domainType); + + /** + * Table override (optional). + */ + interface DeleteWithTable { + + /** + * Explicitly set the {@link String name} of the table on which to perform the delete. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link String name} of the table; must not be {@literal null} or empty. + * @return new instance of {@link DeleteWithQuery}. + * @throws IllegalArgumentException if {@link String table} is {@literal null} or empty. + * @see DeleteWithQuery + */ + default DeleteWithQuery from(String table) { + return from(SqlIdentifier.unquoted(table)); + } + + /** + * Explicitly set the {@link SqlIdentifier name} of the table on which to perform the delete. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link SqlIdentifier name} of the table; must not be {@literal null}. + * @return new instance of {@link DeleteWithQuery}. + * @throws IllegalArgumentException if {@link SqlIdentifier table} is {@literal null}. + * @see DeleteWithQuery + */ + DeleteWithQuery from(SqlIdentifier table); + } + + /** + * Required {@link Query filter}. + */ + interface DeleteWithQuery { + + /** + * Define the {@link Query} used to filter elements in the delete. + * + * @param query {@link Query} used as the filter in the delete; must not be {@literal null}. + * @return new instance of {@link TerminatingDelete}. + * @throws IllegalArgumentException if {@link Query} is {@literal null}. + * @see TerminatingDelete + * @see Query + */ + TerminatingDelete matching(Query query); + } + + /** + * Trigger {@code DELETE} operation by calling one of the terminating methods. + */ + interface TerminatingDelete { + + /** + * Remove all matching rows. + * + * @return the number of affected rows; never {@literal null}. + * @see Mono + */ + Mono all(); + } + + /** + * The {@link ReactiveDelete} interface provides methods for constructing {@code DELETE} operations in a fluent way. + */ + interface ReactiveDelete extends DeleteWithTable, DeleteWithQuery {} + +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperationSupport.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperationSupport.java new file mode 100644 index 00000000..c4370b40 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperationSupport.java @@ -0,0 +1,103 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Mono; + +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link ReactiveDeleteOperation}. + * + * @author Mark Paluch + * @since 1.1 + */ +class ReactiveDeleteOperationSupport implements ReactiveDeleteOperation { + + private final R2dbcEntityTemplate template; + + ReactiveDeleteOperationSupport(R2dbcEntityTemplate template) { + this.template = template; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveDeleteOperation#delete(java.lang.Class) + */ + @Override + public ReactiveDelete delete(Class domainType) { + + Assert.notNull(domainType, "DomainType must not be null"); + + return new ReactiveDeleteSupport(this.template, domainType, Query.empty(), null); + } + + static class ReactiveDeleteSupport implements ReactiveDelete, TerminatingDelete { + + private final R2dbcEntityTemplate template; + private final Class domainType; + private final Query query; + private final @Nullable SqlIdentifier tableName; + + ReactiveDeleteSupport(R2dbcEntityTemplate template, Class domainType, Query query, + @Nullable SqlIdentifier tableName) { + + this.template = template; + this.domainType = domainType; + this.query = query; + this.tableName = tableName; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveDeleteOperation.DeleteWithTable#from(SqlIdentifier) + */ + @Override + public DeleteWithQuery from(SqlIdentifier tableName) { + + Assert.notNull(tableName, "Table name must not be null"); + + return new ReactiveDeleteSupport(this.template, this.domainType, this.query, tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveDeleteOperation.DeleteWithQuery#matching(org.springframework.data.r2dbc.query.Query) + */ + @Override + public TerminatingDelete matching(Query query) { + + Assert.notNull(query, "Query must not be null"); + + return new ReactiveDeleteSupport(this.template, this.domainType, query, this.tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveDeleteOperation.TerminatingDelete#all() + */ + public Mono all() { + return this.template.doDelete(this.query, this.domainType, getTableName()); + } + + private SqlIdentifier getTableName() { + return this.tableName != null ? this.tableName : this.template.getTableName(this.domainType); + } + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveInsertOperation.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveInsertOperation.java new file mode 100644 index 00000000..89b44f85 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveInsertOperation.java @@ -0,0 +1,105 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Mono; + +import org.springframework.data.relational.core.sql.SqlIdentifier; + +/** + * The {@link ReactiveInsertOperation} interface allows creation and execution of {@code INSERT} operations in a fluent + * API style. + *

+ * By default,the table to operate on is derived from the initial {@link Class domainType} and can be defined there via + * {@link org.springframework.data.relational.core.mapping.Table} annotation. Using {@code inTable} allows to override + * the table name for the execution. + * + *

+ *     
+ *         insert(Jedi.class)
+ *             .into("star_wars")
+ *             .using(luke);
+ *     
+ * 
+ * + * @author Mark Paluch + * @since 1.1 + */ +public interface ReactiveInsertOperation { + + /** + * Begin creating an {@code INSERT} operation for given {@link Class domainType}. + * + * @param {@link Class type} of the application domain object. + * @param domainType {@link Class type} of the domain object to insert; must not be {@literal null}. + * @return new instance of {@link ReactiveInsert}. + * @throws IllegalArgumentException if {@link Class domainType} is {@literal null}. + * @see ReactiveInsert + */ + ReactiveInsert insert(Class domainType); + + /** + * Table override (optional). + */ + interface InsertWithTable extends TerminatingInsert { + + /** + * Explicitly set the {@link String name} of the table. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link String name} of the table; must not be {@literal null} or empty. + * @return new instance of {@link TerminatingInsert}. + * @throws IllegalArgumentException if {@link String table} is {@literal null} or empty. + */ + default TerminatingInsert into(String table) { + return into(SqlIdentifier.unquoted(table)); + } + + /** + * Explicitly set the {@link SqlIdentifier name} of the table. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link SqlIdentifier name} of the table; must not be {@literal null}. + * @return new instance of {@link TerminatingInsert}. + * @throws IllegalArgumentException if {@link SqlIdentifier table} is {@literal null}. + */ + TerminatingInsert into(SqlIdentifier table); + } + + /** + * Trigger {@code INSERT} execution by calling one of the terminating methods. + */ + interface TerminatingInsert { + + /** + * Insert exactly one {@link Object}. + * + * @param object {@link Object} to insert; must not be {@literal null}. + * @return the write result for this operation. + * @throws IllegalArgumentException if {@link Object} is {@literal null}. + * @see Mono + */ + Mono using(T object); + } + + /** + * The {@link ReactiveInsert} interface provides methods for constructing {@code INSERT} operations in a fluent way. + */ + interface ReactiveInsert extends InsertWithTable {} + +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveInsertOperationSupport.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveInsertOperationSupport.java new file mode 100644 index 00000000..e3082305 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveInsertOperationSupport.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Mono; + +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link ReactiveInsertOperation}. + * + * @author Mark Paluch + * @since 1.1 + */ +class ReactiveInsertOperationSupport implements ReactiveInsertOperation { + + private final R2dbcEntityTemplate template; + + ReactiveInsertOperationSupport(R2dbcEntityTemplate template) { + this.template = template; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveInsertOperation#insert(java.lang.Class) + */ + @Override + public ReactiveInsert insert(Class domainType) { + + Assert.notNull(domainType, "DomainType must not be null"); + + return new ReactiveInsertSupport<>(this.template, domainType, null); + } + + static class ReactiveInsertSupport implements ReactiveInsert { + + private final R2dbcEntityTemplate template; + private final Class domainType; + private final @Nullable SqlIdentifier tableName; + + ReactiveInsertSupport(R2dbcEntityTemplate template, Class domainType, @Nullable SqlIdentifier tableName) { + + this.template = template; + this.domainType = domainType; + this.tableName = tableName; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveInsertOperation.InsertWithTable#into(SqlIdentifier) + */ + @Override + public TerminatingInsert into(SqlIdentifier tableName) { + + Assert.notNull(tableName, "Table name must not be null"); + + return new ReactiveInsertSupport<>(this.template, this.domainType, tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveInsertOperation.TerminatingInsert#one(java.lang.Object) + */ + @Override + public Mono using(T object) { + + Assert.notNull(object, "Object to insert must not be null"); + + return this.template.doInsert(object, getTableName()); + } + + private SqlIdentifier getTableName() { + return this.tableName != null ? this.tableName : this.template.getTableName(this.domainType); + } + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveSelectOperation.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveSelectOperation.java new file mode 100644 index 00000000..c67de55c --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveSelectOperation.java @@ -0,0 +1,182 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.relational.core.sql.SqlIdentifier; + +/** + * The {@link ReactiveSelectOperation} interface allows creation and execution of {@code SELECT} operations in a fluent + * API style. + *

+ * The starting {@literal domainType} is used for mapping the {@link Query} provided via {@code matching}. By default, + * the originating {@literal domainType} is also used for mapping back the result from the {@link io.r2dbc.spi.Row}. + * However, it is possible to define an different {@literal returnType} via {@code as} to mapping the result. + *

+ * By default, the table to operate on is derived from the initial {@literal domainType} and can be defined there via + * the {@link org.springframework.data.relational.core.mapping.Table} annotation. Using {@code inTable} allows to + * override the table name for the execution. + * + *

+ *     
+ *         select(Human.class)
+ *             .from("star_wars")
+ *             .as(Jedi.class)
+ *             .matching(query(where("firstname").is("luke")))
+ *             .all();
+ *     
+ * 
+ * + * @author Mark Paluch + * @since 1.1 + */ +public interface ReactiveSelectOperation { + + /** + * Begin creating a {@code SELECT} operation for the given {@link Class domainType}. + * + * @param {@link Class type} of the application domain object. + * @param domainType {@link Class type} of the domain object to query; must not be {@literal null}. + * @return new instance of {@link ReactiveSelect}. + * @throws IllegalArgumentException if {@link Class domainType} is {@literal null}. + * @see ReactiveSelect + */ + ReactiveSelect select(Class domainType); + + /** + * Table override (optional). + */ + interface SelectWithTable extends SelectWithQuery { + + /** + * Explicitly set the {@link String name} of the table on which to perform the query. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link String name} of the table; must not be {@literal null} or empty. + * @return new instance of {@link SelectWithProjection}. + * @throws IllegalArgumentException if {@link String table} is {@literal null} or empty. + * @see SelectWithProjection + */ + default SelectWithProjection from(String table) { + return from(SqlIdentifier.unquoted(table)); + } + + /** + * Explicitly set the {@link SqlIdentifier name} of the table on which to perform the query. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link SqlIdentifier name} of the table; must not be {@literal null}. + * @return new instance of {@link SelectWithProjection}. + * @throws IllegalArgumentException if {@link SqlIdentifier table} is {@literal null}. + * @see SelectWithProjection + */ + SelectWithProjection from(SqlIdentifier table); + } + + /** + * Result type override (optional). + */ + interface SelectWithProjection extends SelectWithQuery { + + /** + * Define the {@link Class result target type} that the fields should be mapped to. + *

+ * Skip this step if you are only interested in the original {@link Class domain type}. + * + * @param {@link Class type} of the result. + * @param resultType desired {@link Class type} of the result; must not be {@literal null}. + * @return new instance of {@link SelectWithQuery}. + * @throws IllegalArgumentException if {@link Class resultType} is {@literal null}. + * @see SelectWithQuery + */ + SelectWithQuery as(Class resultType); + } + + /** + * Define a {@link Query} used as the filter for the {@code SELECT}. + */ + interface SelectWithQuery extends TerminatingSelect { + + /** + * Set the {@link Query} used as a filter in the {@code SELECT} statement. + * + * @param query {@link Query} used as a filter; must not be {@literal null}. + * @return new instance of {@link TerminatingSelect}. + * @throws IllegalArgumentException if {@link Query} is {@literal null}. + * @see Query + * @see TerminatingSelect + */ + TerminatingSelect matching(Query query); + } + + /** + * Trigger {@code SELECT} execution by calling one of the terminating methods. + */ + interface TerminatingSelect { + + /** + * Get the number of matching elements. + * + * @return a {@link Mono} emitting the total number of matching elements; never {@literal null}. + * @see Mono + */ + Mono count(); + + /** + * Check for the presence of matching elements. + * + * @return a {@link Mono} emitting {@literal true} if at least one matching element exists; never {@literal null}. + * @see Mono + */ + Mono exists(); + + /** + * Get the first result or no result. + * + * @return the first result or {@link Mono#empty()} if no match found; never {@literal null}. + * @see Mono + */ + Mono first(); + + /** + * Get exactly zero or one result. + * + * @return exactly one result or {@link Mono#empty()} if no match found; never {@literal null}. + * @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found. + * @see Mono + */ + Mono one(); + + /** + * Get all matching elements. + * + * @return all matching elements; never {@literal null}. + * @see Flux + */ + Flux all(); + } + + /** + * The {@link ReactiveSelect} interface provides methods for constructing {@code SELECT} operations in a fluent way. + */ + interface ReactiveSelect extends SelectWithTable, SelectWithProjection {} + +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveSelectOperationSupport.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveSelectOperationSupport.java new file mode 100644 index 00000000..b4086f3e --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveSelectOperationSupport.java @@ -0,0 +1,155 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link ReactiveSelectOperation}. + * + * @author Mark Paluch + * @since 1.1 + */ +class ReactiveSelectOperationSupport implements ReactiveSelectOperation { + + private final R2dbcEntityTemplate template; + + ReactiveSelectOperationSupport(R2dbcEntityTemplate template) { + this.template = template; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation#select(java.lang.Class) + */ + @Override + public ReactiveSelect select(Class domainType) { + + Assert.notNull(domainType, "DomainType must not be null"); + + return new ReactiveSelectSupport<>(this.template, domainType, domainType, Query.empty(), null); + } + + static class ReactiveSelectSupport implements ReactiveSelect { + + private final R2dbcEntityTemplate template; + private final Class domainType; + private final Class returnType; + private final Query query; + private final @Nullable SqlIdentifier tableName; + + ReactiveSelectSupport(R2dbcEntityTemplate template, Class domainType, Class returnType, Query query, + @Nullable SqlIdentifier tableName) { + + this.template = template; + this.domainType = domainType; + this.returnType = returnType; + this.query = query; + this.tableName = tableName; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.SelectWithTable#from(java.lang.String) + */ + @Override + public SelectWithProjection from(SqlIdentifier tableName) { + + Assert.notNull(tableName, "Table name must not be null"); + + return new ReactiveSelectSupport<>(this.template, this.domainType, this.returnType, this.query, tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.SelectWithProjection#as(java.lang.Class) + */ + @Override + public SelectWithQuery as(Class returnType) { + + Assert.notNull(returnType, "ReturnType must not be null"); + + return new ReactiveSelectSupport<>(this.template, this.domainType, returnType, this.query, this.tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.SelectWithQuery#matching(org.springframework.data.r2dbc.query.Query) + */ + @Override + public TerminatingSelect matching(Query query) { + + Assert.notNull(query, "Query must not be null"); + + return new ReactiveSelectSupport<>(this.template, this.domainType, this.returnType, query, this.tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.TerminatingSelect#count() + */ + @Override + public Mono count() { + return this.template.doCount(this.query, this.domainType, getTableName()); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.TerminatingSelect#exists() + */ + @Override + public Mono exists() { + return this.template.doExists(this.query, this.domainType, getTableName()); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.TerminatingSelect#first() + */ + @Override + public Mono first() { + return this.template.doSelect(this.query.limit(1), this.domainType, getTableName(), this.returnType).first(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.TerminatingSelect#one() + */ + @Override + public Mono one() { + return this.template.doSelect(this.query.limit(2), this.domainType, getTableName(), this.returnType).one(); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveSelectOperation.TerminatingSelect#all() + */ + @Override + public Flux all() { + return this.template.doSelect(this.query, this.domainType, getTableName(), this.returnType).all(); + } + + private SqlIdentifier getTableName() { + return this.tableName != null ? this.tableName : this.template.getTableName(this.domainType); + } + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperation.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperation.java new file mode 100644 index 00000000..c779f628 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperation.java @@ -0,0 +1,128 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Mono; + +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.r2dbc.query.Update; +import org.springframework.data.relational.core.sql.SqlIdentifier; + +/** + * The {@link ReactiveUpdateOperation} interface allows creation and execution of {@code UPDATE} operations in a fluent + * API style. + *

+ * The starting {@literal domainType} is used for mapping the {@link Query} provided via {@code matching}, as well as + * the {@link Update} via {@code apply}. + *

+ * By default, the table to operate on is derived from the initial {@literal domainType} and can be defined there via + * the {@link org.springframework.data.relational.core.mapping.Table} annotation. Using {@code inTable} allows a + * developer to override the table name for the execution. + * + *

+ *     
+ *         update(Jedi.class)
+ *             .table("star_wars")
+ *             .matching(query(where("firstname").is("luke")))
+ *             .apply(update("lastname", "skywalker"))
+ *             .all();
+ *     
+ * 
+ * + * @author Mark Paluch + * @since 1.1 + */ +public interface ReactiveUpdateOperation { + + /** + * Begin creating an {@code UPDATE} operation for the given {@link Class domainType}. + * + * @param domainType {@link Class type} of domain object to update; must not be {@literal null}. + * @return new instance of {@link ReactiveUpdate}. + * @throws IllegalArgumentException if {@link Class domainType} is {@literal null}. + * @see ReactiveUpdate + */ + ReactiveUpdate update(Class domainType); + + /** + * Table override (optional). + */ + interface UpdateWithTable { + + /** + * Explicitly set the {@link String name} of the table on which to perform the update. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link String name} of the table; must not be {@literal null} or empty. + * @return new instance of {@link UpdateWithQuery}. + * @throws IllegalArgumentException if {@link String table} is {@literal null} or empty. + * @see UpdateWithQuery + */ + default UpdateWithQuery inTable(String table) { + return inTable(SqlIdentifier.unquoted(table)); + } + + /** + * Explicitly set the {@link SqlIdentifier name} of the table on which to perform the update. + *

+ * Skip this step to use the default table derived from the {@link Class domain type}. + * + * @param table {@link SqlIdentifier name} of the table; must not be {@literal null}. + * @return new instance of {@link UpdateWithQuery}. + * @throws IllegalArgumentException if {@link SqlIdentifier table} is {@literal null}. + * @see UpdateWithQuery + */ + UpdateWithQuery inTable(SqlIdentifier table); + } + + /** + * Define a {@link Query} used as the filter for the {@link Update}. + */ + interface UpdateWithQuery { + + /** + * Filter rows to update by the given {@link Query}. + * + * @param query {@link Query} used as a filter in the update; must not be {@literal null}. + * @return new instance of {@link TerminatingUpdate}. + * @throws IllegalArgumentException if {@link Query} is {@literal null}. + * @see Query + * @see TerminatingUpdate + */ + TerminatingUpdate matching(Query query); + } + + /** + * Trigger {@code UPDATE} execution by calling one of the terminating methods. + */ + interface TerminatingUpdate { + + /** + * Update all matching rows in the table. + * + * @return the number of affected rows by the update; never {@literal null}. + * @see Mono + */ + Mono apply(Update update); + } + + /** + * The {@link ReactiveUpdate} interface provides methods for constructing {@code UPDATE} operations in a fluent way. + */ + interface ReactiveUpdate extends UpdateWithTable, UpdateWithQuery {} + +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperationSupport.java b/src/main/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperationSupport.java new file mode 100644 index 00000000..712ce558 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperationSupport.java @@ -0,0 +1,108 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import reactor.core.publisher.Mono; + +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.r2dbc.query.Update; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Implementation of {@link ReactiveUpdateOperation}. + * + * @author Mark Paluch + * @since 1.1 + */ +class ReactiveUpdateOperationSupport implements ReactiveUpdateOperation { + + private final R2dbcEntityTemplate template; + + ReactiveUpdateOperationSupport(R2dbcEntityTemplate template) { + this.template = template; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveUpdateOperation#update(java.lang.Class) + */ + @Override + public ReactiveUpdate update(Class domainType) { + + Assert.notNull(domainType, "DomainType must not be null"); + + return new ReactiveUpdateSupport(this.template, domainType, Query.empty(), null); + } + + static class ReactiveUpdateSupport implements ReactiveUpdate, TerminatingUpdate { + + private final R2dbcEntityTemplate template; + private final Class domainType; + private final Query query; + private final @Nullable SqlIdentifier tableName; + + ReactiveUpdateSupport(R2dbcEntityTemplate template, Class domainType, Query query, + @Nullable SqlIdentifier tableName) { + + this.template = template; + this.domainType = domainType; + this.query = query; + this.tableName = tableName; + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveUpdateOperation.UpdateWithTable#inTable(SqlIdentifier) + */ + @Override + public UpdateWithQuery inTable(SqlIdentifier tableName) { + + Assert.notNull(tableName, "Table name must not be null"); + + return new ReactiveUpdateSupport(this.template, this.domainType, this.query, tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveUpdateOperation.UpdateWithQuery#matching(org.springframework.data.r2dbc.query.Query) + */ + @Override + public TerminatingUpdate matching(Query query) { + + Assert.notNull(query, "Query must not be null"); + + return new ReactiveUpdateSupport(this.template, this.domainType, query, this.tableName); + } + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.core.ReactiveUpdateOperation.TerminatingUpdate#apply(org.springframework.data.r2dbc.query.Update) + */ + @Override + public Mono apply(Update update) { + + Assert.notNull(update, "Update must not be null"); + + return this.template.doUpdate(this.query, update, this.domainType, getTableName()); + } + + private SqlIdentifier getTableName() { + return this.tableName != null ? this.tableName : this.template.getTableName(this.domainType); + } + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/core/StatementMapper.java b/src/main/java/org/springframework/data/r2dbc/core/StatementMapper.java index b348a304..88495514 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/StatementMapper.java +++ b/src/main/java/org/springframework/data/r2dbc/core/StatementMapper.java @@ -22,6 +22,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import java.util.stream.Collectors; import org.springframework.data.domain.Pageable; @@ -30,7 +31,9 @@ import org.springframework.data.r2dbc.mapping.SettableValue; import org.springframework.data.r2dbc.query.Criteria; import org.springframework.data.r2dbc.query.Update; +import org.springframework.data.relational.core.sql.Expression; import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.data.relational.core.sql.Table; import org.springframework.lang.Nullable; /** @@ -179,19 +182,23 @@ default DeleteSpec createDelete(SqlIdentifier table) { */ class SelectSpec { - private final SqlIdentifier table; - private final List projectedFields; + private final Table table; + private final List projectedFields; + private final List selectList; private final @Nullable Criteria criteria; private final Sort sort; - private final Pageable page; + private final long offset; + private final int limit; - protected SelectSpec(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria, - Sort sort, Pageable page) { + protected SelectSpec(Table table, List projectedFields, List selectList, + @Nullable Criteria criteria, Sort sort, int limit, long offset) { this.table = table; this.projectedFields = projectedFields; + this.selectList = selectList; this.criteria = criteria; this.sort = sort; - this.page = page; + this.offset = offset; + this.limit = limit; } /** @@ -212,7 +219,12 @@ public static SelectSpec create(String table) { * @since 1.1 */ public static SelectSpec create(SqlIdentifier table) { - return new SelectSpec(table, Collections.emptyList(), null, Sort.unsorted(), Pageable.unpaged()); + return new SelectSpec(Table.create(table), Collections.emptyList(), Collections.emptyList(), null, + Sort.unsorted(), -1, -1); + } + + public SelectSpec doWithTable(BiFunction function) { + return function.apply(getTable(), this); } /** @@ -220,9 +232,10 @@ public static SelectSpec create(SqlIdentifier table) { * * @param projectedFields * @return the {@link SelectSpec}. + * @since 1.1 */ public SelectSpec withProjection(String... projectedFields) { - return withProjection(Arrays.stream(projectedFields).map(SqlIdentifier::unquoted).collect(Collectors.toList())); + return withProjection(Arrays.stream(projectedFields).map(table::column).collect(Collectors.toList())); } /** @@ -232,12 +245,39 @@ public SelectSpec withProjection(String... projectedFields) { * @return the {@link SelectSpec}. * @since 1.1 */ - public SelectSpec withProjection(Collection projectedFields) { + public SelectSpec withProjection(SqlIdentifier... projectedFields) { + return withProjection(Arrays.stream(projectedFields).map(table::column).collect(Collectors.toList())); + } + + /** + * Associate {@code expressions} with the select list and create a new {@link SelectSpec}. + * + * @param expressions + * @return the {@link SelectSpec}. + * @since 1.1 + */ + public SelectSpec withProjection(Expression... expressions) { - List fields = new ArrayList<>(this.projectedFields); - fields.addAll(projectedFields); + List selectList = new ArrayList<>(this.selectList); + selectList.addAll(Arrays.asList(expressions)); - return new SelectSpec(this.table, fields, this.criteria, this.sort, this.page); + return new SelectSpec(this.table, projectedFields, selectList, this.criteria, this.sort, this.limit, this.offset); + } + + /** + * Associate {@code projectedFields} with the select and create a new {@link SelectSpec}. + * + * @param projectedFields + * @return the {@link SelectSpec}. + * @since 1.1 + */ + public SelectSpec withProjection(Collection projectedFields) { + + List selectList = new ArrayList<>(this.selectList); + selectList.addAll(projectedFields); + + return new SelectSpec(this.table, this.projectedFields, selectList, this.criteria, this.sort, this.limit, + this.offset); } /** @@ -247,7 +287,8 @@ public SelectSpec withProjection(Collection projectedFields) { * @return the {@link SelectSpec}. */ public SelectSpec withCriteria(Criteria criteria) { - return new SelectSpec(this.table, this.projectedFields, criteria, this.sort, this.page); + return new SelectSpec(this.table, this.projectedFields, this.selectList, criteria, this.sort, this.limit, + this.offset); } /** @@ -259,10 +300,12 @@ public SelectSpec withCriteria(Criteria criteria) { public SelectSpec withSort(Sort sort) { if (sort.isSorted()) { - return new SelectSpec(this.table, this.projectedFields, this.criteria, sort, this.page); + return new SelectSpec(this.table, this.projectedFields, this.selectList, this.criteria, sort, this.limit, + this.offset); } - return new SelectSpec(this.table, this.projectedFields, this.criteria, this.sort, this.page); + return new SelectSpec(this.table, this.projectedFields, this.selectList, this.criteria, this.sort, this.limit, + this.offset); } /** @@ -277,21 +320,53 @@ public SelectSpec withPage(Pageable page) { Sort sort = page.getSort(); - return new SelectSpec(this.table, this.projectedFields, this.criteria, sort.isSorted() ? sort : this.sort, - page); + return new SelectSpec(this.table, this.projectedFields, this.selectList, this.criteria, + sort.isSorted() ? sort : this.sort, page.getPageSize(), page.getOffset()); } - return new SelectSpec(this.table, this.projectedFields, this.criteria, this.sort, page); + return new SelectSpec(this.table, this.projectedFields, this.selectList, this.criteria, this.sort, this.limit, + this.offset); } - public SqlIdentifier getTable() { + /** + * Associate a result offset with the select and create a new {@link SelectSpec}. + * + * @param page + * @return the {@link SelectSpec}. + */ + public SelectSpec offset(long offset) { + return new SelectSpec(this.table, this.projectedFields, this.selectList, this.criteria, this.sort, this.limit, + offset); + } + + /** + * Associate a result limit with the select and create a new {@link SelectSpec}. + * + * @param page + * @return the {@link SelectSpec}. + */ + public SelectSpec limit(int limit) { + return new SelectSpec(this.table, this.projectedFields, this.selectList, this.criteria, this.sort, limit, + this.offset); + } + + public Table getTable() { return this.table; } - public List getProjectedFields() { + /** + * @return + * @deprecated since 1.1, use {@link #getSelectList()} instead. + */ + @Deprecated + public List getProjectedFields() { return Collections.unmodifiableList(this.projectedFields); } + public List getSelectList() { + return Collections.unmodifiableList(selectList); + } + @Nullable public Criteria getCriteria() { return this.criteria; @@ -301,8 +376,12 @@ public Sort getSort() { return this.sort; } - public Pageable getPage() { - return this.page; + public long getOffset() { + return this.offset; + } + + public int getLimit() { + return this.limit; } } @@ -312,9 +391,9 @@ public Pageable getPage() { class InsertSpec { private final SqlIdentifier table; - private final Map assignments; + private final Map assignments; - protected InsertSpec(SqlIdentifier table, Map assignments) { + protected InsertSpec(SqlIdentifier table, Map assignments) { this.table = table; this.assignments = assignments; } @@ -348,8 +427,19 @@ public static InsertSpec create(SqlIdentifier table) { * @return the {@link InsertSpec}. */ public InsertSpec withColumn(String column, SettableValue value) { + return withColumn(SqlIdentifier.unquoted(column), value); + } - Map values = new LinkedHashMap<>(this.assignments); + /** + * Associate a column with a {@link SettableValue} and create a new {@link InsertSpec}. + * + * @param column + * @param value + * @return the {@link InsertSpec}. + */ + public InsertSpec withColumn(SqlIdentifier column, SettableValue value) { + + Map values = new LinkedHashMap<>(this.assignments); values.put(column, value); return new InsertSpec(this.table, values); @@ -359,7 +449,7 @@ public SqlIdentifier getTable() { return this.table; } - public Map getAssignments() { + public Map getAssignments() { return Collections.unmodifiableMap(this.assignments); } } @@ -370,11 +460,12 @@ public Map getAssignments() { class UpdateSpec { private final SqlIdentifier table; + @Nullable private final Update update; private final @Nullable Criteria criteria; - protected UpdateSpec(SqlIdentifier table, Update update, @Nullable Criteria criteria) { + protected UpdateSpec(SqlIdentifier table, @Nullable Update update, @Nullable Criteria criteria) { this.table = table; this.update = update; @@ -416,6 +507,7 @@ public SqlIdentifier getTable() { return this.table; } + @Nullable public Update getUpdate() { return this.update; } diff --git a/src/main/java/org/springframework/data/r2dbc/query/Criteria.java b/src/main/java/org/springframework/data/r2dbc/query/Criteria.java index 73712ee3..858ad83a 100644 --- a/src/main/java/org/springframework/data/r2dbc/query/Criteria.java +++ b/src/main/java/org/springframework/data/r2dbc/query/Criteria.java @@ -19,6 +19,7 @@ import java.util.Collection; import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -35,15 +36,15 @@ public class Criteria { private final @Nullable Criteria previous; private final Combinator combinator; - private final String column; + private final SqlIdentifier column; private final Comparator comparator; private final @Nullable Object value; - private Criteria(String column, Comparator comparator, @Nullable Object value) { + private Criteria(SqlIdentifier column, Comparator comparator, @Nullable Object value) { this(null, Combinator.INITIAL, column, comparator, value); } - private Criteria(@Nullable Criteria previous, Combinator combinator, String column, Comparator comparator, + private Criteria(@Nullable Criteria previous, Combinator combinator, SqlIdentifier column, Comparator comparator, @Nullable Object value) { this.previous = previous; @@ -63,7 +64,7 @@ public static CriteriaStep where(String column) { Assert.hasText(column, "Column name must not be null or empty!"); - return new DefaultCriteriaStep(column); + return new DefaultCriteriaStep(SqlIdentifier.unquoted(column)); } /** @@ -76,10 +77,10 @@ public CriteriaStep and(String column) { Assert.hasText(column, "Column name must not be null or empty!"); - return new DefaultCriteriaStep(column) { + return new DefaultCriteriaStep(SqlIdentifier.unquoted(column)) { @Override protected Criteria createCriteria(Comparator comparator, Object value) { - return new Criteria(Criteria.this, Combinator.AND, column, comparator, value); + return new Criteria(Criteria.this, Combinator.AND, SqlIdentifier.unquoted(column), comparator, value); } }; } @@ -94,10 +95,10 @@ public CriteriaStep or(String column) { Assert.hasText(column, "Column name must not be null or empty!"); - return new DefaultCriteriaStep(column) { + return new DefaultCriteriaStep(SqlIdentifier.unquoted(column)) { @Override protected Criteria createCriteria(Comparator comparator, Object value) { - return new Criteria(Criteria.this, Combinator.OR, column, comparator, value); + return new Criteria(Criteria.this, Combinator.OR, SqlIdentifier.unquoted(column), comparator, value); } }; } @@ -126,9 +127,9 @@ Combinator getCombinator() { } /** - * @return the property name. + * @return the column/property name. */ - String getColumn() { + SqlIdentifier getColumn() { return column; } @@ -164,7 +165,6 @@ public interface CriteriaStep { * Creates a {@link Criteria} using equality. * * @param value must not be {@literal null}. - * @return */ Criteria is(Object value); @@ -172,7 +172,6 @@ public interface CriteriaStep { * Creates a {@link Criteria} using equality (is not). * * @param value must not be {@literal null}. - * @return */ Criteria not(Object value); @@ -180,7 +179,6 @@ public interface CriteriaStep { * Creates a {@link Criteria} using {@code IN}. * * @param values must not be {@literal null}. - * @return */ Criteria in(Object... values); @@ -188,15 +186,13 @@ public interface CriteriaStep { * Creates a {@link Criteria} using {@code IN}. * * @param values must not be {@literal null}. - * @return */ - Criteria in(Collection values); + Criteria in(Collection values); /** * Creates a {@link Criteria} using {@code NOT IN}. * * @param values must not be {@literal null}. - * @return */ Criteria notIn(Object... values); @@ -204,15 +200,13 @@ public interface CriteriaStep { * Creates a {@link Criteria} using {@code NOT IN}. * * @param values must not be {@literal null}. - * @return */ - Criteria notIn(Collection values); + Criteria notIn(Collection values); /** * Creates a {@link Criteria} using less-than ({@literal <}). * * @param value must not be {@literal null}. - * @return */ Criteria lessThan(Object value); @@ -220,7 +214,6 @@ public interface CriteriaStep { * Creates a {@link Criteria} using less-than or equal to ({@literal <=}). * * @param value must not be {@literal null}. - * @return */ Criteria lessThanOrEquals(Object value); @@ -228,7 +221,6 @@ public interface CriteriaStep { * Creates a {@link Criteria} using greater-than({@literal >}). * * @param value must not be {@literal null}. - * @return */ Criteria greaterThan(Object value); @@ -236,7 +228,6 @@ public interface CriteriaStep { * Creates a {@link Criteria} using greater-than or equal to ({@literal >=}). * * @param value must not be {@literal null}. - * @return */ Criteria greaterThanOrEquals(Object value); @@ -244,21 +235,18 @@ public interface CriteriaStep { * Creates a {@link Criteria} using {@code LIKE}. * * @param value must not be {@literal null}. - * @return */ Criteria like(Object value); /** * Creates a {@link Criteria} using {@code IS NULL}. * - * @return */ Criteria isNull(); /** * Creates a {@link Criteria} using {@code IS NOT NULL}. * - * @return */ Criteria isNotNull(); } @@ -268,9 +256,9 @@ public interface CriteriaStep { */ static class DefaultCriteriaStep implements CriteriaStep { - private final String property; + private final SqlIdentifier property; - DefaultCriteriaStep(String property) { + DefaultCriteriaStep(SqlIdentifier property) { this.property = property; } diff --git a/src/main/java/org/springframework/data/r2dbc/query/Query.java b/src/main/java/org/springframework/data/r2dbc/query/Query.java new file mode 100644 index 00000000..c51f0500 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/query/Query.java @@ -0,0 +1,264 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import org.springframework.data.relational.core.sql.SqlIdentifier; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Query object representing {@link Criteria}, columns, {@link Sort}, and limit/offset for a SQL query. {@link Query} is + * created with a fluent API creating immutable objects. + * + * @author Mark Paluch + * @since 1.1 + * @see Criteria + * @see Sort + * @see Pageable + */ +public class Query { + + private final @Nullable Criteria criteria; + + private final List columns; + private final Sort sort; + private final int limit; + private final long offset; + + /** + * Static factory method to create a {@link Query} using the provided {@link Criteria}. + * + * @param criteria must not be {@literal null}. + * @return a new {@link Query} for the given {@link Criteria}. + */ + public static Query query(Criteria criteria) { + return new Query(criteria); + } + + /** + * Creates a new {@link Query} using the given {@link Criteria}. + * + * @param criteria must not be {@literal null}. + */ + private Query(@Nullable Criteria criteria) { + + this.criteria = criteria; + this.sort = Sort.unsorted(); + this.columns = Collections.emptyList(); + this.limit = -1; + this.offset = -1; + } + + private Query(@Nullable Criteria criteria, List columns, Sort sort, int limit, long offset) { + + this.criteria = criteria; + this.columns = columns; + this.sort = sort; + this.limit = limit; + this.offset = offset; + } + + /** + * Create a new empty {@link Query}. + * + * @return + */ + public static Query empty() { + return new Query(null); + } + + /** + * Add columns to the query. + * + * @param columns + * @return a new {@link Query} object containing the former settings with {@code columns} applied. + */ + public Query columns(String... columns) { + + Assert.notNull(columns, "Columns must not be null"); + + return withColumns(Arrays.stream(columns).map(SqlIdentifier::unquoted).collect(Collectors.toList())); + } + + /** + * Add columns to the query. + * + * @param columns + * @return a new {@link Query} object containing the former settings with {@code columns} applied. + */ + public Query columns(Collection columns) { + + Assert.notNull(columns, "Columns must not be null"); + + return withColumns(columns.stream().map(SqlIdentifier::unquoted).collect(Collectors.toList())); + } + + /** + * Add columns to the query. + * + * @param columns + * @return a new {@link Query} object containing the former settings with {@code columns} applied. + * @since 1.1 + */ + public Query columns(SqlIdentifier... columns) { + + Assert.notNull(columns, "Columns must not be null"); + + return withColumns(Arrays.asList(columns)); + } + + /** + * Add columns to the query. + * + * @param columns + * @return a new {@link Query} object containing the former settings with {@code columns} applied. + */ + private Query withColumns(Collection columns) { + + Assert.notNull(columns, "Columns must not be null"); + + List newColumns = new ArrayList<>(this.columns); + newColumns.addAll(columns); + return new Query(this.criteria, newColumns, this.sort, this.limit, offset); + } + + /** + * Set number of rows to skip before returning results. + * + * @param offset + * @return a new {@link Query} object containing the former settings with {@code offset} applied. + */ + public Query offset(long offset) { + return new Query(this.criteria, this.columns, this.sort, this.limit, offset); + } + + /** + * Limit the number of returned documents to {@code limit}. + * + * @param limit + * @return a new {@link Query} object containing the former settings with {@code limit} applied. + */ + public Query limit(int limit) { + return new Query(this.criteria, this.columns, this.sort, limit, this.offset); + } + + /** + * Set the given pagination information on the {@link Query} instance. Will transparently set {@code offset} and + * {@code limit} as well as applying the {@link Sort} instance defined with the {@link Pageable}. + * + * @param pageable + * @return a new {@link Query} object containing the former settings with {@link Pageable} applied. + */ + public Query with(Pageable pageable) { + + if (pageable.isUnpaged()) { + return this; + } + + assertNoCaseSort(pageable.getSort()); + + return new Query(this.criteria, this.columns, this.sort.and(sort), pageable.getPageSize(), pageable.getOffset()); + } + + /** + * Add a {@link Sort} to the {@link Query} instance. + * + * @param sort + * @return a new {@link Query} object containing the former settings with {@link Sort} applied. + */ + public Query sort(Sort sort) { + + Assert.notNull(sort, "Sort must not be null!"); + + if (sort.isUnsorted()) { + return this; + } + + assertNoCaseSort(sort); + + return new Query(this.criteria, this.columns, this.sort.and(sort), this.limit, this.offset); + } + + /** + * Return the {@link Criteria} to be applied. + * + * @return + */ + public Optional getCriteria() { + return Optional.ofNullable(this.criteria); + } + + /** + * Return the columns that this query should project. + * + * @return + */ + public List getColumns() { + return columns; + } + + /** + * Return {@literal true} if the {@link Query} has a sort parameter. + * + * @return {@literal true} if sorted. + * @see Sort#isSorted() + */ + public boolean isSorted() { + return sort.isSorted(); + } + + public Sort getSort() { + return sort; + } + + /** + * Return the number of rows to skip. + * + * @return + */ + public long getOffset() { + return this.offset; + } + + /** + * Return the maximum number of rows to be return. + * + * @return + */ + public int getLimit() { + return this.limit; + } + + private static void assertNoCaseSort(Sort sort) { + + for (Sort.Order order : sort) { + if (order.isIgnoreCase()) { + throw new IllegalArgumentException(String.format("Given sort contained an Order for %s with ignore case;" + + " R2DBC does not support sorting ignoring case currently", order.getProperty())); + } + } + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index 24103f8d..8dac3094 100644 --- a/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -39,13 +39,7 @@ import org.springframework.data.r2dbc.query.Criteria.Comparator; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; -import org.springframework.data.relational.core.sql.Column; -import org.springframework.data.relational.core.sql.Condition; -import org.springframework.data.relational.core.sql.Expression; -import org.springframework.data.relational.core.sql.IdentifierProcessing; -import org.springframework.data.relational.core.sql.SQL; -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.core.sql.*; import org.springframework.data.util.ClassTypeInformation; import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; @@ -111,7 +105,7 @@ public Sort getMappedObject(Sort sort, @Nullable RelationalPersistentEntity e for (Sort.Order order : sort) { - Field field = createPropertyField(entity, order.getProperty(), this.mappingContext); + Field field = createPropertyField(entity, SqlIdentifier.unquoted(order.getProperty()), this.mappingContext); mappedOrder.add( Sort.Order.by(toSql(field.getMappedColumnName())).with(order.getNullHandling()).with(order.getDirection())); } @@ -119,6 +113,72 @@ public Sort getMappedObject(Sort sort, @Nullable RelationalPersistentEntity e return Sort.by(mappedOrder); } + /** + * Map the {@link Sort} object to apply field name mapping using {@link Class the type to read}. + * + * @param sort must not be {@literal null}. + * @param entity related {@link RelationalPersistentEntity}, can be {@literal null}. + * @return + * @since 1.1 + */ + public List getMappedSort(Table table, Sort sort, @Nullable RelationalPersistentEntity entity) { + + List mappedOrder = new ArrayList<>(); + + for (Sort.Order order : sort) { + + Field field = createPropertyField(entity, SqlIdentifier.unquoted(order.getProperty()), this.mappingContext); + OrderByField orderBy = OrderByField.from(table.column(field.getMappedColumnName())) + .withNullHandling(order.getNullHandling()); + mappedOrder.add(order.isAscending() ? orderBy.asc() : orderBy.desc()); + } + + return mappedOrder; + } + + /** + * Map the {@link Expression} object to apply field name mapping using {@link Class the type to read}. + * + * @param expression must not be {@literal null}. + * @param entity related {@link RelationalPersistentEntity}, can be {@literal null}. + * @return the mapped {@link Expression}. + * @since 1.1 + */ + public Expression getMappedObject(Expression expression, @Nullable RelationalPersistentEntity entity) { + + if (entity == null || expression instanceof AsteriskFromTable) { + return expression; + } + + if (expression instanceof Column) { + + Column column = (Column) expression; + Field field = createPropertyField(entity, column.getName()); + Table table = column.getTable(); + + Column columnFromTable = table.column(field.getMappedColumnName()); + return column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable; + } + + if (expression instanceof SimpleFunction) { + + SimpleFunction function = (SimpleFunction) expression; + + List arguments = function.getExpressions(); + List mappedArguments = new ArrayList<>(arguments.size()); + + for (Expression argument : arguments) { + mappedArguments.add(getMappedObject(argument, entity)); + } + + SimpleFunction mappedFunction = SimpleFunction.create(function.getFunctionName(), mappedArguments); + + return function instanceof Aliased ? mappedFunction.as(((Aliased) function).getAlias()) : mappedFunction; + } + + throw new IllegalArgumentException(String.format("Cannot map %s", expression)); + } + /** * Map a {@link Criteria} object into {@link Condition} and consider value/{@code NULL} {@link Bindings}. * @@ -170,7 +230,7 @@ private Condition getCondition(Criteria criteria, MutableBindings bindings, Tabl @Nullable RelationalPersistentEntity entity) { Field propertyField = createPropertyField(entity, criteria.getColumn(), this.mappingContext); - Column column = table.column(toSql(propertyField.getMappedColumnName())); + Column column = table.column(propertyField.getMappedColumnName()); TypeInformation actualType = propertyField.getTypeHint().getRequiredActualType(); Object mappedValue; @@ -243,7 +303,7 @@ private Condition createCondition(Column column, @Nullable Object mappedValue, C for (Object o : (Iterable) mappedValue) { - BindMarker bindMarker = bindings.nextMarker(column.getName()); + BindMarker bindMarker = bindings.nextMarker(column.getName().getReference()); expressions.add(bind(o, valueType, bindings, bindMarker)); } @@ -251,7 +311,7 @@ private Condition createCondition(Column column, @Nullable Object mappedValue, C } else { - BindMarker bindMarker = bindings.nextMarker(column.getName()); + BindMarker bindMarker = bindings.nextMarker(column.getName().getReference()); Expression expression = bind(mappedValue, valueType, bindings, bindMarker); condition = column.in(expression); @@ -264,7 +324,7 @@ private Condition createCondition(Column column, @Nullable Object mappedValue, C return condition; } - BindMarker bindMarker = bindings.nextMarker(column.getName()); + BindMarker bindMarker = bindings.nextMarker(column.getName().getReference()); Expression expression = bind(mappedValue, valueType, bindings, bindMarker); switch (comparator) { @@ -287,7 +347,11 @@ private Condition createCondition(Column column, @Nullable Object mappedValue, C } } - Field createPropertyField(@Nullable RelationalPersistentEntity entity, String key, + Field createPropertyField(@Nullable RelationalPersistentEntity entity, SqlIdentifier key) { + return entity == null ? new Field(key) : new MetadataBackedField(key, entity, mappingContext); + } + + Field createPropertyField(@Nullable RelationalPersistentEntity entity, SqlIdentifier key, MappingContext, RelationalPersistentProperty> mappingContext) { return entity == null ? new Field(key) : new MetadataBackedField(key, entity, mappingContext); } @@ -322,14 +386,14 @@ private Expression bind(@Nullable Object mappedValue, Class valueType, Mutabl */ protected static class Field { - protected final String name; + protected final SqlIdentifier name; /** * Creates a new {@link Field} without meta-information but the given name. * * @param name must not be {@literal null} or empty. */ - public Field(String name) { + public Field(SqlIdentifier name) { Assert.notNull(name, "Name must not be null!"); this.name = name; @@ -341,7 +405,7 @@ public Field(String name) { * @return */ public SqlIdentifier getMappedColumnName() { - return new PassThruIdentifier(this.name); + return this.name; } public TypeInformation getTypeHint() { @@ -367,7 +431,7 @@ protected static class MetadataBackedField extends Field { * @param entity must not be {@literal null}. * @param context must not be {@literal null}. */ - protected MetadataBackedField(String name, RelationalPersistentEntity entity, + protected MetadataBackedField(SqlIdentifier name, RelationalPersistentEntity entity, MappingContext, RelationalPersistentProperty> context) { this(name, entity, context, null); } @@ -381,7 +445,7 @@ protected MetadataBackedField(String name, RelationalPersistentEntity entity, * @param context must not be {@literal null}. * @param property may be {@literal null}. */ - protected MetadataBackedField(String name, RelationalPersistentEntity entity, + protected MetadataBackedField(SqlIdentifier name, RelationalPersistentEntity entity, MappingContext, RelationalPersistentProperty> context, @Nullable RelationalPersistentProperty property) { @@ -392,7 +456,7 @@ protected MetadataBackedField(String name, RelationalPersistentEntity entity, this.entity = entity; this.mappingContext = context; - this.path = getPath(name); + this.path = getPath(name.getReference()); this.property = this.path == null ? property : this.path.getLeafProperty(); } diff --git a/src/main/java/org/springframework/data/r2dbc/query/Update.java b/src/main/java/org/springframework/data/r2dbc/query/Update.java index f265aa2f..068ff523 100644 --- a/src/main/java/org/springframework/data/r2dbc/query/Update.java +++ b/src/main/java/org/springframework/data/r2dbc/query/Update.java @@ -19,6 +19,7 @@ import java.util.LinkedHashMap; import java.util.Map; +import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -32,9 +33,9 @@ public class Update { private static final Update EMPTY = new Update(Collections.emptyMap()); - private final Map columnsToUpdate; + private final Map columnsToUpdate; - private Update(Map columnsToUpdate) { + private Update(Map columnsToUpdate) { this.columnsToUpdate = columnsToUpdate; } @@ -57,6 +58,21 @@ public static Update update(String column, @Nullable Object value) { * @return */ public Update set(String column, @Nullable Object value) { + + Assert.hasText(column, "Column for update must not be null or blank"); + + return addMultiFieldOperation(SqlIdentifier.unquoted(column), value); + } + + /** + * Update a column by assigning a value. + * + * @param column must not be {@literal null}. + * @param value can be {@literal null}. + * @return + * @since 1.1 + */ + public Update set(SqlIdentifier column, @Nullable Object value) { return addMultiFieldOperation(column, value); } @@ -65,15 +81,15 @@ public Update set(String column, @Nullable Object value) { * * @return */ - public Map getAssignments() { + public Map getAssignments() { return Collections.unmodifiableMap(this.columnsToUpdate); } - private Update addMultiFieldOperation(String key, Object value) { + private Update addMultiFieldOperation(SqlIdentifier key, @Nullable Object value) { - Assert.hasText(key, "Column for update must not be null or blank"); + Assert.notNull(key, "Column for update must not be null"); - Map updates = new LinkedHashMap<>(this.columnsToUpdate); + Map updates = new LinkedHashMap<>(this.columnsToUpdate); updates.put(key, value); return new Update(updates); diff --git a/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index 21d69b92..260df541 100644 --- a/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -32,6 +32,7 @@ import org.springframework.data.relational.core.sql.Assignments; import org.springframework.data.relational.core.sql.Column; import org.springframework.data.relational.core.sql.SQL; +import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.relational.core.sql.Table; import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; @@ -77,8 +78,8 @@ public BoundAssignments getMappedObject(BindMarkers markers, Update update, Tabl * @param entity related {@link RelationalPersistentEntity}, can be {@literal null}. * @return the mapped {@link BoundAssignments}. */ - public BoundAssignments getMappedObject(BindMarkers markers, Map assignments, Table table, - @Nullable RelationalPersistentEntity entity) { + public BoundAssignments getMappedObject(BindMarkers markers, Map assignments, + Table table, @Nullable RelationalPersistentEntity entity) { Assert.notNull(markers, "BindMarkers must not be null!"); Assert.notNull(assignments, "Assignments must not be null!"); @@ -95,11 +96,11 @@ public BoundAssignments getMappedObject(BindMarkers markers, Map entity) { Field propertyField = createPropertyField(entity, columnName, getMappingContext()); - Column column = table.column(toSql(propertyField.getMappedColumnName())); + Column column = table.column(propertyField.getMappedColumnName()); TypeInformation actualType = propertyField.getTypeHint().getRequiredActualType(); Object mappedValue; @@ -128,7 +129,7 @@ private Assignment getAssignment(String columnName, Object value, MutableBinding private Assignment createAssignment(Column column, Object value, Class type, MutableBindings bindings) { - BindMarker bindMarker = bindings.nextMarker(column.getName()); + BindMarker bindMarker = bindings.nextMarker(column.getName().getReference()); AssignValue assignValue = Assignments.value(column, SQL.bindMarker(bindMarker.getPlaceholder())); if (value == null) { diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java index 740e6368..fccfb13a 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java @@ -22,6 +22,7 @@ import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.core.DatabaseClient; +import org.springframework.data.r2dbc.core.R2dbcEntityTemplate; import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.repository.R2dbcRepository; import org.springframework.data.r2dbc.repository.query.R2dbcQueryMethod; @@ -92,8 +93,8 @@ protected Object getTargetRepository(RepositoryInformation information) { RelationalEntityInformation entityInformation = getEntityInformation(information.getDomainType(), information); - return getTargetRepositoryViaReflection(information, entityInformation, this.databaseClient, this.converter, - this.dataAccessStrategy); + return getTargetRepositoryViaReflection(information, entityInformation, + new R2dbcEntityTemplate(this.databaseClient, this.dataAccessStrategy), this.converter); } /* diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java index 3693e67e..4ff1b44f 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java @@ -18,26 +18,19 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.List; - import org.reactivestreams.Publisher; -import org.springframework.dao.TransientDataAccessResourceException; import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.core.DatabaseClient; -import org.springframework.data.r2dbc.core.PreparedOperation; +import org.springframework.data.r2dbc.core.R2dbcEntityOperations; +import org.springframework.data.r2dbc.core.R2dbcEntityTemplate; import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; -import org.springframework.data.r2dbc.core.StatementMapper; import org.springframework.data.r2dbc.query.Criteria; +import org.springframework.data.r2dbc.query.Query; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; -import org.springframework.data.relational.core.sql.Functions; -import org.springframework.data.relational.core.sql.Select; -import org.springframework.data.relational.core.sql.SqlIdentifier; -import org.springframework.data.relational.core.sql.StatementBuilder; -import org.springframework.data.relational.core.sql.Table; -import org.springframework.data.relational.core.sql.render.SqlRenderer; import org.springframework.data.relational.repository.query.RelationalEntityInformation; import org.springframework.data.repository.reactive.ReactiveCrudRepository; +import org.springframework.data.util.Lazy; import org.springframework.transaction.annotation.Transactional; import org.springframework.util.Assert; @@ -51,16 +44,45 @@ public class SimpleR2dbcRepository implements ReactiveCrudRepository { private final RelationalEntityInformation entity; - private final DatabaseClient databaseClient; - private final R2dbcConverter converter; - private final ReactiveDataAccessStrategy accessStrategy; + private final R2dbcEntityOperations entityOperations; + private final Lazy idProperty; + + /** + * Create a new {@link SimpleR2dbcRepository}. + * + * @param entity + * @param entityOperations + * @param converter + * @since 1.1 + */ + SimpleR2dbcRepository(RelationalEntityInformation entity, R2dbcEntityOperations entityOperations, + R2dbcConverter converter) { + + this.entity = entity; + this.entityOperations = entityOperations; + this.idProperty = Lazy.of(() -> converter // + .getMappingContext() // + .getRequiredPersistentEntity(this.entity.getJavaType()) // + .getRequiredIdProperty()); + } + /** + * Create a new {@link SimpleR2dbcRepository}. + * + * @param entity + * @param databaseClient + * @param converter + * @param accessStrategy + */ public SimpleR2dbcRepository(RelationalEntityInformation entity, DatabaseClient databaseClient, R2dbcConverter converter, ReactiveDataAccessStrategy accessStrategy) { + this.entity = entity; - this.databaseClient = databaseClient; - this.converter = converter; - this.accessStrategy = accessStrategy; + this.entityOperations = new R2dbcEntityTemplate(databaseClient); + this.idProperty = Lazy.of(() -> converter // + .getMappingContext() // + .getRequiredPersistentEntity(this.entity.getJavaType()) // + .getRequiredIdProperty()); } /* (non-Javadoc) @@ -73,28 +95,10 @@ public Mono save(S objectToSave) { Assert.notNull(objectToSave, "Object to save must not be null!"); if (this.entity.isNew(objectToSave)) { - - return this.databaseClient.insert() // - .into(this.entity.getJavaType()) // - .table(this.entity.getTableName()).using(objectToSave) // - .map(this.converter.populateIdIfNecessary(objectToSave)) // - .first() // - .defaultIfEmpty(objectToSave); + return this.entityOperations.insert(objectToSave); } - return this.databaseClient.update() // - .table(this.entity.getJavaType()) // - .table(this.entity.getTableName()).using(objectToSave) // - .fetch().rowsUpdated().handle((rowsUpdated, sink) -> { - - if (rowsUpdated == 0) { - sink.error(new TransientDataAccessResourceException( - String.format("Failed to update table [%s]. Row with Id [%s] does not exist.", - this.entity.getTableName(), this.entity.getId(objectToSave)))); - } else { - sink.next(objectToSave); - } - }); + return this.entityOperations.update(objectToSave); } /* (non-Javadoc) @@ -129,20 +133,7 @@ public Mono findById(ID id) { Assert.notNull(id, "Id must not be null!"); - List columns = this.accessStrategy.getAllColumns(this.entity.getJavaType()); - String idProperty = getIdProperty().getName(); - - StatementMapper mapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType()); - StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.entity.getTableName()) // - .withProjection(columns) // - .withCriteria(Criteria.where(idProperty).is(id)); - - PreparedOperation operation = mapper.getMappedObject(selectSpec); - - return this.databaseClient.execute(operation) // - .as(this.entity.getJavaType()) // - .fetch() // - .one(); + return this.entityOperations.selectOne(getIdQuery(id), this.entity.getJavaType()); } /* (non-Javadoc) @@ -161,18 +152,7 @@ public Mono existsById(ID id) { Assert.notNull(id, "Id must not be null!"); - String idProperty = getIdProperty().getName(); - - StatementMapper mapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType()); - StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.entity.getTableName()).withProjection(idProperty) // - .withCriteria(Criteria.where(idProperty).is(id)); - - PreparedOperation operation = mapper.getMappedObject(selectSpec); - - return this.databaseClient.execute(operation) // - .map((r, md) -> r) // - .first() // - .hasElement(); + return this.entityOperations.exists(getIdQuery(id), this.entity.getJavaType()); } /* (non-Javadoc) @@ -188,7 +168,7 @@ public Mono existsById(Publisher publisher) { */ @Override public Flux findAll() { - return this.databaseClient.select().from(this.entity.getJavaType()).fetch().all(); + return this.entityOperations.select(Query.empty(), this.entity.getJavaType()); } /* (non-Javadoc) @@ -216,17 +196,9 @@ public Flux findAllById(Publisher idPublisher) { return Flux.empty(); } - List columns = this.accessStrategy.getAllColumns(this.entity.getJavaType()); String idProperty = getIdProperty().getName(); - StatementMapper mapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType()); - StatementMapper.SelectSpec selectSpec = mapper.createSelect(this.entity.getTableName()) // - .withProjection(columns) // - .withCriteria(Criteria.where(idProperty).in(ids)); - - PreparedOperation operation = mapper.getMappedObject(selectSpec); - - return this.databaseClient.execute(operation).as(this.entity.getJavaType()).fetch().all(); + return this.entityOperations.select(Query.query(Criteria.where(idProperty).in(ids)), this.entity.getJavaType()); }); } @@ -235,17 +207,7 @@ public Flux findAllById(Publisher idPublisher) { */ @Override public Mono count() { - - Table table = Table.create(this.accessStrategy.toSql(this.entity.getTableName())); - Select select = StatementBuilder // - .select(Functions.count(table.column(this.accessStrategy.toSql(getIdProperty().getColumnName())))) // - .from(table) // - .build(); - - return this.databaseClient.execute(SqlRenderer.toString(select)) // - .map((r, md) -> r.get(0, Long.class)) // - .first() // - .defaultIfEmpty(0L); + return this.entityOperations.count(Query.empty(), this.entity.getJavaType()); } /* (non-Javadoc) @@ -257,13 +219,7 @@ public Mono deleteById(ID id) { Assert.notNull(id, "Id must not be null!"); - return this.databaseClient.delete() // - .from(this.entity.getJavaType()) // - .table(this.entity.getTableName()) // - .matching(Criteria.where(getIdProperty().getName()).is(id)) // - .fetch() // - .rowsUpdated() // - .then(); + return this.entityOperations.delete(getIdQuery(id), this.entity.getJavaType()).then(); } /* (non-Javadoc) @@ -274,7 +230,6 @@ public Mono deleteById(ID id) { public Mono deleteById(Publisher idPublisher) { Assert.notNull(idPublisher, "The Id Publisher must not be null!"); - StatementMapper statementMapper = this.accessStrategy.getStatementMapper().forType(this.entity.getJavaType()); return Flux.from(idPublisher).buffer().filter(ids -> !ids.isEmpty()).concatMap(ids -> { @@ -282,12 +237,9 @@ public Mono deleteById(Publisher idPublisher) { return Flux.empty(); } - return this.databaseClient.delete() // - .from(this.entity.getJavaType()) // - .table(this.entity.getTableName()) // - .matching(Criteria.where(getIdProperty().getName()).in(ids)) // - .fetch() // - .rowsUpdated(); + String idProperty = getIdProperty().getName(); + + return this.entityOperations.delete(Query.query(Criteria.where(idProperty).in(ids)), this.entity.getJavaType()); }).then(); } @@ -336,14 +288,14 @@ public Mono deleteAll(Publisher objectPublisher) { @Override @Transactional public Mono deleteAll() { - return this.databaseClient.delete().from(this.entity.getTableName()).then(); + return this.entityOperations.delete(Query.empty(), this.entity.getJavaType()).then(); } private RelationalPersistentProperty getIdProperty() { + return this.idProperty.get(); + } - return this.converter // - .getMappingContext() // - .getRequiredPersistentEntity(this.entity.getJavaType()) // - .getRequiredIdProperty(); + private Query getIdQuery(Object id) { + return Query.query(Criteria.where(getIdProperty().getName()).is(id)); } } diff --git a/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java new file mode 100644 index 00000000..dd5031be --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplateUnitTests.java @@ -0,0 +1,208 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; + +import io.r2dbc.spi.test.MockColumnMetadata; +import io.r2dbc.spi.test.MockResult; +import io.r2dbc.spi.test.MockRow; +import io.r2dbc.spi.test.MockRowMetadata; +import reactor.test.StepVerifier; + +import java.util.Collections; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Sort; +import org.springframework.data.r2dbc.dialect.PostgresDialect; +import org.springframework.data.r2dbc.mapping.SettableValue; +import org.springframework.data.r2dbc.query.Criteria; +import org.springframework.data.r2dbc.query.Query; +import org.springframework.data.r2dbc.query.Update; +import org.springframework.data.r2dbc.testing.StatementRecorder; +import org.springframework.data.relational.core.mapping.Column; + +/** + * Unit tests for {@link R2dbcEntityTemplate}. + * + * @author Mark Paluch + */ +public class R2dbcEntityTemplateUnitTests { + + DatabaseClient client; + R2dbcEntityTemplate entityTemplate; + StatementRecorder recorder; + + @Before + public void before() { + + recorder = StatementRecorder.newInstance(); + client = DatabaseClient.builder().connectionFactory(recorder) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + entityTemplate = new R2dbcEntityTemplate(client); + } + + @Test // gh-220 + public void shouldCountBy() { + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.count(Query.query(Criteria.where("name").is("Walter")), Person.class) // + .as(StepVerifier::create) // + .expectNext(1L) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT COUNT(person.id) FROM person WHERE person.THE_NAME = $1"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldExistsByCriteria() { + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.exists(Query.query(Criteria.where("name").is("Walter")), Person.class) // + .as(StepVerifier::create) // + .expectNext(true) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT person.id FROM person WHERE person.THE_NAME = $1"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldSelectByCriteria() { + + recorder.addStubbing(s -> s.startsWith("SELECT"), Collections.emptyList()); + + entityTemplate.select(Query.query(Criteria.where("name").is("Walter")).sort(Sort.by("name")), Person.class) // + .as(StepVerifier::create) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()) + .isEqualTo("SELECT person.* FROM person WHERE person.THE_NAME = $1 ORDER BY THE_NAME ASC"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldSelectOne() { + + recorder.addStubbing(s -> s.startsWith("SELECT"), Collections.emptyList()); + + entityTemplate.selectOne(Query.query(Criteria.where("name").is("Walter")).sort(Sort.by("name")), Person.class) // + .as(StepVerifier::create) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()) + .isEqualTo("SELECT person.* FROM person WHERE person.THE_NAME = $1 ORDER BY THE_NAME ASC LIMIT 2"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldUpdateByQuery() { + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata).rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("UPDATE"), result); + + entityTemplate + .update(Query.query(Criteria.where("name").is("Walter")), Update.update("name", "Heisenberg"), Person.class) // + .as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("UPDATE")); + + assertThat(statement.getSql()).isEqualTo("UPDATE person SET THE_NAME = $1 WHERE person.THE_NAME = $2"); + assertThat(statement.getBindings()).hasSize(2).containsEntry(0, SettableValue.from("Heisenberg")).containsEntry(1, + SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldDeleteByQuery() { + + MockRowMetadata metadata = MockRowMetadata.builder() + .columnMetadata(MockColumnMetadata.builder().name("name").build()).build(); + MockResult result = MockResult.builder().rowMetadata(metadata).rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("DELETE"), result); + + entityTemplate.delete(Query.query(Criteria.where("name").is("Walter")), Person.class) // + .as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE")); + + assertThat(statement.getSql()).isEqualTo("DELETE FROM person WHERE person.THE_NAME = $1"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldDeleteEntity() { + + Person person = new Person(); + person.id = "Walter"; + recorder.addStubbing(s -> s.startsWith("DELETE"), Collections.emptyList()); + + entityTemplate.delete(person) // + .as(StepVerifier::create) // + .expectNext(person).verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE")); + + assertThat(statement.getSql()).isEqualTo("DELETE FROM person WHERE person.id = $1"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + static class Person { + + @Id String id; + + @Column("THE_NAME") String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + } +} diff --git a/src/test/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperationUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperationUnitTests.java new file mode 100644 index 00000000..0abb72b7 --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/core/ReactiveDeleteOperationUnitTests.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.r2dbc.query.Criteria.*; +import static org.springframework.data.r2dbc.query.Query.*; + +import io.r2dbc.spi.test.MockResult; +import reactor.test.StepVerifier; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.data.annotation.Id; +import org.springframework.data.r2dbc.dialect.PostgresDialect; +import org.springframework.data.r2dbc.mapping.SettableValue; +import org.springframework.data.r2dbc.testing.StatementRecorder; +import org.springframework.data.relational.core.mapping.Column; + +/** + * Unit test for {@link ReactiveDeleteOperation}. + * + * @author Mark Paluch + */ +public class ReactiveDeleteOperationUnitTests { + + DatabaseClient client; + R2dbcEntityTemplate entityTemplate; + StatementRecorder recorder; + + @Before + public void before() { + + recorder = StatementRecorder.newInstance(); + client = DatabaseClient.builder().connectionFactory(recorder) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + entityTemplate = new R2dbcEntityTemplate(client); + } + + @Test // gh-220 + public void shouldDelete() { + + MockResult result = MockResult.builder().rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("DELETE"), result); + + entityTemplate.delete(Person.class) // + .matching(query(where("name").is("Walter"))) // + .all() // + .as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE")); + + assertThat(statement.getSql()).isEqualTo("DELETE FROM person WHERE person.THE_NAME = $1"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldDeleteInTable() { + + MockResult result = MockResult.builder().rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("DELETE"), result); + + entityTemplate.delete(Person.class) // + .from("other_table") // + .matching(query(where("name").is("Walter"))) // + .all() // + .as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("DELETE")); + + assertThat(statement.getSql()).isEqualTo("DELETE FROM other_table WHERE other_table.THE_NAME = $1"); + } + + static class Person { + + @Id String id; + + @Column("THE_NAME") String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + } + +} diff --git a/src/test/java/org/springframework/data/r2dbc/core/ReactiveInsertOperationUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/ReactiveInsertOperationUnitTests.java new file mode 100644 index 00000000..f5b98f8d --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/core/ReactiveInsertOperationUnitTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; + +import io.r2dbc.spi.test.MockColumnMetadata; +import io.r2dbc.spi.test.MockResult; +import io.r2dbc.spi.test.MockRow; +import io.r2dbc.spi.test.MockRowMetadata; +import reactor.test.StepVerifier; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.data.annotation.Id; +import org.springframework.data.r2dbc.dialect.PostgresDialect; +import org.springframework.data.r2dbc.mapping.SettableValue; +import org.springframework.data.r2dbc.testing.StatementRecorder; +import org.springframework.data.relational.core.mapping.Column; + +/** + * Unit test for {@link ReactiveInsertOperation}. + * + * @author Mark Paluch + */ +public class ReactiveInsertOperationUnitTests { + + DatabaseClient client; + R2dbcEntityTemplate entityTemplate; + StatementRecorder recorder; + + @Before + public void before() { + + recorder = StatementRecorder.newInstance(); + client = DatabaseClient.builder().connectionFactory(recorder) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + entityTemplate = new R2dbcEntityTemplate(client); + } + + @Test // gh-220 + public void shouldInsert() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, 42).build()).build(); + + recorder.addStubbing(s -> s.startsWith("INSERT"), result); + + Person person = new Person(); + person.setName("Walter"); + + entityTemplate.insert(Person.class) // + .using(person) // + .as(StepVerifier::create) // + .consumeNextWith(actual -> { + + assertThat(actual.id).isEqualTo("42"); + }) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("INSERT")); + + assertThat(statement.getSql()).isEqualTo("INSERT INTO person (THE_NAME) VALUES ($1)"); + assertThat(statement.getBindings()).hasSize(1).containsEntry(0, SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldUpdateInTable() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, 42).build()).build(); + + recorder.addStubbing(s -> s.startsWith("INSERT"), result); + + Person person = new Person(); + person.setName("Walter"); + + entityTemplate.insert(Person.class) // + .into("the_table") // + .using(person) // + .as(StepVerifier::create) // + .consumeNextWith(actual -> { + + assertThat(actual.id).isEqualTo("42"); + }) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("INSERT")); + + assertThat(statement.getSql()).isEqualTo("INSERT INTO the_table (THE_NAME) VALUES ($1)"); + } + + static class Person { + + @Id String id; + + @Column("THE_NAME") String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + } + +} diff --git a/src/test/java/org/springframework/data/r2dbc/core/ReactiveSelectOperationUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/ReactiveSelectOperationUnitTests.java new file mode 100644 index 00000000..ec2593bd --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/core/ReactiveSelectOperationUnitTests.java @@ -0,0 +1,232 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.r2dbc.query.Criteria.*; +import static org.springframework.data.r2dbc.query.Query.*; + +import io.r2dbc.spi.test.MockColumnMetadata; +import io.r2dbc.spi.test.MockResult; +import io.r2dbc.spi.test.MockRow; +import io.r2dbc.spi.test.MockRowMetadata; +import reactor.test.StepVerifier; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.data.annotation.Id; +import org.springframework.data.r2dbc.dialect.PostgresDialect; +import org.springframework.data.r2dbc.testing.StatementRecorder; +import org.springframework.data.relational.core.mapping.Column; + +/** + * Unit test for {@link ReactiveSelectOperation}. + * + * @author Mark Paluch + */ +public class ReactiveSelectOperationUnitTests { + + DatabaseClient client; + R2dbcEntityTemplate entityTemplate; + StatementRecorder recorder; + + @Before + public void before() { + + recorder = StatementRecorder.newInstance(); + client = DatabaseClient.builder().connectionFactory(recorder) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + entityTemplate = new R2dbcEntityTemplate(client); + } + + @Test // gh-220 + public void shouldSelectAll() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, "Walter").build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .matching(query(where("name").is("Walter")).limit(10).offset(20)) // + .all() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()) + .isEqualTo("SELECT person.* FROM person WHERE person.THE_NAME = $1 LIMIT 10 OFFSET 20"); + } + + @Test // gh-220 + public void shouldSelectAs() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, "Walter").build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .as(PersonProjection.class) // + .matching(query(where("name").is("Walter"))) // + .all() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT person.THE_NAME FROM person WHERE person.THE_NAME = $1"); + } + + @Test // gh-220 + public void shouldSelectFromTable() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, "Walter").build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .from("the_table") // + .matching(query(where("name").is("Walter"))) // + .all() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT the_table.* FROM the_table WHERE the_table.THE_NAME = $1"); + } + + @Test // gh-220 + public void shouldSelectFirst() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, "Walter").build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .matching(query(where("name").is("Walter"))) // + .first() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT person.* FROM person WHERE person.THE_NAME = $1 LIMIT 1"); + } + + @Test // gh-220 + public void shouldSelectOne() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, "Walter").build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .matching(query(where("name").is("Walter"))) // + .one() // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT person.* FROM person WHERE person.THE_NAME = $1 LIMIT 2"); + } + + @Test // gh-220 + public void shouldSelectExists() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified("id", Object.class, "Walter").build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .matching(query(where("name").is("Walter"))) // + .exists() // + .as(StepVerifier::create) // + .expectNext(true) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT person.id FROM person WHERE person.THE_NAME = $1"); + } + + @Test // gh-220 + public void shouldSelectCount() { + + MockRowMetadata metadata = MockRowMetadata.builder().columnMetadata(MockColumnMetadata.builder().name("id").build()) + .build(); + MockResult result = MockResult.builder().rowMetadata(metadata) + .row(MockRow.builder().identified(0, Long.class, 1L).build()).build(); + + recorder.addStubbing(s -> s.startsWith("SELECT"), result); + + entityTemplate.select(Person.class) // + .matching(query(where("name").is("Walter"))) // + .count() // + .as(StepVerifier::create) // + .expectNext(1L) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("SELECT")); + + assertThat(statement.getSql()).isEqualTo("SELECT COUNT(person.id) FROM person WHERE person.THE_NAME = $1"); + } + + static class Person { + + @Id String id; + + @Column("THE_NAME") String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + } + + interface PersonProjection { + + String getName(); + } +} diff --git a/src/test/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperationUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperationUnitTests.java new file mode 100644 index 00000000..7be1107e --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/core/ReactiveUpdateOperationUnitTests.java @@ -0,0 +1,111 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.r2dbc.query.Criteria.*; +import static org.springframework.data.r2dbc.query.Query.*; + +import io.r2dbc.spi.test.MockResult; +import reactor.test.StepVerifier; + +import org.junit.Before; +import org.junit.Test; + +import org.springframework.data.annotation.Id; +import org.springframework.data.r2dbc.dialect.PostgresDialect; +import org.springframework.data.r2dbc.mapping.SettableValue; +import org.springframework.data.r2dbc.query.Update; +import org.springframework.data.r2dbc.testing.StatementRecorder; +import org.springframework.data.relational.core.mapping.Column; + +/** + * Unit test for {@link ReactiveUpdateOperation}. + * + * @author Mark Paluch + */ +public class ReactiveUpdateOperationUnitTests { + + DatabaseClient client; + R2dbcEntityTemplate entityTemplate; + StatementRecorder recorder; + + @Before + public void before() { + + recorder = StatementRecorder.newInstance(); + client = DatabaseClient.builder().connectionFactory(recorder) + .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE)).build(); + entityTemplate = new R2dbcEntityTemplate(client); + } + + @Test // gh-220 + public void shouldUpdate() { + + MockResult result = MockResult.builder().rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("UPDATE"), result); + + entityTemplate.update(Person.class) // + .matching(query(where("name").is("Walter"))) // + .apply(Update.update("name", "Heisenberg")) // + .as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("UPDATE")); + + assertThat(statement.getSql()).isEqualTo("UPDATE person SET THE_NAME = $1 WHERE person.THE_NAME = $2"); + assertThat(statement.getBindings()).hasSize(2).containsEntry(0, SettableValue.from("Heisenberg")).containsEntry(1, + SettableValue.from("Walter")); + } + + @Test // gh-220 + public void shouldUpdateInTable() { + + MockResult result = MockResult.builder().rowsUpdated(1).build(); + + recorder.addStubbing(s -> s.startsWith("UPDATE"), result); + + entityTemplate.update(Person.class) // + .inTable("the_table") // + .matching(query(where("name").is("Walter"))) // + .apply(Update.update("name", "Heisenberg")) // + .as(StepVerifier::create) // + .expectNext(1) // + .verifyComplete(); + + StatementRecorder.RecordedStatement statement = recorder.getCreatedStatement(s -> s.startsWith("UPDATE")); + + assertThat(statement.getSql()).isEqualTo("UPDATE the_table SET THE_NAME = $1 WHERE the_table.THE_NAME = $2"); + } + + static class Person { + + @Id String id; + + @Column("THE_NAME") String name; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + } + +} diff --git a/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java b/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java index d8e82ea3..887585ba 100644 --- a/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/query/CriteriaUnitTests.java @@ -21,9 +21,9 @@ import java.util.Arrays; import org.junit.Test; -import org.springframework.data.r2dbc.query.Criteria; -import org.springframework.data.r2dbc.query.Criteria.Combinator; -import org.springframework.data.r2dbc.query.Criteria.Comparator; + +import org.springframework.data.r2dbc.query.Criteria.*; +import org.springframework.data.relational.core.sql.SqlIdentifier; /** * Unit tests for {@link Criteria}. @@ -37,7 +37,7 @@ public void andChainedCriteria() { Criteria criteria = where("foo").is("bar").and("baz").isNotNull(); - assertThat(criteria.getColumn()).isEqualTo("baz"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("baz")); assertThat(criteria.getComparator()).isEqualTo(Comparator.IS_NOT_NULL); assertThat(criteria.getValue()).isNull(); assertThat(criteria.getPrevious()).isNotNull(); @@ -45,7 +45,7 @@ public void andChainedCriteria() { criteria = criteria.getPrevious(); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.EQ); assertThat(criteria.getValue()).isEqualTo("bar"); } @@ -55,7 +55,7 @@ public void orChainedCriteria() { Criteria criteria = where("foo").is("bar").or("baz").isNotNull(); - assertThat(criteria.getColumn()).isEqualTo("baz"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("baz")); assertThat(criteria.getCombinator()).isEqualTo(Combinator.OR); criteria = criteria.getPrevious(); @@ -69,7 +69,7 @@ public void shouldBuildEqualsCriteria() { Criteria criteria = where("foo").is("bar"); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.EQ); assertThat(criteria.getValue()).isEqualTo("bar"); } @@ -79,7 +79,7 @@ public void shouldBuildNotEqualsCriteria() { Criteria criteria = where("foo").not("bar"); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.NEQ); assertThat(criteria.getValue()).isEqualTo("bar"); } @@ -89,7 +89,7 @@ public void shouldBuildInCriteria() { Criteria criteria = where("foo").in("bar", "baz"); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.IN); assertThat(criteria.getValue()).isEqualTo(Arrays.asList("bar", "baz")); } @@ -99,7 +99,7 @@ public void shouldBuildNotInCriteria() { Criteria criteria = where("foo").notIn("bar", "baz"); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.NOT_IN); assertThat(criteria.getValue()).isEqualTo(Arrays.asList("bar", "baz")); } @@ -109,7 +109,7 @@ public void shouldBuildGtCriteria() { Criteria criteria = where("foo").greaterThan(1); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.GT); assertThat(criteria.getValue()).isEqualTo(1); } @@ -119,7 +119,7 @@ public void shouldBuildGteCriteria() { Criteria criteria = where("foo").greaterThanOrEquals(1); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.GTE); assertThat(criteria.getValue()).isEqualTo(1); } @@ -129,7 +129,7 @@ public void shouldBuildLtCriteria() { Criteria criteria = where("foo").lessThan(1); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.LT); assertThat(criteria.getValue()).isEqualTo(1); } @@ -139,7 +139,7 @@ public void shouldBuildLteCriteria() { Criteria criteria = where("foo").lessThanOrEquals(1); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.LTE); assertThat(criteria.getValue()).isEqualTo(1); } @@ -149,7 +149,7 @@ public void shouldBuildLikeCriteria() { Criteria criteria = where("foo").like("hello%"); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.LIKE); assertThat(criteria.getValue()).isEqualTo("hello%"); } @@ -159,7 +159,7 @@ public void shouldBuildIsNullCriteria() { Criteria criteria = where("foo").isNull(); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.IS_NULL); } @@ -168,7 +168,7 @@ public void shouldBuildIsNotNullCriteria() { Criteria criteria = where("foo").isNotNull(); - assertThat(criteria.getColumn()).isEqualTo("foo"); + assertThat(criteria.getColumn()).isEqualTo(SqlIdentifier.unquoted("foo")); assertThat(criteria.getComparator()).isEqualTo(Comparator.IS_NOT_NULL); } } diff --git a/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java b/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java index 26679073..9a747475 100644 --- a/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java @@ -34,6 +34,7 @@ import org.springframework.data.relational.core.sql.AssignValue; import org.springframework.data.relational.core.sql.Expression; import org.springframework.data.relational.core.sql.SQL; +import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.relational.core.sql.Table; /** @@ -54,10 +55,10 @@ public void shouldMapFieldNamesInUpdate() { BoundAssignments mapped = map(update); - Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); - assertThat(assignments).containsEntry("another_name", SQL.bindMarker("$1")); + assertThat(assignments).containsEntry(SqlIdentifier.unquoted("another_name"), SQL.bindMarker("$1")); } @Test // gh-64 @@ -67,10 +68,10 @@ public void shouldUpdateToSettableValue() { BoundAssignments mapped = map(update); - Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); - assertThat(assignments).containsEntry("another_name", SQL.bindMarker("$1")); + assertThat(assignments).containsEntry(SqlIdentifier.unquoted("another_name"), SQL.bindMarker("$1")); mapped.getBindings().apply(bindTarget); verify(bindTarget).bindNull(0, String.class); @@ -87,7 +88,7 @@ public void shouldUpdateToNull() { assertThat(mapped.getAssignments().get(0).toString()).isEqualTo("person.another_name = NULL"); mapped.getBindings().apply(bindTarget); - verifyZeroInteractions(bindTarget); + verifyNoInteractions(bindTarget); } @Test // gh-195 @@ -97,12 +98,12 @@ public void shouldMapMultipleFields() { BoundAssignments mapped = map(update); - Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); assertThat(update.getAssignments()).hasSize(3); - assertThat(assignments).hasSize(3).containsEntry("c1", SQL.bindMarker("$1")).containsEntry("c2", - SQL.bindMarker("$2")); + assertThat(assignments).hasSize(3).containsEntry(SqlIdentifier.unquoted("c1"), SQL.bindMarker("$1")) + .containsEntry(SqlIdentifier.unquoted("c2"), SQL.bindMarker("$2")); } private BoundAssignments map(Update update) { diff --git a/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java b/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java index b18cc5b6..dea67ea1 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java @@ -24,11 +24,12 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; -import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.annotation.Id; +import org.springframework.data.r2dbc.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.core.DatabaseClient; import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy; -import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; import org.springframework.data.relational.repository.query.RelationalEntityInformation; import org.springframework.data.relational.repository.support.MappingRelationalEntityInformation; import org.springframework.data.repository.Repository; @@ -41,18 +42,15 @@ @RunWith(MockitoJUnitRunner.class) public class R2dbcRepositoryFactoryUnitTests { + R2dbcConverter r2dbcConverter = new MappingR2dbcConverter(new R2dbcMappingContext()); + @Mock DatabaseClient databaseClient; - @Mock R2dbcConverter r2dbcConverter; @Mock ReactiveDataAccessStrategy dataAccessStrategy; - @Mock @SuppressWarnings("rawtypes") MappingContext mappingContext; - @Mock @SuppressWarnings("rawtypes") RelationalPersistentEntity entity; @Before @SuppressWarnings("unchecked") public void before() { - when(mappingContext.getRequiredPersistentEntity(Person.class)).thenReturn(entity); when(dataAccessStrategy.getConverter()).thenReturn(r2dbcConverter); - when(r2dbcConverter.getMappingContext()).thenReturn(mappingContext); } @Test @@ -75,5 +73,7 @@ public void createsRepositoryWithIdTypeLong() { interface MyPersonRepository extends Repository {} - static class Person {} + static class Person { + @Id long id; + } } diff --git a/src/test/java/org/springframework/data/r2dbc/testing/StatementRecorder.java b/src/test/java/org/springframework/data/r2dbc/testing/StatementRecorder.java new file mode 100644 index 00000000..d8dbd088 --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/testing/StatementRecorder.java @@ -0,0 +1,312 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.testing; + +import io.r2dbc.spi.Batch; +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.ConnectionFactoryMetadata; +import io.r2dbc.spi.ConnectionMetadata; +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Statement; +import io.r2dbc.spi.ValidationDepth; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +import org.reactivestreams.Publisher; + +import org.springframework.data.r2dbc.mapping.SettableValue; + +/** + * Recorder utility for R2DBC {@link Statement}s. Allows stubbing and introspection. + * + * @author Mark Paluch + */ +public class StatementRecorder implements ConnectionFactory { + + private final Map, Supplier>> stubbings = new LinkedHashMap<>(); + private final List createdStatements = new ArrayList<>(); + private final List executedStatements = new ArrayList<>(); + + private StatementRecorder() {} + + /** + * Create a new {@link StatementRecorder}. + * + * @return + */ + public static StatementRecorder newInstance() { + return new StatementRecorder(); + } + + /** + * Create a new {@link StatementRecorder} accepting a {@link Consumer configurer}. + * + * @param configurer + * @return + */ + public static StatementRecorder newInstance(Consumer configurer) { + + StatementRecorder statementRecorder = new StatementRecorder(); + + configurer.accept(statementRecorder); + + return statementRecorder; + } + + /** + * Add a stubbing rule given the {@link Predicate SQL Predicate} and a {@link Result} that is emitted by the executed + * statement. Typical usage: + * + *

+	 * recorder.addStubbing(sql -> sql.startsWith("SELECT"), result);
+	 * 
+ * + * @param sqlPredicate + * @param result + */ + public void addStubbing(Predicate sqlPredicate, Result result) { + this.stubbings.put(sqlPredicate, () -> Collections.singletonList(result)); + } + + /** + * Add a stubbing rule given the {@link Predicate SQL Predicate} and a list of {@link Result results} that are emitted + * by the executed statement. Typical usage: + * + *
+	 * recorder.addStubbing(sql -> sql.startsWith("SELECT"), results);
+	 * 
+ * + * @param sqlPredicate + * @param result + */ + public void addStubbing(Predicate sqlPredicate, List results) { + this.stubbings.put(sqlPredicate, () -> results); + } + + /** + * Retrieve a statement by {@code sql}. + * + * @param sql + * @return + */ + public RecordedStatement getCreatedStatement(String sql) { + return getCreatedStatement(it -> compareSql(sql, it)); + } + + private static boolean compareSql(String pattern, String actual) { + return actual.equals(pattern) || Pattern.compile(pattern).matcher(actual).find(); + } + + /** + * Retrieve a statement by a {@link Predicate SQL predicate}. + * + * @param sql + * @return + */ + public RecordedStatement getCreatedStatement(Predicate predicate) { + + return createdStatements.stream().filter(recordedStatement -> { + return predicate.test(recordedStatement.getSql()) || predicate.test(recordedStatement.getSql().toLowerCase()) + || predicate.test(recordedStatement.getSql().toUpperCase()); + }).findFirst().orElseThrow(() -> new NoSuchElementException("No statement found")); + } + + public List getCreatedStatements() { + return createdStatements; + } + + public List getExecutedStatements() { + return executedStatements; + } + + @Override + public Publisher create() { + return Mono.just(new RecorderConnection()); + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + return () -> "StatementRecorder"; + } + + class RecorderConnection implements Connection { + @Override + public Publisher beginTransaction() { + return createStatement("BEGIN").execute().then(); + } + + @Override + public Publisher close() { + return createStatement("CLOSE").execute().then(); + } + + @Override + public Publisher commitTransaction() { + return createStatement("COMMIT").execute().then(); + } + + @Override + public Batch createBatch() { + throw new UnsupportedOperationException("createBatch not yet supported"); + } + + @Override + public Publisher createSavepoint(String name) { + return createStatement("CREATE SAVEPOINT " + name).execute().then(); + } + + @Override + public RecordedStatement createStatement(String sql) { + + RecordedStatement statement = doCreateStatement(sql); + + createdStatements.add(statement); + + return statement; + } + + private RecordedStatement doCreateStatement(String sql) { + for (Map.Entry, Supplier>> entry : stubbings.entrySet()) { + + if (entry.getKey().test(sql) || entry.getKey().test(sql.toLowerCase()) + || entry.getKey().test(sql.toUpperCase())) { + return new RecordedStatement(sql, entry.getValue().get()); + } + + } + return new RecordedStatement(sql, Collections.emptyList()); + } + + @Override + public boolean isAutoCommit() { + throw new UnsupportedOperationException("isAutoCommit not yet supported"); + } + + @Override + public ConnectionMetadata getMetadata() { + throw new UnsupportedOperationException("getMetadata not yet supported"); + } + + @Override + public IsolationLevel getTransactionIsolationLevel() { + throw new UnsupportedOperationException("getTransactionIsolationLevel not yet supported"); + } + + @Override + public Publisher releaseSavepoint(String name) { + return createStatement("RELEASE SAVEPOINT " + name).execute().then(); + } + + @Override + public Publisher rollbackTransaction() { + return createStatement("ROLLBACK").execute().then(); + } + + @Override + public Publisher rollbackTransactionToSavepoint(String name) { + return createStatement("ROLLBACK TO " + name).execute().then(); + } + + @Override + public Publisher setAutoCommit(boolean autoCommit) { + return createStatement("SET AUTOCOMMIT " + autoCommit).execute().then(); + } + + @Override + public Publisher setTransactionIsolationLevel(IsolationLevel isolationLevel) { + return createStatement("SET TRANSACTION ISOLATION LEVEL " + isolationLevel.asSql()).execute().then(); + } + + @Override + public Publisher validate(ValidationDepth depth) { + return Mono.just(true); + } + } + + public class RecordedStatement implements Statement { + + private final String sql; + + private final List results; + + private final Map bindings = new LinkedHashMap<>(); + + public RecordedStatement(String sql, Result result) { + this(sql, Collections.singletonList(result)); + } + + public RecordedStatement(String sql, List results) { + this.sql = sql; + this.results = results; + } + + public Map getBindings() { + return bindings; + } + + public String getSql() { + return sql; + } + + @Override + public Statement add() { + return this; + } + + @Override + public Statement bind(int index, Object o) { + this.bindings.put(index, SettableValue.from(o)); + return this; + } + + @Override + public Statement bind(String identifier, Object o) { + this.bindings.put(identifier, SettableValue.from(o)); + return this; + } + + @Override + public Statement bindNull(int index, Class type) { + this.bindings.put(index, SettableValue.empty(type)); + return this; + } + + @Override + public Statement bindNull(String identifier, Class type) { + this.bindings.put(identifier, SettableValue.empty(type)); + return this; + } + + @Override + public Flux execute() { + return Flux.fromIterable(results).doOnSubscribe(subscription -> executedStatements.add(this)); + } + } + +}