@@ -23,14 +23,18 @@ use core::mem::MaybeUninit;
2323use hyperlight_common:: outb:: OutBAction ;
2424use spin:: Mutex ;
2525
26+ /// Type alias for the function that sends trace records to the host.
27+ type SendToHostFn = fn ( u64 , & [ TraceRecord ] ) ;
28+
2629/// Global trace buffer for storing trace records.
27- static TRACE_BUFFER : Mutex < TraceBuffer > = Mutex :: new ( TraceBuffer :: new ( ) ) ;
30+ static TRACE_BUFFER : Mutex < TraceBuffer > = Mutex :: new ( TraceBuffer :: new ( send_to_host ) ) ;
2831
2932/// Maximum number of entries in the trace buffer.
30- const MAX_NO_OF_ENTRIES : usize = 64 ;
33+ /// From local testing, 32 entries seems to be a good balance between performance and memory usage.
34+ const MAX_NO_OF_ENTRIES : usize = 32 ;
3135
3236/// Maximum length of a trace message in bytes.
33- pub const MAX_TRACE_MSG_LEN : usize = 64 ;
37+ const MAX_TRACE_MSG_LEN : usize = 64 ;
3438
3539#[ derive( Debug , Copy , Clone ) ]
3640/// Represents a trace record of a guest with a number of cycles and a message.
@@ -43,20 +47,45 @@ pub struct TraceRecord {
4347 pub msg : [ u8 ; MAX_TRACE_MSG_LEN ] ,
4448}
4549
50+ impl TryFrom < & str > for TraceRecord {
51+ type Error = & ' static str ;
52+
53+ fn try_from ( msg : & str ) -> Result < Self , Self :: Error > {
54+ if msg. len ( ) > MAX_TRACE_MSG_LEN {
55+ return Err ( "Message too long" ) ;
56+ }
57+
58+ let cycles = invariant_tsc:: read_tsc ( ) ;
59+
60+ Ok ( TraceRecord {
61+ cycles,
62+ msg : {
63+ let mut arr = [ 0u8 ; MAX_TRACE_MSG_LEN ] ;
64+ arr[ ..msg. len ( ) ] . copy_from_slice ( msg. as_bytes ( ) ) ;
65+ arr
66+ } ,
67+ msg_len : msg. len ( ) ,
68+ } )
69+ }
70+ }
71+
4672/// A buffer for storing trace records.
4773struct TraceBuffer {
4874 /// The entries in the trace buffer.
4975 entries : [ TraceRecord ; MAX_NO_OF_ENTRIES ] ,
5076 /// The index where the next entry will be written.
5177 write_index : usize ,
78+ /// Function to send the trace records to the host.
79+ send_to_host : SendToHostFn ,
5280}
5381
5482impl TraceBuffer {
5583 /// Creates a new `TraceBuffer` with uninitialized entries.
56- const fn new ( ) -> Self {
84+ const fn new ( f : SendToHostFn ) -> Self {
5785 Self {
5886 entries : unsafe { [ MaybeUninit :: zeroed ( ) . assume_init ( ) ; MAX_NO_OF_ENTRIES ] } ,
5987 write_index : 0 ,
88+ send_to_host : f,
6089 }
6190 }
6291
@@ -72,26 +101,26 @@ impl TraceBuffer {
72101
73102 if write_index == 0 {
74103 // If buffer is full send to host
75- self . send_to_host ( MAX_NO_OF_ENTRIES ) ;
104+ ( self . send_to_host ) ( MAX_NO_OF_ENTRIES as u64 , & self . entries ) ;
76105 }
77106 }
78107
79108 /// Flush the trace buffer, sending any remaining records to the host.
80109 fn flush ( & mut self ) {
81110 if self . write_index > 0 {
82- self . send_to_host ( self . write_index ) ;
111+ ( self . send_to_host ) ( self . write_index as u64 , & self . entries ) ;
83112 self . write_index = 0 ; // Reset write index after flushing
84113 }
85114 }
115+ }
86116
87- /// Send the trace records to the host.
88- fn send_to_host ( & self , count : usize ) {
89- unsafe {
90- core:: arch:: asm!( "out dx, al" ,
117+ /// Send the trace records to the host.
118+ fn send_to_host ( len : u64 , records : & [ TraceRecord ] ) {
119+ unsafe {
120+ core:: arch:: asm!( "out dx, al" ,
91121 in( "dx" ) OutBAction :: TraceRecord as u16 ,
92- in( "rax" ) count as u64 ,
93- in( "rcx" ) & self . entries as * const _ as u64 ) ;
94- }
122+ in( "rax" ) len,
123+ in( "rcx" ) records. as_ptr( ) as u64 ) ;
95124 }
96125}
97126
@@ -129,33 +158,120 @@ pub mod invariant_tsc {
129158 }
130159}
131160
132- /// Create a trace record with the given message.
161+ /// Attempt to create a trace record from the message.
162+ /// If the message is too long it will skip the record.
163+ /// This is useful for ensuring that the trace buffer does not overflow.
164+ /// If the message is successfully converted, it will be pushed to the trace buffer.
133165///
134- /// Note: The message must not exceed `MAX_TRACE_MSG_LEN` bytes.
135- /// If the message is too long, it will be skipped.
166+ /// **Note**: The message must not exceed `MAX_TRACE_MSG_LEN` bytes.
136167pub fn create_trace_record ( msg : & str ) {
137- if msg. len ( ) > MAX_TRACE_MSG_LEN {
138- return ; // Message too long, skip tracing
168+ // Maybe panic if the message is too long?
169+ if let Ok ( entry) = TraceRecord :: try_from ( msg) {
170+ let mut buffer = TRACE_BUFFER . lock ( ) ;
171+ buffer. push ( entry) ;
139172 }
140-
141- let cycles = invariant_tsc:: read_tsc ( ) ;
142-
143- let entry = TraceRecord {
144- cycles,
145- msg : {
146- let mut arr = [ 0u8 ; MAX_TRACE_MSG_LEN ] ;
147- arr[ ..msg. len ( ) ] . copy_from_slice ( msg. as_bytes ( ) ) ;
148- arr
149- } ,
150- msg_len : msg. len ( ) ,
151- } ;
152-
153- let mut buffer = TRACE_BUFFER . lock ( ) ;
154- buffer. push ( entry) ;
155173}
156174
157175/// Flush the trace buffer to send any remaining trace records to the host.
158176pub fn flush_trace_buffer ( ) {
159177 let mut buffer = TRACE_BUFFER . lock ( ) ;
160178 buffer. flush ( ) ;
161179}
180+
181+ #[ cfg( test) ]
182+ mod tests {
183+ use alloc:: format;
184+
185+ use super :: * ;
186+
187+ /// This is a mock function for testing purposes.
188+ /// In a real scenario, this would send the trace records to the host.
189+ fn mock_send_to_host ( _len : u64 , _records : & [ TraceRecord ] ) { }
190+
191+ fn create_test_entry ( msg : & str ) -> TraceRecord {
192+ let cycles = invariant_tsc:: read_tsc ( ) ;
193+
194+ TraceRecord {
195+ cycles,
196+ msg : {
197+ let mut arr = [ 0u8 ; MAX_TRACE_MSG_LEN ] ;
198+ arr[ ..msg. len ( ) ] . copy_from_slice ( msg. as_bytes ( ) ) ;
199+ arr
200+ } ,
201+ msg_len : msg. len ( ) ,
202+ }
203+ }
204+
205+ #[ test]
206+ fn test_push_trace_record ( ) {
207+ let mut buffer = TraceBuffer :: new ( mock_send_to_host) ;
208+
209+ let msg = "Test message" ;
210+ let entry = create_test_entry ( msg) ;
211+
212+ buffer. push ( entry) ;
213+ assert_eq ! ( buffer. write_index, 1 ) ;
214+ assert_eq ! ( buffer. entries[ 0 ] . msg_len, msg. len( ) ) ;
215+ assert_eq ! ( & buffer. entries[ 0 ] . msg[ ..msg. len( ) ] , msg. as_bytes( ) ) ;
216+ assert ! ( buffer. entries[ 0 ] . cycles > 0 ) ; // Ensure cycles is set
217+ }
218+
219+ #[ test]
220+ fn test_flush_trace_buffer ( ) {
221+ let mut buffer = TraceBuffer :: new ( mock_send_to_host) ;
222+
223+ let msg = "Test message" ;
224+ let entry = create_test_entry ( msg) ;
225+
226+ buffer. push ( entry) ;
227+ assert_eq ! ( buffer. write_index, 1 ) ;
228+ assert_eq ! ( buffer. entries[ 0 ] . msg_len, msg. len( ) ) ;
229+ assert_eq ! ( & buffer. entries[ 0 ] . msg[ ..msg. len( ) ] , msg. as_bytes( ) ) ;
230+ assert ! ( buffer. entries[ 0 ] . cycles > 0 ) ;
231+
232+ // Flush the buffer
233+ buffer. flush ( ) ;
234+
235+ // After flushing, the entryes should still be intact, we don't clear them
236+ assert_eq ! ( buffer. write_index, 0 ) ;
237+ assert_eq ! ( buffer. entries[ 0 ] . msg_len, msg. len( ) ) ;
238+ assert_eq ! ( & buffer. entries[ 0 ] . msg[ ..msg. len( ) ] , msg. as_bytes( ) ) ;
239+ assert ! ( buffer. entries[ 0 ] . cycles > 0 ) ;
240+ }
241+
242+ #[ test]
243+ fn test_auto_flush_on_full ( ) {
244+ let mut buffer = TraceBuffer :: new ( mock_send_to_host) ;
245+
246+ // Fill the buffer to trigger auto-flush
247+ for i in 0 ..MAX_NO_OF_ENTRIES {
248+ let msg = format ! ( "Message {}" , i) ;
249+ let entry = create_test_entry ( & msg) ;
250+ buffer. push ( entry) ;
251+ }
252+
253+ // After filling, the write index should be 0 (buffer is full)
254+ assert_eq ! ( buffer. write_index, 0 ) ;
255+
256+ // The first entry should still be intact
257+ assert_eq ! ( buffer. entries[ 0 ] . msg_len, "Message 0" . len( ) ) ;
258+ }
259+
260+ /// Test TraceRecord creation with a valid message
261+ #[ test]
262+ fn test_trace_record_creation_valid ( ) {
263+ let msg = "Valid message" ;
264+ let entry = TraceRecord :: try_from ( msg) . expect ( "Failed to create TraceRecord" ) ;
265+ assert_eq ! ( entry. msg_len, msg. len( ) ) ;
266+ assert_eq ! ( & entry. msg[ ..msg. len( ) ] , msg. as_bytes( ) ) ;
267+ assert ! ( entry. cycles > 0 ) ; // Ensure cycles is set
268+ }
269+
270+ /// Test TraceRecord creation with a message that exceeds the maximum length
271+ #[ test]
272+ fn test_trace_record_creation_too_long ( ) {
273+ let long_msg = "A" . repeat ( MAX_TRACE_MSG_LEN + 1 ) ;
274+ let result = TraceRecord :: try_from ( long_msg. as_str ( ) ) ;
275+ assert ! ( result. is_err( ) , "Expected error for message too long" ) ;
276+ }
277+ }
0 commit comments