diff --git a/src/main/java/com/rabbitmq/stream/Environment.java b/src/main/java/com/rabbitmq/stream/Environment.java
index 00baf4099d..c2a669ea72 100644
--- a/src/main/java/com/rabbitmq/stream/Environment.java
+++ b/src/main/java/com/rabbitmq/stream/Environment.java
@@ -83,6 +83,22 @@ static EnvironmentBuilder builder() {
    */
   StreamStats queryStreamStats(String stream);
 
+  /**
+   * Store the offset for a given reference on the given stream.
+   *
+   * 
This method is useful to store a given offset before a consumer is created.
+   *
+   * 
Prefer the {@link Consumer#store(long)} or {@link MessageHandler.Context#storeOffset()}
+   * methods to store offsets while consuming messages.
+   *
+   * @see Consumer#store(long)
+   * @see MessageHandler.Context#storeOffset()
+   * @param reference the reference to store the offset for, e.g. a consumer name
+   * @param stream the stream
+   * @param offset the offset to store
+   */
+  void storeOffset(String reference, String stream, long offset);
+
   /**
    * Return whether a stream exists or not.
    *
diff --git a/src/main/java/com/rabbitmq/stream/impl/OffsetTrackingUtils.java b/src/main/java/com/rabbitmq/stream/impl/OffsetTrackingUtils.java
new file mode 100644
index 0000000000..74aad57628
--- /dev/null
+++ b/src/main/java/com/rabbitmq/stream/impl/OffsetTrackingUtils.java
@@ -0,0 +1,125 @@
+// Copyright (c) 2025 Broadcom. All Rights Reserved.
+// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
+//
+// This software, the RabbitMQ Stream Java client library, is dual-licensed under the
+// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL").
+// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL,
+// please see LICENSE-APACHE2.
+//
+// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
+// either express or implied. See the LICENSE file for specific language governing
+// rights and limitations of this software.
+//
+// If you have any questions regarding licensing, please contact us at
+// info@rabbitmq.com.
+package com.rabbitmq.stream.impl;
+
+import static com.rabbitmq.stream.BackOffDelayPolicy.fixedWithInitialDelay;
+import static com.rabbitmq.stream.impl.AsyncRetry.asyncRetry;
+import static java.lang.String.format;
+import static java.time.Duration.ofMillis;
+
+import com.rabbitmq.stream.Constants;
+import com.rabbitmq.stream.NoOffsetException;
+import com.rabbitmq.stream.StreamException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.function.LongSupplier;
+import java.util.function.Supplier;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class OffsetTrackingUtils {
+
+  private static final Logger LOGGER = LoggerFactory.getLogger(OffsetTrackingUtils.class);
+
+  private OffsetTrackingUtils() {}
+
+  static long storedOffset(Supplier clientSupplier, String name, String stream) {
+    // the client can be null, so we catch any exception
+    Client.QueryOffsetResponse response;
+    try {
+      response = clientSupplier.get().queryOffset(name, stream);
+    } catch (Exception e) {
+      throw new IllegalStateException(
+          format(
+              "Not possible to query offset for name %s on stream %s for now: %s",
+              name, stream, e.getMessage()),
+          e);
+    }
+    if (response.isOk()) {
+      return response.getOffset();
+    } else if (response.getResponseCode() == Constants.RESPONSE_CODE_NO_OFFSET) {
+      throw new NoOffsetException(
+          format(
+              "No offset stored for name %s on stream %s (%s)",
+              name, stream, Utils.formatConstant(response.getResponseCode())));
+    } else {
+      throw new StreamException(
+          format(
+              "QueryOffset for name %s on stream %s returned an error (%s)",
+              name, stream, Utils.formatConstant(response.getResponseCode())),
+          response.getResponseCode());
+    }
+  }
+
+  static void waitForOffsetToBeStored(
+      String caller,
+      ScheduledExecutorService scheduledExecutorService,
+      LongSupplier offsetSupplier,
+      String name,
+      String stream,
+      long expectedStoredOffset) {
+    String reference = String.format("{stream=%s/name=%s}", stream, name);
+    CompletableFuture storedTask =
+        asyncRetry(
+                () -> {
+                  try {
+                    long lastStoredOffset = offsetSupplier.getAsLong();
+                    boolean stored = lastStoredOffset == expectedStoredOffset;
+                    LOGGER.debug(
+                        "Last stored offset from {} on {} is {}, expecting {}",
+                        caller,
+                        reference,
+                        lastStoredOffset,
+                        expectedStoredOffset);
+                    if (!stored) {
+                      throw new IllegalStateException();
+                    } else {
+                      return true;
+                    }
+                  } catch (StreamException e) {
+                    if (e.getCode() == Constants.RESPONSE_CODE_NO_OFFSET) {
+                      LOGGER.debug(
+                          "No stored offset for {} on {}, expecting {}",
+                          caller,
+                          reference,
+                          expectedStoredOffset);
+                      throw new IllegalStateException();
+                    } else {
+                      throw e;
+                    }
+                  }
+                })
+            .description(
+                "Last stored offset for %s on %s must be %d",
+                caller, reference, expectedStoredOffset)
+            .delayPolicy(fixedWithInitialDelay(ofMillis(200), ofMillis(200)))
+            .retry(exception -> exception instanceof IllegalStateException)
+            .scheduler(scheduledExecutorService)
+            .build();
+
+    try {
+      storedTask.get(10, TimeUnit.SECONDS);
+      LOGGER.debug("Offset {} stored ({}, {})", expectedStoredOffset, caller, reference);
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+    } catch (ExecutionException | TimeoutException e) {
+      LOGGER.warn("Error while checking offset has been stored", e);
+      storedTask.cancel(true);
+    }
+  }
+}
diff --git a/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java b/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java
index 591b6db2e2..11c57e4517 100644
--- a/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java
+++ b/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java
@@ -18,11 +18,9 @@
 import static com.rabbitmq.stream.impl.AsyncRetry.asyncRetry;
 import static com.rabbitmq.stream.impl.Utils.offsetBefore;
 import static java.lang.String.format;
-import static java.time.Duration.ofMillis;
 
 import com.rabbitmq.stream.*;
 import com.rabbitmq.stream.MessageHandler.Context;
-import com.rabbitmq.stream.impl.Client.QueryOffsetResponse;
 import com.rabbitmq.stream.impl.StreamConsumerBuilder.TrackingConfiguration;
 import com.rabbitmq.stream.impl.StreamEnvironment.LocatorNotAvailableException;
 import com.rabbitmq.stream.impl.StreamEnvironment.TrackingConsumerRegistration;
@@ -329,53 +327,13 @@ static long getStoredOffsetSafely(StreamConsumer consumer, StreamEnvironment env
   }
 
   void waitForOffsetToBeStored(long expectedStoredOffset) {
-    CompletableFuture storedTask =
-        asyncRetry(
-                () -> {
-                  try {
-                    long lastStoredOffset = storedOffset();
-                    boolean stored = lastStoredOffset == expectedStoredOffset;
-                    LOGGER.debug(
-                        "Last stored offset from consumer {} on {} is {}, expecting {}",
-                        this.id,
-                        this.stream,
-                        lastStoredOffset,
-                        expectedStoredOffset);
-                    if (!stored) {
-                      throw new IllegalStateException();
-                    } else {
-                      return true;
-                    }
-                  } catch (StreamException e) {
-                    if (e.getCode() == Constants.RESPONSE_CODE_NO_OFFSET) {
-                      LOGGER.debug(
-                          "No stored offset for consumer {} on {}, expecting {}",
-                          this.id,
-                          this.stream,
-                          expectedStoredOffset);
-                      throw new IllegalStateException();
-                    } else {
-                      throw e;
-                    }
-                  }
-                })
-            .description(
-                "Last stored offset for consumer %s on stream %s must be %d",
-                this.name, this.stream, expectedStoredOffset)
-            .delayPolicy(fixedWithInitialDelay(ofMillis(200), ofMillis(200)))
-            .retry(exception -> exception instanceof IllegalStateException)
-            .scheduler(environment.scheduledExecutorService())
-            .build();
-
-    try {
-      storedTask.get(10, TimeUnit.SECONDS);
-      LOGGER.debug(
-          "Offset {} stored (consumer {}, stream {})", expectedStoredOffset, this.id, this.stream);
-    } catch (InterruptedException e) {
-      Thread.currentThread().interrupt();
-    } catch (ExecutionException | TimeoutException e) {
-      LOGGER.warn("Error while checking offset has been stored", e);
-    }
+    OffsetTrackingUtils.waitForOffsetToBeStored(
+        "consumer " + this.id,
+        this.environment.scheduledExecutorService(),
+        this::storedOffset,
+        this.name,
+        this.stream,
+        expectedStoredOffset);
   }
 
   void start() {
@@ -563,31 +521,7 @@ void running() {
   long storedOffset(Supplier clientSupplier) {
     checkNotClosed();
     if (canTrack()) {
-      // the client can be null by now, so we catch any exception
-      QueryOffsetResponse response;
-      try {
-        response = clientSupplier.get().queryOffset(this.name, this.stream);
-      } catch (Exception e) {
-        throw new IllegalStateException(
-            format(
-                "Not possible to query offset for consumer %s on stream %s for now: %s",
-                this.name, this.stream, e.getMessage()),
-            e);
-      }
-      if (response.isOk()) {
-        return response.getOffset();
-      } else if (response.getResponseCode() == Constants.RESPONSE_CODE_NO_OFFSET) {
-        throw new NoOffsetException(
-            format(
-                "No offset stored for consumer %s on stream %s (%s)",
-                this.name, this.stream, Utils.formatConstant(response.getResponseCode())));
-      } else {
-        throw new StreamException(
-            format(
-                "QueryOffset for consumer %s on stream %s returned an error (%s)",
-                this.name, this.stream, Utils.formatConstant(response.getResponseCode())),
-            response.getResponseCode());
-      }
+      return OffsetTrackingUtils.storedOffset(clientSupplier, this.name, this.stream);
     } else if (this.name == null) {
       throw new UnsupportedOperationException(
           "Not possible to query stored offset for a consumer without a name");
diff --git a/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java b/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java
index 8bf503beab..d3ec44d8d0 100644
--- a/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java
+++ b/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java
@@ -556,6 +556,29 @@ public StreamStats queryStreamStats(String stream) {
     }
   }
 
+  @Override
+  public void storeOffset(String reference, String stream, long offset) {
+    checkNotClosed();
+    this.maybeInitializeLocator();
+    locatorOperation(
+        Utils.namedFunction(
+            l -> {
+              l.storeOffset(reference, stream, offset);
+              return null;
+            },
+            "Store offset %d for stream '%s' with reference '%s'",
+            offset,
+            stream,
+            reference));
+    OffsetTrackingUtils.waitForOffsetToBeStored(
+        "env-store-offset",
+        this.scheduledExecutorService,
+        () -> OffsetTrackingUtils.storedOffset(() -> locator().client(), reference, stream),
+        reference,
+        stream,
+        offset);
+  }
+
   @Override
   public boolean streamExists(String stream) {
     checkNotClosed();
diff --git a/src/test/java/com/rabbitmq/stream/impl/Assertions.java b/src/test/java/com/rabbitmq/stream/impl/Assertions.java
index 8a3b0e966f..10b7093774 100644
--- a/src/test/java/com/rabbitmq/stream/impl/Assertions.java
+++ b/src/test/java/com/rabbitmq/stream/impl/Assertions.java
@@ -16,6 +16,7 @@
 
 import static org.assertj.core.api.Assertions.fail;
 
+import com.rabbitmq.stream.Constants;
 import java.time.Duration;
 import org.assertj.core.api.AbstractObjectAssert;
 
@@ -23,10 +24,50 @@ final class Assertions {
 
   private Assertions() {}
 
+  static ResponseAssert assertThat(Client.Response response) {
+    return new ResponseAssert(response);
+  }
+
   static SyncAssert assertThat(TestUtils.Sync sync) {
     return new SyncAssert(sync);
   }
 
+  static class ResponseAssert extends AbstractObjectAssert {
+
+    public ResponseAssert(Client.Response response) {
+      super(response, ResponseAssert.class);
+    }
+
+    ResponseAssert isOk() {
+      if (!actual.isOk()) {
+        fail(
+            "Response should be successful but was not, response code is: %s",
+            Utils.formatConstant(actual.getResponseCode()));
+      }
+      return this;
+    }
+
+    ResponseAssert isNotOk() {
+      if (actual.isOk()) {
+        fail("Response should not be successful but was, response code is: %s", actual);
+      }
+      return this;
+    }
+
+    ResponseAssert hasCode(short responseCode) {
+      if (actual.getResponseCode() != responseCode) {
+        fail(
+            "Response code should be %s but was %s",
+            Utils.formatConstant(responseCode), Utils.formatConstant(actual.getResponseCode()));
+      }
+      return this;
+    }
+
+    ResponseAssert hasCodeNoOffset() {
+      return hasCode(Constants.RESPONSE_CODE_NO_OFFSET);
+    }
+  }
+
   static class SyncAssert extends AbstractObjectAssert {
 
     private SyncAssert(TestUtils.Sync sync) {
diff --git a/src/test/java/com/rabbitmq/stream/impl/StreamConsumerTest.java b/src/test/java/com/rabbitmq/stream/impl/StreamConsumerTest.java
index 6871823f61..a51b2d179c 100644
--- a/src/test/java/com/rabbitmq/stream/impl/StreamConsumerTest.java
+++ b/src/test/java/com/rabbitmq/stream/impl/StreamConsumerTest.java
@@ -15,11 +15,12 @@
 package com.rabbitmq.stream.impl;
 
 import static com.rabbitmq.stream.ConsumerFlowStrategy.creditWhenHalfMessagesProcessed;
+import static com.rabbitmq.stream.impl.Assertions.assertThat;
 import static com.rabbitmq.stream.impl.TestUtils.*;
 import static com.rabbitmq.stream.impl.TestUtils.CountDownLatchConditions.completed;
 import static java.lang.String.format;
 import static java.util.Collections.synchronizedList;
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 import com.rabbitmq.stream.*;
 import com.rabbitmq.stream.impl.Client.QueryOffsetResponse;
@@ -40,6 +41,7 @@
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.IntConsumer;
 import java.util.function.IntFunction;
+import java.util.function.Supplier;
 import java.util.function.UnaryOperator;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
@@ -148,8 +150,8 @@ void committedOffsetShouldBeSet() throws Exception {
                 })
             .build();
 
-    assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
-    assertThat(committedOffset.get()).isNotZero();
+    org.assertj.core.api.Assertions.assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(committedOffset.get()).isNotZero();
 
     consumer.close();
   }
@@ -172,7 +174,7 @@ void consume() throws Exception {
                     Collections.singletonList(
                         client.messageBuilder().addData("".getBytes()).build())));
 
-    assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
 
     CountDownLatch consumeLatch = new CountDownLatch(messageCount);
 
@@ -187,8 +189,8 @@ void consume() throws Exception {
                 })
             .build();
 
-    assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
-    assertThat(chunkTimestamp.get()).isNotZero();
+    org.assertj.core.api.Assertions.assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(chunkTimestamp.get()).isNotZero();
 
     consumer.close();
   }
@@ -227,7 +229,7 @@ void consumeWithAsyncConsumerFlowControl() throws Exception {
     waitAtMost(() -> receivedMessageCount.get() >= processingLimit);
     waitUntilStable(receivedMessageCount::get);
 
-    assertThat(receivedMessageCount)
+    org.assertj.core.api.Assertions.assertThat(receivedMessageCount)
         .hasValueGreaterThanOrEqualTo(processingLimit)
         .hasValueLessThan(messageCount);
 
@@ -258,7 +260,7 @@ void asynchronousProcessingWithFlowControl() {
                         ctx.processed();
                       }))
           .build();
-      assertThat(latch).is(completed());
+      org.assertj.core.api.Assertions.assertThat(latch).is(completed());
     } finally {
       executorService.shutdownNow();
     }
@@ -282,7 +284,7 @@ void closeOnCondition() throws Exception {
                     Collections.singletonList(
                         client.messageBuilder().addData("".getBytes()).build())));
 
-    assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
 
     int messagesToProcess = 20_000;
 
@@ -304,9 +306,9 @@ void closeOnCondition() throws Exception {
                 })
             .build();
 
-    assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
     consumer.close();
-    assertThat(processedMessages).hasValue(messagesToProcess);
+    org.assertj.core.api.Assertions.assertThat(processedMessages).hasValue(messagesToProcess);
   }
 
   @Test
@@ -339,7 +341,7 @@ void consumerShouldBeClosedWhenStreamGetsDeleted(TestInfo info) throws Exception
                     producer.messageBuilder().addData("".getBytes()).build(),
                     confirmationStatus -> publishLatch.countDown()));
 
-    assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
 
     CountDownLatch consumeLatch = new CountDownLatch(messageCount);
     StreamConsumer consumer =
@@ -349,14 +351,14 @@ void consumerShouldBeClosedWhenStreamGetsDeleted(TestInfo info) throws Exception
                 .messageHandler((offset, message) -> consumeLatch.countDown())
                 .build();
 
-    assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
+    org.assertj.core.api.Assertions.assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
 
-    assertThat(consumer.isOpen()).isTrue();
+    org.assertj.core.api.Assertions.assertThat(consumer.isOpen()).isTrue();
 
     environment.deleteStream(s);
 
     TestUtils.waitAtMost(10, () -> !consumer.isOpen());
-    assertThat(consumer.isOpen()).isFalse();
+    org.assertj.core.api.Assertions.assertThat(consumer.isOpen()).isFalse();
   }
 
   @Test
@@ -396,7 +398,7 @@ void manualTrackingConsumerShouldRestartWhereItLeftOff() throws Exception {
 
     messageSending.accept(messageCountFirstWave);
 
-    assertThat(latchAssert(latchConfirmFirstWave)).completes();
+    org.assertj.core.api.Assertions.assertThat(latchAssert(latchConfirmFirstWave)).completes();
 
     int storeEvery = 100;
     AtomicInteger consumedMessageCount = new AtomicInteger();
@@ -425,22 +427,24 @@ void manualTrackingConsumerShouldRestartWhereItLeftOff() throws Exception {
             .build();
 
     ConsumerInfo consumerInfo = MonitoringTestUtils.extract(consumer);
-    assertThat(consumerInfo.getId()).isGreaterThanOrEqualTo(0);
-    assertThat(consumerInfo.getStream()).isEqualTo(stream);
-    assertThat(consumerInfo.getSubscriptionClient()).contains(" -> localhost:5552");
-    assertThat(consumerInfo.getTrackingClient()).contains(" -> localhost:5552");
+    org.assertj.core.api.Assertions.assertThat(consumerInfo.getId()).isGreaterThanOrEqualTo(0);
+    org.assertj.core.api.Assertions.assertThat(consumerInfo.getStream()).isEqualTo(stream);
+    org.assertj.core.api.Assertions.assertThat(consumerInfo.getSubscriptionClient())
+        .contains(" -> localhost:5552");
+    org.assertj.core.api.Assertions.assertThat(consumerInfo.getTrackingClient())
+        .contains(" -> localhost:5552");
 
     consumerReference.set(consumer);
 
     waitAtMost(10, () -> consumedMessageCount.get() == messageCountFirstWave);
 
-    assertThat(lastStoredOffset.get()).isPositive();
+    org.assertj.core.api.Assertions.assertThat(lastStoredOffset.get()).isPositive();
 
     consumer.close();
 
     messageSending.accept(messageCountSecondWave);
 
-    assertThat(latchAssert(latchConfirmSecondWave)).completes();
+    org.assertj.core.api.Assertions.assertThat(latchAssert(latchConfirmSecondWave)).completes();
 
     AtomicLong firstOffset = new AtomicLong(0);
     consumer =
@@ -465,7 +469,8 @@ void manualTrackingConsumerShouldRestartWhereItLeftOff() throws Exception {
 
     // there will be the tracking records after the first wave of messages,
     // messages offset won't be contiguous, so it's not an exact match
-    assertThat(firstOffset.get()).isGreaterThanOrEqualTo(lastStoredOffset.get());
+    org.assertj.core.api.Assertions.assertThat(firstOffset.get())
+        .isGreaterThanOrEqualTo(lastStoredOffset.get());
 
     consumer.close();
   }
@@ -508,7 +513,7 @@ void consumerShouldReUseInitialOffsetSpecificationAfterDisruptionIfNoMessagesRec
     Cli.killConnection("rabbitmq-stream-consumer-0");
 
     // no messages should have been received
-    assertThat(consumedCount.get()).isZero();
+    org.assertj.core.api.Assertions.assertThat(consumedCount.get()).isZero();
 
     // starting the second wave, it sends a message every 100 ms
     AtomicBoolean keepPublishing = new AtomicBoolean(true);
@@ -526,7 +531,7 @@ void consumerShouldReUseInitialOffsetSpecificationAfterDisruptionIfNoMessagesRec
     // the consumer should restart consuming with its initial offset spec, "next"
     try {
       latchAssert(consumeLatch).completes(recoveryInitialDelay.multipliedBy(2));
-      assertThat(bodies).hasSize(1).contains("second wave");
+      org.assertj.core.api.Assertions.assertThat(bodies).hasSize(1).contains("second wave");
     } finally {
       keepPublishing.set(false);
     }
@@ -551,7 +556,7 @@ void consumerShouldKeepConsumingAfterDisruption(
                       producer.messageBuilder().addData("".getBytes()).build(),
                       confirmationStatus -> publishLatch.countDown()));
 
-      assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
+      org.assertj.core.api.Assertions.assertThat(publishLatch.await(10, TimeUnit.SECONDS)).isTrue();
       producer.close();
 
       AtomicInteger receivedMessageCount = new AtomicInteger(0);
@@ -569,9 +574,9 @@ void consumerShouldKeepConsumingAfterDisruption(
                       })
                   .build();
 
-      assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
+      org.assertj.core.api.Assertions.assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue();
 
-      assertThat(consumer.isOpen()).isTrue();
+      org.assertj.core.api.Assertions.assertThat(consumer.isOpen()).isTrue();
 
       disruption.accept(s);
 
@@ -592,13 +597,14 @@ void consumerShouldKeepConsumingAfterDisruption(
                       producerSecondWave.messageBuilder().addData("".getBytes()).build(),
                       confirmationStatus -> publishLatchSecondWave.countDown()));
 
-      assertThat(publishLatchSecondWave.await(10, TimeUnit.SECONDS)).isTrue();
+      org.assertj.core.api.Assertions.assertThat(publishLatchSecondWave.await(10, TimeUnit.SECONDS))
+          .isTrue();
       producerSecondWave.close();
 
       latchAssert(consumeLatchSecondWave).completes(recoveryInitialDelay.plusSeconds(2));
-      assertThat(receivedMessageCount.get())
+      org.assertj.core.api.Assertions.assertThat(receivedMessageCount.get())
           .isBetween(messageCount * 2, messageCount * 2 + 1); // there can be a duplicate
-      assertThat(consumer.isOpen()).isTrue();
+      org.assertj.core.api.Assertions.assertThat(consumer.isOpen()).isTrue();
 
     } finally {
       if (consumer != null) {
@@ -806,7 +812,8 @@ void externalOffsetTrackingWithSubscriptionListener() throws Exception {
     publish.run();
 
     waitAtMost(5, () -> receivedMessages.get() == messageCount);
-    assertThat(offsetTracking.get()).isGreaterThanOrEqualTo(messageCount - 1);
+    org.assertj.core.api.Assertions.assertThat(offsetTracking.get())
+        .isGreaterThanOrEqualTo(messageCount - 1);
 
     Cli.killConnection("rabbitmq-stream-consumer-0");
     waitAtMost(
@@ -814,7 +821,8 @@ void externalOffsetTrackingWithSubscriptionListener() throws Exception {
 
     publish.run();
     waitAtMost(5, () -> receivedMessages.get() == messageCount * 2);
-    assertThat(offsetTracking.get()).isGreaterThanOrEqualTo(messageCount * 2 - 1);
+    org.assertj.core.api.Assertions.assertThat(offsetTracking.get())
+        .isGreaterThanOrEqualTo(messageCount * 2 - 1);
   }
 
   @Test
@@ -866,7 +874,8 @@ void duplicatesWhenResubscribeAfterDisconnectionWithLongFlushInterval() throws E
         });
 
     // we have duplicates because the last stored value is behind and the re-subscription uses it
-    assertThat(receivedMessages).hasValueGreaterThan(publishedMessages.get());
+    org.assertj.core.api.Assertions.assertThat(receivedMessages)
+        .hasValueGreaterThan(publishedMessages.get());
   }
 
   @Test
@@ -927,7 +936,8 @@ void useSubscriptionListenerToRestartExactlyWhereDesired() throws Exception {
     latchAssert(poisonLatch).completes(recoveryInitialDelay.plusSeconds(2));
     // no duplicates because the custom offset tracking overrides the stored offset in the
     // subscription listener
-    assertThat(receivedMessages).hasValue(publishedMessages.get() + 1);
+    org.assertj.core.api.Assertions.assertThat(receivedMessages)
+        .hasValue(publishedMessages.get() + 1);
   }
 
   @Test
@@ -991,4 +1001,50 @@ void creationShouldFailWithDetailsWhenUnknownHost() {
           .isInstanceOfAny(ConnectTimeoutException.class, UnknownHostException.class);
     }
   }
+
+  @Test
+  void resetOffsetTrackingFromEnvironment() {
+    int messageCount = 100;
+    publishAndWaitForConfirms(cf, messageCount, stream);
+    String reference = "app";
+    Sync sync = sync(messageCount);
+    AtomicLong lastOffset = new AtomicLong(0);
+    Supplier consumerSupplier =
+        () ->
+            environment.consumerBuilder().stream(stream)
+                .name(reference)
+                .offset(OffsetSpecification.first())
+                .messageHandler(
+                    (context, message) -> {
+                      lastOffset.set(context.offset());
+                      sync.down();
+                    })
+                .autoTrackingStrategy()
+                .builder()
+                .build();
+    // consumer gets the initial message batch and stores the offset on closing
+    Consumer consumer = consumerSupplier.get();
+    assertThat(sync).completes();
+    consumer.close();
+
+    // we'll publish 1 more message and make sure the consumers only consumes that one
+    // (because it restarts where it left off)
+    long limit = lastOffset.get();
+    sync.reset(1);
+    consumer = consumerSupplier.get();
+
+    publishAndWaitForConfirms(cf, 1, stream);
+
+    assertThat(sync).completes();
+    org.assertj.core.api.Assertions.assertThat(lastOffset).hasValueGreaterThan(limit);
+    consumer.close();
+
+    // we reset the offset to 0, the consumer should restart from the beginning
+    environment.storeOffset(reference, stream, 0);
+    sync.reset(messageCount + 1);
+    consumer = consumerSupplier.get();
+
+    assertThat(sync).completes();
+    consumer.close();
+  }
 }
diff --git a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java
index 9588002d54..a5704e14b4 100644
--- a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java
+++ b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentTest.java
@@ -14,6 +14,7 @@
 // info@rabbitmq.com.
 package com.rabbitmq.stream.impl;
 
+import static com.rabbitmq.stream.impl.Assertions.assertThat;
 import static com.rabbitmq.stream.impl.TestUtils.*;
 import static com.rabbitmq.stream.impl.TestUtils.CountDownLatchConditions.completed;
 import static com.rabbitmq.stream.impl.TestUtils.ExceptionConditions.responseCode;
@@ -812,4 +813,19 @@ void brokerShouldAcceptInitialMemberCountArgument(TestInfo info) {
       env.close();
     }
   }
+
+  @Test
+  void storeOffset() {
+    String ref = "app";
+    Client client = cf.get();
+    assertThat(client.queryOffset(ref, stream)).isNotOk().hasCodeNoOffset();
+    try (Environment env = environmentBuilder.build()) {
+      env.storeOffset(ref, stream, 42);
+      assertThat(client.queryOffset(ref, stream).getOffset()).isEqualTo(42);
+      env.storeOffset(ref, stream, 43);
+      assertThat(client.queryOffset(ref, stream).getOffset()).isEqualTo(43);
+      env.storeOffset(ref, stream, 0);
+      assertThat(client.queryOffset(ref, stream).getOffset()).isEqualTo(0);
+    }
+  }
 }