Skip to content

Commit ec074f7

Browse files
committed
feat: add vectorizer infrastructure for embedding generation
- Implement BaseVectorizer abstract class with caching support - Add MockVectorizer for testing purposes - Create VectorizerBuilder for flexible vectorizer creation - Support multiple embedding models and configurations - Include automatic cache integration for performance optimization
1 parent c7d359a commit ec074f7

12 files changed

+2993
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package com.redis.vl.utils;
2+
3+
import java.nio.ByteBuffer;
4+
import java.nio.ByteOrder;
5+
6+
/** Utility class for array conversions. */
7+
public class ArrayUtils {
8+
9+
private ArrayUtils() {
10+
// Private constructor to prevent instantiation
11+
}
12+
13+
/**
14+
* Convert float array to byte array using little-endian byte order.
15+
*
16+
* @param floats The float array to convert
17+
* @return The byte array representation
18+
*/
19+
public static byte[] floatArrayToBytes(float[] floats) {
20+
if (floats == null) {
21+
return null;
22+
}
23+
ByteBuffer buffer = ByteBuffer.allocate(floats.length * Float.BYTES);
24+
buffer.order(ByteOrder.LITTLE_ENDIAN);
25+
for (float f : floats) {
26+
buffer.putFloat(f);
27+
}
28+
return buffer.array();
29+
}
30+
31+
/**
32+
* Convert byte array to float array using little-endian byte order.
33+
*
34+
* @param bytes The byte array to convert
35+
* @return The float array representation
36+
*/
37+
public static float[] bytesToFloatArray(byte[] bytes) {
38+
if (bytes == null) {
39+
return null;
40+
}
41+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
42+
buffer.order(ByteOrder.LITTLE_ENDIAN);
43+
float[] floats = new float[bytes.length / Float.BYTES];
44+
for (int i = 0; i < floats.length; i++) {
45+
floats[i] = buffer.getFloat();
46+
}
47+
return floats;
48+
}
49+
50+
/**
51+
* Convert double array to float array.
52+
*
53+
* @param doubles The double array to convert
54+
* @return The float array representation
55+
*/
56+
public static float[] doubleArrayToFloats(double[] doubles) {
57+
if (doubles == null) {
58+
return null;
59+
}
60+
float[] floats = new float[doubles.length];
61+
for (int i = 0; i < doubles.length; i++) {
62+
floats[i] = (float) doubles[i];
63+
}
64+
return floats;
65+
}
66+
}
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
package com.redis.vl.utils.vectorize;
2+
3+
import com.redis.vl.extensions.cache.EmbeddingsCache;
4+
import com.redis.vl.utils.ArrayUtils;
5+
import java.util.ArrayList;
6+
import java.util.HashMap;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.Optional;
10+
import java.util.function.Function;
11+
12+
/**
13+
* Abstract base class for text vectorizers. Port of redis-vl-python/redisvl/utils/vectorize/base.py
14+
*/
15+
public abstract class BaseVectorizer {
16+
17+
protected final String modelName;
18+
protected final String dtype;
19+
protected int dimensions;
20+
protected Optional<EmbeddingsCache> cache;
21+
22+
/**
23+
* Creates a new BaseVectorizer.
24+
*
25+
* @param modelName The name of the embedding model
26+
* @param dimensions The dimension of the embedding vectors
27+
*/
28+
protected BaseVectorizer(String modelName, int dimensions) {
29+
this(modelName, dimensions, "float32");
30+
}
31+
32+
/**
33+
* Creates a new BaseVectorizer with specified data type.
34+
*
35+
* @param modelName The name of the embedding model
36+
* @param dimensions The dimension of the embedding vectors (-1 for auto-detect)
37+
* @param dtype The data type for embeddings (default: "float32")
38+
*/
39+
protected BaseVectorizer(String modelName, int dimensions, String dtype) {
40+
this.modelName = modelName;
41+
this.dimensions = dimensions;
42+
this.dtype = dtype != null ? dtype : "float32";
43+
this.cache = Optional.empty();
44+
}
45+
46+
/**
47+
* Set an embeddings cache for this vectorizer.
48+
*
49+
* @param cache The embeddings cache to use
50+
*/
51+
public void setCache(EmbeddingsCache cache) {
52+
this.cache = Optional.ofNullable(cache);
53+
}
54+
55+
/**
56+
* Get the embeddings cache if present.
57+
*
58+
* @return Optional containing the cache, or empty if none set
59+
*/
60+
public Optional<EmbeddingsCache> getCache() {
61+
return cache;
62+
}
63+
64+
/**
65+
* Get the vector data type.
66+
*
67+
* @return The data type (e.g. "float32")
68+
*/
69+
public String getDataType() {
70+
return dtype;
71+
}
72+
73+
/**
74+
* Get the model name.
75+
*
76+
* @return The model name
77+
*/
78+
public String getModelName() {
79+
return modelName;
80+
}
81+
82+
/**
83+
* Get the embedding dimensions.
84+
*
85+
* @return The number of dimensions
86+
*/
87+
public int getDimensions() {
88+
return dimensions;
89+
}
90+
91+
/**
92+
* Embed a single text string.
93+
*
94+
* @param text The text to embed
95+
* @return The embedding vector
96+
*/
97+
public float[] embed(String text) {
98+
return embed(text, null, false, false);
99+
}
100+
101+
/**
102+
* Embed a single text string with full options.
103+
*
104+
* @param text The text to embed
105+
* @param preprocess Optional preprocessing function
106+
* @param asBuffer Return as byte buffer (not implemented in Java version)
107+
* @param skipCache Skip cache lookup and storage
108+
* @return The embedding vector
109+
*/
110+
public float[] embed(
111+
String text, Function<String, String> preprocess, boolean asBuffer, boolean skipCache) {
112+
// Apply preprocessing if provided
113+
String processedText = preprocess != null ? preprocess.apply(text) : text;
114+
115+
// Check cache first if not skipping
116+
if (!skipCache && cache.isPresent()) {
117+
Optional<float[]> cached = cache.get().get(processedText, modelName);
118+
if (cached.isPresent()) {
119+
return cached.get();
120+
}
121+
}
122+
123+
// Generate embedding
124+
float[] embedding = generateEmbedding(processedText);
125+
126+
// Auto-detect dimensions if not set
127+
if (dimensions <= 0 && embedding != null) {
128+
dimensions = embedding.length;
129+
}
130+
131+
// Store in cache if available and not skipping
132+
if (!skipCache && cache.isPresent() && embedding != null) {
133+
cache.get().set(processedText, modelName, embedding);
134+
}
135+
136+
return embedding;
137+
}
138+
139+
/**
140+
* Convert embedding to byte buffer if requested.
141+
*
142+
* @param embedding The embedding vector
143+
* @param asBuffer Whether to return as bytes
144+
* @return The embedding as float array or byte array
145+
*/
146+
protected Object processEmbedding(float[] embedding, boolean asBuffer) {
147+
if (asBuffer) {
148+
return ArrayUtils.floatArrayToBytes(embedding);
149+
}
150+
return embedding;
151+
}
152+
153+
/**
154+
* Embed multiple text strings in batch.
155+
*
156+
* @param texts The texts to embed
157+
* @return List of embedding vectors
158+
*/
159+
public List<float[]> embedBatch(List<String> texts) {
160+
return embedBatch(texts, null, 10, false, false);
161+
}
162+
163+
/**
164+
* Embed multiple text strings with full options.
165+
*
166+
* @param texts List of texts to embed
167+
* @param preprocess Optional preprocessing function
168+
* @param batchSize Number of texts to process per batch
169+
* @param asBuffer Return as byte buffers (not implemented in Java)
170+
* @param skipCache Skip cache lookup and storage
171+
* @return List of embedding vectors
172+
*/
173+
public List<float[]> embedBatch(
174+
List<String> texts,
175+
Function<String, String> preprocess,
176+
int batchSize,
177+
boolean asBuffer,
178+
boolean skipCache) {
179+
if (texts.isEmpty()) {
180+
return new ArrayList<>();
181+
}
182+
183+
// Apply preprocessing if provided
184+
List<String> processedTexts = new ArrayList<>();
185+
for (String text : texts) {
186+
processedTexts.add(preprocess != null ? preprocess.apply(text) : text);
187+
}
188+
189+
// Get cached embeddings and identify misses
190+
BatchCacheResult cacheResult = getFromCacheBatch(processedTexts, skipCache);
191+
List<float[]> results = cacheResult.results;
192+
List<String> cacheMisses = cacheResult.cacheMisses;
193+
List<Integer> cacheMissIndices = cacheResult.cacheMissIndices;
194+
195+
// Generate embeddings for cache misses
196+
if (!cacheMisses.isEmpty()) {
197+
List<float[]> newEmbeddings = generateEmbeddingsBatch(cacheMisses, batchSize);
198+
199+
// Store new embeddings in cache
200+
storeInCacheBatch(cacheMisses, newEmbeddings, skipCache);
201+
202+
// Insert new embeddings into results array
203+
for (int i = 0; i < cacheMissIndices.size() && i < newEmbeddings.size(); i++) {
204+
int idx = cacheMissIndices.get(i);
205+
if (idx < results.size()) {
206+
results.set(idx, newEmbeddings.get(i));
207+
}
208+
}
209+
}
210+
211+
return results;
212+
}
213+
214+
/**
215+
* Generate embedding for a single text (to be implemented by subclasses).
216+
*
217+
* @param text The text to embed
218+
* @return The embedding vector
219+
*/
220+
protected abstract float[] generateEmbedding(String text);
221+
222+
/**
223+
* Generate embeddings for multiple texts in batch (to be implemented by subclasses).
224+
*
225+
* @param texts The texts to embed
226+
* @param batchSize Number of texts to process per batch
227+
* @return List of embedding vectors
228+
*/
229+
protected abstract List<float[]> generateEmbeddingsBatch(List<String> texts, int batchSize);
230+
231+
/** Helper class to hold batch cache results. */
232+
protected static class BatchCacheResult {
233+
public final List<float[]> results;
234+
public final List<String> cacheMisses;
235+
public final List<Integer> cacheMissIndices;
236+
237+
public BatchCacheResult(
238+
List<float[]> results, List<String> cacheMisses, List<Integer> cacheMissIndices) {
239+
this.results = results;
240+
this.cacheMisses = cacheMisses;
241+
this.cacheMissIndices = cacheMissIndices;
242+
}
243+
}
244+
245+
/** Get cached embeddings and identify cache misses. */
246+
private BatchCacheResult getFromCacheBatch(List<String> texts, boolean skipCache) {
247+
List<float[]> results = new ArrayList<>();
248+
List<String> cacheMisses = new ArrayList<>();
249+
List<Integer> cacheMissIndices = new ArrayList<>();
250+
251+
if (skipCache || !cache.isPresent()) {
252+
// No cache, all are misses
253+
for (int i = 0; i < texts.size(); i++) {
254+
results.add(null);
255+
cacheMisses.add(texts.get(i));
256+
cacheMissIndices.add(i);
257+
}
258+
} else {
259+
// Check cache for each text
260+
Map<String, float[]> cachedResults = cache.get().mget(texts, modelName);
261+
262+
for (int i = 0; i < texts.size(); i++) {
263+
String text = texts.get(i);
264+
if (cachedResults.containsKey(text)) {
265+
results.add(cachedResults.get(text));
266+
} else {
267+
results.add(null);
268+
cacheMisses.add(text);
269+
cacheMissIndices.add(i);
270+
}
271+
}
272+
}
273+
274+
return new BatchCacheResult(results, cacheMisses, cacheMissIndices);
275+
}
276+
277+
/** Store new embeddings in cache. */
278+
private void storeInCacheBatch(List<String> texts, List<float[]> embeddings, boolean skipCache) {
279+
if (skipCache || !cache.isPresent() || texts.size() != embeddings.size()) {
280+
return;
281+
}
282+
283+
Map<String, float[]> toStore = new HashMap<>();
284+
for (int i = 0; i < texts.size(); i++) {
285+
toStore.put(texts.get(i), embeddings.get(i));
286+
}
287+
288+
cache.get().mset(toStore, modelName);
289+
}
290+
291+
/**
292+
* Get the vector type identifier.
293+
*
294+
* @return The type of vectorizer
295+
*/
296+
public String getType() {
297+
return "base";
298+
}
299+
}

0 commit comments

Comments
 (0)