@@ -737,17 +737,22 @@ pub struct Streams(RwLock<HashMap<String, StreamRef>>);
737737// 4. When first event is sent to stream (update the schema)
738738// 5. When set alert API is called (update the alert)
739739impl Streams {
740- pub fn create (
740+ /// Checks after getting an exclusive lock whether the stream already exists, else creates it.
741+ /// NOTE: This is done to ensure we don't have contention among threads.
742+ pub fn get_or_create (
741743 & self ,
742744 options : Arc < Options > ,
743745 stream_name : String ,
744746 metadata : LogStreamMetadata ,
745747 ingestor_id : Option < String > ,
746748 ) -> StreamRef {
749+ let mut guard = self . write ( ) . expect ( LOCK_EXPECT ) ;
750+ if let Some ( stream) = guard. get ( & stream_name) {
751+ return stream. clone ( ) ;
752+ }
753+
747754 let stream = Stream :: new ( options, & stream_name, metadata, ingestor_id) ;
748- self . write ( )
749- . expect ( LOCK_EXPECT )
750- . insert ( stream_name, stream. clone ( ) ) ;
755+ guard. insert ( stream_name, stream. clone ( ) ) ;
751756
752757 stream
753758 }
@@ -812,7 +817,7 @@ impl Streams {
812817
813818#[ cfg( test) ]
814819mod tests {
815- use std:: time:: Duration ;
820+ use std:: { sync :: Barrier , thread :: spawn , time:: Duration } ;
816821
817822 use arrow_array:: { Int32Array , StringArray , TimestampMillisecondArray } ;
818823 use arrow_schema:: { DataType , Field , TimeUnit } ;
@@ -1187,4 +1192,113 @@ mod tests {
11871192 assert_eq ! ( staging. parquet_files( ) . len( ) , 2 ) ;
11881193 assert_eq ! ( staging. arrow_files( ) . len( ) , 1 ) ;
11891194 }
1195+
1196+ #[ test]
1197+ fn get_or_create_returns_existing_stream ( ) {
1198+ let streams = Streams :: default ( ) ;
1199+ let options = Arc :: new ( Options :: default ( ) ) ;
1200+ let stream_name = "test_stream" ;
1201+ let metadata = LogStreamMetadata :: default ( ) ;
1202+ let ingestor_id = Some ( "test_ingestor" . to_owned ( ) ) ;
1203+
1204+ // Create the stream first
1205+ let stream1 = streams. get_or_create (
1206+ options. clone ( ) ,
1207+ stream_name. to_owned ( ) ,
1208+ metadata. clone ( ) ,
1209+ ingestor_id. clone ( ) ,
1210+ ) ;
1211+
1212+ // Call get_or_create again with the same stream_name
1213+ let stream2 = streams. get_or_create (
1214+ options. clone ( ) ,
1215+ stream_name. to_owned ( ) ,
1216+ metadata. clone ( ) ,
1217+ ingestor_id. clone ( ) ,
1218+ ) ;
1219+
1220+ // Assert that both references point to the same stream
1221+ assert ! ( Arc :: ptr_eq( & stream1, & stream2) ) ;
1222+
1223+ // Verify the map contains only one entry
1224+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1225+ assert_eq ! ( guard. len( ) , 1 ) ;
1226+ }
1227+
1228+ #[ test]
1229+ fn create_and_return_new_stream_when_name_does_not_exist ( ) {
1230+ let streams = Streams :: default ( ) ;
1231+ let options = Arc :: new ( Options :: default ( ) ) ;
1232+ let stream_name = "new_stream" ;
1233+ let metadata = LogStreamMetadata :: default ( ) ;
1234+ let ingestor_id = Some ( "new_ingestor" . to_owned ( ) ) ;
1235+
1236+ // Assert the stream doesn't exist already
1237+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1238+ assert_eq ! ( guard. len( ) , 0 ) ;
1239+ assert ! ( !guard. contains_key( stream_name) ) ;
1240+ drop ( guard) ;
1241+
1242+ // Call get_or_create with a new stream_name
1243+ let stream = streams. get_or_create (
1244+ options. clone ( ) ,
1245+ stream_name. to_owned ( ) ,
1246+ metadata. clone ( ) ,
1247+ ingestor_id. clone ( ) ,
1248+ ) ;
1249+
1250+ // verify created stream has the same ingestor_id
1251+ assert_eq ! ( stream. ingestor_id, ingestor_id) ;
1252+
1253+ // Assert that the stream is created
1254+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1255+ assert_eq ! ( guard. len( ) , 1 ) ;
1256+ assert ! ( guard. contains_key( stream_name) ) ;
1257+ }
1258+
1259+ #[ test]
1260+ fn get_or_create_stream_concurrently ( ) {
1261+ let streams = Arc :: new ( Streams :: default ( ) ) ;
1262+ let options = Arc :: new ( Options :: default ( ) ) ;
1263+ let stream_name = String :: from ( "concurrent_stream" ) ;
1264+ let metadata = LogStreamMetadata :: default ( ) ;
1265+ let ingestor_id = Some ( String :: from ( "concurrent_ingestor" ) ) ;
1266+
1267+ // Barrier to synchronize threads
1268+ let barrier = Arc :: new ( Barrier :: new ( 2 ) ) ;
1269+
1270+ // Clones for the first thread
1271+ let streams1 = Arc :: clone ( & streams) ;
1272+ let options1 = Arc :: clone ( & options) ;
1273+ let barrier1 = Arc :: clone ( & barrier) ;
1274+ let stream_name1 = stream_name. clone ( ) ;
1275+ let metadata1 = metadata. clone ( ) ;
1276+ let ingestor_id1 = ingestor_id. clone ( ) ;
1277+
1278+ // First thread
1279+ let handle1 = spawn ( move || {
1280+ barrier1. wait ( ) ;
1281+ streams1. get_or_create ( options1, stream_name1, metadata1, ingestor_id1)
1282+ } ) ;
1283+
1284+ // Cloned for the second thread
1285+ let streams2 = Arc :: clone ( & streams) ;
1286+
1287+ // Second thread
1288+ let handle2 = spawn ( move || {
1289+ barrier. wait ( ) ;
1290+ streams2. get_or_create ( options, stream_name, metadata, ingestor_id)
1291+ } ) ;
1292+
1293+ // Wait for both threads to complete and get their results
1294+ let stream1 = handle1. join ( ) . expect ( "Thread 1 panicked" ) ;
1295+ let stream2 = handle2. join ( ) . expect ( "Thread 2 panicked" ) ;
1296+
1297+ // Assert that both references point to the same stream
1298+ assert ! ( Arc :: ptr_eq( & stream1, & stream2) ) ;
1299+
1300+ // Verify the map contains only one entry
1301+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1302+ assert_eq ! ( guard. len( ) , 1 ) ;
1303+ }
11901304}
0 commit comments