Skip to content

Commit f20eaa0

Browse files
committed
feat: implement semantic and embeddings cache extensions
- Add SemanticCache for LLM response caching based on semantic similarity - Implement EmbeddingsCache for efficient embedding storage and retrieval - Support configurable distance thresholds and TTL policies - Include batch operations for improved performance - Add comprehensive test coverage for cache functionality
1 parent ec074f7 commit f20eaa0

File tree

11 files changed

+2535
-0
lines changed

11 files changed

+2535
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.redis.vl.extensions;
2+
3+
/** Constants used within the extension classes. */
4+
public final class ExtensionConstants {
5+
6+
// BaseMessageHistory
7+
public static final String ID_FIELD_NAME = "entry_id";
8+
public static final String ROLE_FIELD_NAME = "role";
9+
public static final String CONTENT_FIELD_NAME = "content";
10+
public static final String TOOL_FIELD_NAME = "tool_call_id";
11+
public static final String TIMESTAMP_FIELD_NAME = "timestamp";
12+
public static final String SESSION_FIELD_NAME = "session_tag";
13+
14+
// SemanticMessageHistory
15+
public static final String MESSAGE_VECTOR_FIELD_NAME = "vector_field";
16+
17+
// SemanticCache
18+
public static final String REDIS_KEY_FIELD_NAME = "key";
19+
public static final String ENTRY_ID_FIELD_NAME = "entry_id";
20+
public static final String PROMPT_FIELD_NAME = "prompt";
21+
public static final String RESPONSE_FIELD_NAME = "response";
22+
public static final String CACHE_VECTOR_FIELD_NAME = "prompt_vector";
23+
public static final String INSERTED_AT_FIELD_NAME = "inserted_at";
24+
public static final String UPDATED_AT_FIELD_NAME = "updated_at";
25+
public static final String METADATA_FIELD_NAME = "metadata"; // also used in MessageHistory
26+
27+
// EmbeddingsCache
28+
public static final String TEXT_FIELD_NAME = "text";
29+
public static final String MODEL_NAME_FIELD_NAME = "model_name";
30+
public static final String EMBEDDING_FIELD_NAME = "embedding";
31+
public static final String DIMENSIONS_FIELD_NAME = "dimensions";
32+
33+
// SemanticRouter
34+
public static final String ROUTE_VECTOR_FIELD_NAME = "vector";
35+
36+
private ExtensionConstants() {
37+
// Prevent instantiation
38+
}
39+
}
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
package com.redis.vl.extensions.cache;
2+
3+
import java.util.List;
4+
import java.util.Set;
5+
import redis.clients.jedis.UnifiedJedis;
6+
import redis.clients.jedis.params.ScanParams;
7+
import redis.clients.jedis.params.SetParams;
8+
import redis.clients.jedis.resps.ScanResult;
9+
10+
/** Abstract base class for all cache implementations. */
11+
public abstract class BaseCache {
12+
13+
protected final String name;
14+
protected final String prefix;
15+
protected Integer ttl;
16+
protected UnifiedJedis redisClient;
17+
18+
/**
19+
* Creates a new BaseCache instance.
20+
*
21+
* @param name The name of the cache
22+
* @param redisClient The Redis client connection
23+
* @param ttl Default time-to-live in seconds for cache entries (null for no expiration)
24+
*/
25+
protected BaseCache(String name, UnifiedJedis redisClient, Integer ttl) {
26+
if (name == null || name.trim().isEmpty()) {
27+
throw new IllegalArgumentException("Cache name cannot be null or empty");
28+
}
29+
if (redisClient == null) {
30+
throw new IllegalArgumentException("Redis client cannot be null");
31+
}
32+
33+
this.name = name;
34+
this.prefix = name + ":";
35+
this.redisClient = redisClient;
36+
this.ttl = ttl;
37+
}
38+
39+
/** Creates a new BaseCache instance without TTL. */
40+
protected BaseCache(String name, UnifiedJedis redisClient) {
41+
this(name, redisClient, null);
42+
}
43+
44+
/**
45+
* Generate a Redis key with the cache prefix.
46+
*
47+
* @param entryId The unique identifier for the cache entry
48+
* @return The prefixed Redis key
49+
*/
50+
protected String makeKey(String entryId) {
51+
return prefix + entryId;
52+
}
53+
54+
/**
55+
* Get the cache prefix.
56+
*
57+
* @return The cache prefix
58+
*/
59+
public String getPrefix() {
60+
return prefix;
61+
}
62+
63+
/**
64+
* Get the cache name.
65+
*
66+
* @return The cache name
67+
*/
68+
public String getName() {
69+
return name;
70+
}
71+
72+
/**
73+
* Get the default TTL for cache entries.
74+
*
75+
* @return Time-to-live in seconds (null if no expiration)
76+
*/
77+
public Integer getTtl() {
78+
return ttl;
79+
}
80+
81+
/**
82+
* Set the default TTL for cache entries.
83+
*
84+
* @param ttl Time-to-live in seconds (null for no expiration)
85+
*/
86+
public void setTtl(Integer ttl) {
87+
this.ttl = ttl;
88+
}
89+
90+
/**
91+
* Set expiration on a key.
92+
*
93+
* @param key The Redis key
94+
* @param ttl Time-to-live in seconds (uses default if null)
95+
*/
96+
public void expire(String key, Integer ttl) {
97+
Integer effectiveTtl = ttl != null ? ttl : this.ttl;
98+
if (effectiveTtl != null && effectiveTtl > 0) {
99+
redisClient.expire(key, effectiveTtl);
100+
}
101+
}
102+
103+
/** Clear all entries in the cache. */
104+
public void clear() {
105+
// Use SCAN to iterate through keys with our prefix
106+
String cursor = "0";
107+
ScanParams scanParams = new ScanParams();
108+
scanParams.match(prefix + "*");
109+
scanParams.count(100);
110+
111+
do {
112+
ScanResult<String> scanResult = redisClient.scan(cursor, scanParams);
113+
List<String> keys = scanResult.getResult();
114+
115+
if (!keys.isEmpty()) {
116+
redisClient.del(keys.toArray(new String[0]));
117+
}
118+
119+
cursor = scanResult.getCursor();
120+
} while (!"0".equals(cursor));
121+
}
122+
123+
/**
124+
* Get the number of entries in the cache.
125+
*
126+
* @return The number of cache entries
127+
*/
128+
public long size() {
129+
Set<String> keys = redisClient.keys(prefix + "*");
130+
return keys.size();
131+
}
132+
133+
/**
134+
* Check if the cache is connected to Redis.
135+
*
136+
* @return true if connected, false otherwise
137+
*/
138+
public boolean isConnected() {
139+
try {
140+
return "PONG".equals(redisClient.ping());
141+
} catch (Exception e) {
142+
return false;
143+
}
144+
}
145+
146+
/** Disconnect from Redis. */
147+
public void disconnect() {
148+
if (redisClient != null) {
149+
redisClient.close();
150+
redisClient = null;
151+
}
152+
}
153+
154+
/** Helper method to set a value with optional TTL. */
155+
protected void setWithTtl(String key, String value, Integer ttl) {
156+
SetParams params = new SetParams();
157+
Integer effectiveTtl = ttl != null ? ttl : this.ttl;
158+
if (effectiveTtl != null && effectiveTtl > 0) {
159+
params.ex(effectiveTtl);
160+
}
161+
redisClient.set(key, value, params);
162+
}
163+
164+
/** Helper method to set a byte array value with optional TTL. */
165+
protected void setWithTtl(byte[] key, byte[] value, Integer ttl) {
166+
if (ttl != null || this.ttl != null) {
167+
Integer effectiveTtl = ttl != null ? ttl : this.ttl;
168+
redisClient.setex(key, effectiveTtl, value);
169+
} else {
170+
redisClient.set(key, value);
171+
}
172+
}
173+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package com.redis.vl.extensions.cache;
2+
3+
import java.util.Map;
4+
5+
/**
6+
* Represents a cache hit from SemanticCache. Contains the matched response and metadata about the
7+
* match.
8+
*/
9+
public class CacheHit {
10+
11+
private final String prompt;
12+
private final String response;
13+
private final float distance;
14+
private final Map<String, Object> metadata;
15+
16+
/**
17+
* Creates a new CacheHit.
18+
*
19+
* @param prompt The original prompt that was matched
20+
* @param response The cached response
21+
* @param distance The vector distance from the query
22+
* @param metadata Additional metadata stored with the cache entry
23+
*/
24+
public CacheHit(String prompt, String response, float distance, Map<String, Object> metadata) {
25+
this.prompt = prompt;
26+
this.response = response;
27+
this.distance = distance;
28+
this.metadata = metadata;
29+
}
30+
31+
/**
32+
* Get the matched prompt.
33+
*
34+
* @return The prompt
35+
*/
36+
public String getPrompt() {
37+
return prompt;
38+
}
39+
40+
/**
41+
* Get the cached response.
42+
*
43+
* @return The response
44+
*/
45+
public String getResponse() {
46+
return response;
47+
}
48+
49+
/**
50+
* Get the vector distance.
51+
*
52+
* @return The distance (0 = exact match, higher = less similar)
53+
*/
54+
public float getDistance() {
55+
return distance;
56+
}
57+
58+
/**
59+
* Get the metadata.
60+
*
61+
* @return The metadata map
62+
*/
63+
public Map<String, Object> getMetadata() {
64+
return metadata;
65+
}
66+
67+
@Override
68+
public String toString() {
69+
return String.format(
70+
"CacheHit{prompt='%s', response='%s', distance=%.4f}", prompt, response, distance);
71+
}
72+
}

0 commit comments

Comments
 (0)