Skip to content

Commit 4a8888e

Browse files
committed
DATAJDBC-101 - Polishing.
Refactored SQL generation. Adapted some assertions to the "Spring Data JDBC style". Minor formatting. Original pull request: #188.
1 parent 7f578ed commit 4a8888e

File tree

6 files changed

+45
-36
lines changed

6 files changed

+45
-36
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateOperations.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ public interface JdbcAggregateOperations {
140140
* @param <T> the type of the aggregate roots. Must not be {@code null}.
141141
* @param sort the sorting information. Must not be {@code null}.
142142
* @return Guaranteed to be not {@code null}.
143+
* @since 2.0
143144
*/
144145
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
145146

@@ -150,6 +151,7 @@ public interface JdbcAggregateOperations {
150151
* @param <T> the type of the aggregate roots. Must not be {@code null}.
151152
* @param pageable the pagination information. Must not be {@code null}.
152153
* @return Guaranteed to be not {@code null}.
154+
* @since 2.0
153155
*/
154156
<T> Page<T> findAll(Class<T> domainType, Pageable pageable);
155157
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DataAccessStrategy.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ default Iterable<Object> findAllByPath(Identifier identifier,
226226
* @param <T> the type of entities to load.
227227
* @param sort the sorting information. Must not be {@code null}.
228228
* @return Guaranteed to be not {@code null}.
229+
* @since 2.0
229230
*/
230231
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);
231232

@@ -236,6 +237,7 @@ default Iterable<Object> findAllByPath(Identifier identifier,
236237
* @param <T> the type of entities to load.
237238
* @param pageable the pagination information. Must not be {@code null}.
238239
* @return Guaranteed to be not {@code null}.
240+
* @since 2.0
239241
*/
240242
<T> Iterable<T> findAll(Class<T> domainType, Pageable pageable);
241243
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlGenerator.java

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,7 @@
1717

1818
import lombok.Value;
1919

20-
import java.util.ArrayList;
21-
import java.util.Collection;
22-
import java.util.Collections;
23-
import java.util.Comparator;
24-
import java.util.HashSet;
25-
import java.util.LinkedHashSet;
26-
import java.util.List;
27-
import java.util.Map;
28-
import java.util.Set;
29-
import java.util.TreeSet;
20+
import java.util.*;
3021
import java.util.function.Function;
3122
import java.util.regex.Pattern;
3223
import java.util.stream.Collectors;
@@ -420,24 +411,29 @@ private SelectBuilder.SelectWhere selectBuilder(Collection<String> keyColumns) {
420411
}
421412

422413
private SelectBuilder.SelectOrdered selectBuilder(Collection<String> keyColumns, Sort sort, Pageable pageable) {
423-
SelectBuilder.SelectWhere baseSelect = this.selectBuilder(keyColumns);
424414

425-
if (baseSelect instanceof SelectBuilder.SelectFromAndJoin) {
426-
if (pageable.isPaged()) {
427-
return ((SelectBuilder.SelectFromAndJoin) baseSelect).limitOffset(pageable.getPageSize(), pageable.getOffset())
428-
.orderBy(extractOrderByFields(sort));
429-
}
430-
return ((SelectBuilder.SelectFromAndJoin) baseSelect).orderBy(extractOrderByFields(sort));
415+
SelectBuilder.SelectOrdered sortable = this.selectBuilder(keyColumns);
416+
sortable = applyPagination(pageable, sortable);
417+
return sortable.orderBy(extractOrderByFields(sort));
431418

432-
} else if (baseSelect instanceof SelectBuilder.SelectFromAndJoinCondition) {
433-
if (pageable.isPaged()) {
434-
return ((SelectBuilder.SelectFromAndJoinCondition) baseSelect)
435-
.limitOffset(pageable.getPageSize(), pageable.getOffset()).orderBy(extractOrderByFields(sort));
436-
}
437-
return baseSelect.orderBy(extractOrderByFields(sort));
438-
} else {
439-
throw new RuntimeException("Unexpected type found!");
419+
}
420+
421+
private SelectBuilder.SelectOrdered applyPagination(Pageable pageable, SelectBuilder.SelectOrdered select) {
422+
423+
if (!pageable.isPaged()) {
424+
return select;
440425
}
426+
427+
Assert.isTrue(select instanceof SelectBuilder.SelectLimitOffset,
428+
() -> String.format("Can't apply limit clause to statement of type %s", select.getClass()));
429+
430+
SelectBuilder.SelectLimitOffset limitable = (SelectBuilder.SelectLimitOffset) select;
431+
SelectBuilder.SelectLimitOffset limitResult = limitable.limitOffset(pageable.getPageSize(), pageable.getOffset());
432+
433+
Assert.state(limitResult instanceof SelectBuilder.SelectOrdered,
434+
String.format("The result of applying the limit-clause must be of type SelectOrdered in order to apply the order-by-clause but is of type %s.", select.getClass()));
435+
436+
return (SelectBuilder.SelectOrdered) limitResult;
441437
}
442438

443439
/**

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/mybatis/MyBatisDataAccessStrategy.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
347347
*/
348348
@Override
349349
public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
350+
350351
Map<String, Object> additionalContext = new HashMap<>();
351352
additionalContext.put("sort", sort);
352353
return sqlSession().selectList(namespace(domainType) + ".findAllSorted",
@@ -359,6 +360,7 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
359360
*/
360361
@Override
361362
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
363+
362364
Map<String, Object> additionalContext = new HashMap<>();
363365
additionalContext.put("pageable", pageable);
364366
return sqlSession().selectList(namespace(domainType) + ".findAllPaged",

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/support/SimpleJdbcRepository.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
*/
4141
@RequiredArgsConstructor
4242
@Transactional(readOnly = true)
43-
public class SimpleJdbcRepository<T, ID> implements CrudRepository<T, ID>, PagingAndSortingRepository<T, ID> {
43+
public class SimpleJdbcRepository<T, ID> implements PagingAndSortingRepository<T, ID> {
4444

4545
private final @NonNull JdbcAggregateOperations entityOperations;
4646
private final @NonNull PersistentEntity<T, ?> entity;

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import org.junit.ClassRule;
4040
import org.junit.Rule;
4141
import org.junit.Test;
42-
4342
import org.springframework.beans.factory.annotation.Autowired;
4443
import org.springframework.context.ApplicationEventPublisher;
4544
import org.springframework.context.annotation.Bean;
@@ -225,41 +224,51 @@ public void saveAndLoadManyEntitiesWithReferencedEntity() {
225224

226225
Iterable<LegoSet> reloadedLegoSets = template.findAll(LegoSet.class);
227226

228-
assertThat(reloadedLegoSets).hasSize(1).extracting("id", "manual.id", "manual.content")
229-
.contains(tuple(legoSet.getId(), legoSet.getManual().getId(), legoSet.getManual().getContent()));
227+
assertThat(reloadedLegoSets) //
228+
.extracting("id", "manual.id", "manual.content") //
229+
.containsExactly(tuple(legoSet.getId(), legoSet.getManual().getId(), legoSet.getManual().getContent()));
230230
}
231231

232232
@Test // DATAJDBC-101
233233
public void saveAndLoadManyEntitiesWithReferencedEntitySorted() {
234+
234235
template.save(createLegoSet("Lava"));
235236
template.save(createLegoSet("Star"));
236237
template.save(createLegoSet("Frozen"));
237238

238239
Iterable<LegoSet> reloadedLegoSets = template.findAll(LegoSet.class, Sort.by("name"));
239240

240-
assertThat(reloadedLegoSets).hasSize(3).extracting("name").isEqualTo(Arrays.asList("Frozen", "Lava", "Star"));
241+
assertThat(reloadedLegoSets) //
242+
.extracting("name") //
243+
.containsExactly("Frozen", "Lava", "Star");
241244
}
242245

243246
@Test // DATAJDBC-101
244247
public void saveAndLoadManyEntitiesWithReferencedEntityPaged() {
248+
245249
template.save(createLegoSet("Lava"));
246250
template.save(createLegoSet("Star"));
247251
template.save(createLegoSet("Frozen"));
248252

249253
Iterable<LegoSet> reloadedLegoSets = template.findAll(LegoSet.class, PageRequest.of(1, 1));
250254

251-
assertThat(reloadedLegoSets).hasSize(1).extracting("name").isEqualTo(singletonList("Star"));
255+
assertThat(reloadedLegoSets) //
256+
.extracting("name") //
257+
.containsExactly("Star");
252258
}
253259

254260
@Test // DATAJDBC-101
255261
public void saveAndLoadManyEntitiesWithReferencedEntitySortedAndPaged() {
262+
256263
template.save(createLegoSet("Lava"));
257264
template.save(createLegoSet("Star"));
258265
template.save(createLegoSet("Frozen"));
259266

260267
Iterable<LegoSet> reloadedLegoSets = template.findAll(LegoSet.class, PageRequest.of(1, 2, Sort.by("name")));
261268

262-
assertThat(reloadedLegoSets).hasSize(1).extracting("name").isEqualTo(singletonList("Star"));
269+
assertThat(reloadedLegoSets) //
270+
.extracting("name") //
271+
.containsExactly("Star");
263272
}
264273

265274
@Test // DATAJDBC-112
@@ -749,12 +758,10 @@ public void saveAndUpdateAggregateWithImmutableVersion() {
749758
AggregateWithImmutableVersion savedAgain = template.save(reloadedAggregate);
750759
AggregateWithImmutableVersion reloadedAgain = template.findById(id, aggregate.getClass());
751760

752-
assertThat(savedAgain.version)
753-
.describedAs("The object returned by save should have an increased version")
761+
assertThat(savedAgain.version).describedAs("The object returned by save should have an increased version")
754762
.isEqualTo(2L);
755763

756-
assertThat(reloadedAgain.getVersion())
757-
.describedAs("version field should increment by one with each save")
764+
assertThat(reloadedAgain.getVersion()).describedAs("version field should increment by one with each save")
758765
.isEqualTo(2L);
759766

760767
assertThatThrownBy(() -> template.save(new AggregateWithImmutableVersion(id, 1L)))

0 commit comments

Comments
 (0)