twizzler_io/
buffer.rs

1use std::{
2    cell::UnsafeCell,
3    sync::atomic::{AtomicU64, Ordering},
4};
5
6use twizzler_abi::syscall::{
7    ThreadSync, ThreadSyncFlags, ThreadSyncOp, ThreadSyncReference, ThreadSyncSleep, ThreadSyncWake,
8};
9
10pub struct VolatileBuffer<const N: usize> {
11    reserve: AtomicU64,
12    head: AtomicU64,
13    tail: AtomicU64,
14    buffer: UnsafeCell<[u8; N]>,
15}
16unsafe impl<const N: usize> Send for VolatileBuffer<N> {}
17unsafe impl<const N: usize> Sync for VolatileBuffer<N> {}
18
19impl<const N: usize> VolatileBuffer<N> {
20    pub fn new() -> Self {
21        Self {
22            buffer: UnsafeCell::new([0; N]),
23            head: AtomicU64::new(0),
24            tail: AtomicU64::new(0),
25            reserve: AtomicU64::new(0),
26        }
27    }
28
29    pub fn avail_space(&self) -> usize {
30        let tail = self.tail.load(Ordering::SeqCst);
31        let resv = self.reserve.load(Ordering::SeqCst);
32
33        (N - 1) - (resv - tail) as usize
34    }
35
36    pub fn pending_bytes(&self) -> usize {
37        let head = self.head.load(Ordering::SeqCst);
38        let tail = self.tail.load(Ordering::SeqCst);
39
40        (head - tail) as usize
41    }
42
43    pub fn is_empty(&self) -> bool {
44        let tail = self.tail.load(Ordering::SeqCst);
45        let head = self.head.load(Ordering::SeqCst);
46
47        head == tail
48    }
49
50    pub fn sync_for_pending_data(&self) -> ThreadSyncSleep {
51        let head = self.head.load(Ordering::SeqCst);
52        ThreadSyncSleep::new(
53            ThreadSyncReference::Virtual(&self.head),
54            head,
55            ThreadSyncOp::Equal,
56            ThreadSyncFlags::empty(),
57        )
58    }
59
60    pub fn sync_for_avail_space(&self) -> ThreadSyncSleep {
61        let tail = self.tail.load(Ordering::SeqCst);
62        ThreadSyncSleep::new(
63            ThreadSyncReference::Virtual(&self.tail),
64            tail,
65            ThreadSyncOp::Equal,
66            ThreadSyncFlags::empty(),
67        )
68    }
69
70    pub fn read_bytes(&self, mut buf: &mut [u8]) -> std::io::Result<usize> {
71        let mut count = 0;
72        while buf.len() > 0 {
73            let head = self.head.load(Ordering::SeqCst);
74            let tail = self.tail.load(Ordering::SeqCst);
75
76            // Empty
77            if tail == head {
78                return Ok(count);
79            }
80
81            assert!(head >= tail);
82            let n = std::cmp::min(buf.len(), (head - tail) as usize);
83            let n = self.read_from_circle(&mut buf[0..n], tail as usize % N);
84
85            if self
86                .tail
87                .compare_exchange(tail, tail + n as u64, Ordering::SeqCst, Ordering::SeqCst)
88                .is_err()
89            {
90                continue;
91            }
92            self.do_wake(&self.tail);
93            buf = &mut buf[n..];
94            count += n;
95        }
96        Ok(count)
97    }
98
99    pub fn write_bytes(&self, mut buf: &[u8]) -> std::io::Result<usize> {
100        let mut count = 0;
101        while buf.len() > 0 {
102            let resv = self.reserve.load(Ordering::SeqCst);
103            let tail = self.tail.load(Ordering::SeqCst);
104
105            let avail = (N - 1) - (resv - tail) as usize;
106            if avail == 0 {
107                return Ok(count);
108            }
109
110            let n = std::cmp::min(buf.len(), avail);
111
112            // Step 1: reserve space
113            if self
114                .reserve
115                .compare_exchange(resv, resv + n as u64, Ordering::SeqCst, Ordering::SeqCst)
116                .is_err()
117            {
118                // Someone else reserved space. Try again.
119                continue;
120            }
121
122            // Step 2: wait until our head catches up to the old reserve. Note that since
123            // we succeeded the compare-exchange above, we have to complete this operation
124            // for the pty to remain in a consistent state.
125            while self.head.load(Ordering::SeqCst) != resv {
126                core::hint::spin_loop();
127            }
128
129            let n = self.write_to_circle(&buf[0..n], resv as usize % N);
130
131            let old_head = self.head.fetch_add(n as u64, Ordering::SeqCst);
132            if old_head != resv {
133                tracing::warn!("head incremented unexpectedly ({} != {})", old_head, resv);
134            }
135            self.do_wake(&self.head);
136
137            buf = &buf[n..];
138            count += n;
139        }
140        Ok(count)
141    }
142
143    fn get_buffer(&self) -> &[u8] {
144        let ptr = self.buffer.get();
145        unsafe { ptr.as_ref().unwrap() }
146    }
147
148    fn get_buffer_mut(&self) -> &mut [u8] {
149        let ptr = self.buffer.get();
150        unsafe { ptr.as_mut().unwrap() }
151    }
152
153    fn read_from_circle(&self, buf: &mut [u8], phase: usize) -> usize {
154        let buffer = self.get_buffer();
155        let (second, first) = buffer.split_at(phase);
156        let first_len = first.len().min(buf.len());
157        let second_len = second.len().min(buf.len().saturating_sub(first_len));
158
159        (&mut buf[0..first_len]).copy_from_slice(&first[0..first_len]);
160        (&mut buf[first_len..(first_len + second_len)]).copy_from_slice(&second[0..second_len]);
161        return first_len + second_len;
162    }
163
164    fn write_to_circle(&self, buf: &[u8], phase: usize) -> usize {
165        let buffer = self.get_buffer_mut();
166        let (second, first) = buffer.split_at_mut(phase);
167        let first_len = first.len().min(buf.len());
168        let second_len = second.len().min(buf.len().saturating_sub(first_len));
169
170        (&mut first[0..first_len]).copy_from_slice(&buf[0..first_len]);
171        (&mut second[0..second_len]).copy_from_slice(&buf[first_len..(first_len + second_len)]);
172        return first_len + second_len;
173    }
174
175    fn do_wake(&self, ptr: &AtomicU64) {
176        let _ = twizzler_abi::syscall::sys_thread_sync(
177            &mut [ThreadSync::new_wake(ThreadSyncWake::new(
178                ThreadSyncReference::Virtual(ptr),
179                usize::MAX,
180            ))],
181            None,
182        )
183        .inspect_err(|e| tracing::error!("failed to wake on volatile buffer: {}", e));
184    }
185}
186
187#[cfg(test)]
188pub mod tests {
189    use std::sync::{
190        Arc,
191        atomic::{AtomicBool, AtomicUsize},
192    };
193
194    use crate::buffer::VolatileBuffer;
195
196    #[test]
197    pub fn test_basic() {
198        let vb = VolatileBuffer::<2048>::new();
199
200        let mut buf = [0; 1024];
201        assert_eq!(vb.read_bytes(&mut buf).unwrap(), 0);
202
203        for i in 0..100 {
204            buf.fill(i);
205
206            assert_eq!(vb.write_bytes(&buf).unwrap(), 1024);
207            assert_eq!(vb.read_bytes(&mut buf).unwrap(), 1024);
208            assert_eq!(buf, [i; 1024]);
209        }
210    }
211
212    #[test]
213    pub fn test_mt() {
214        const ITER: usize = 100;
215        const BS: usize = 1;
216        const NR_TH: usize = 8;
217        std::thread::scope(|scope| {
218            let vb = Arc::new(VolatileBuffer::<2048>::new());
219
220            let counts = Arc::new([const { AtomicUsize::new(0) }; NR_TH]);
221            let wcounts = counts.clone();
222            let done = Arc::new(AtomicBool::new(false));
223            tracing::info!("starting mt pty test");
224
225            let reader = move |done: &AtomicBool, pty: &VolatileBuffer<_>| {
226                let do_read = || -> usize {
227                    let mut buf = [0; 8];
228                    let len = pty.read_bytes(&mut buf).unwrap();
229                    if len > 0 {
230                        tracing::info!("rr: {} {}", len, buf[0]);
231                    }
232                    for b in &buf[0..len] {
233                        let idx = *b as usize;
234                        tracing::info!("      => {}", idx);
235                        wcounts[idx].fetch_add(1, std::sync::atomic::Ordering::SeqCst);
236                    }
237                    len
238                };
239                while !done.load(std::sync::atomic::Ordering::SeqCst) {
240                    do_read();
241                }
242                while do_read() > 0 {}
243            };
244
245            let writer = |pty: &VolatileBuffer<_>, c: u8| {
246                for i in 0..ITER {
247                    let buf = [c; BS];
248                    tracing::info!("ww: {} {}", c, i);
249                    let mut len = pty.write_bytes(&buf).unwrap();
250                    while len == 0 {
251                        tracing::info!("{} had to retry", c);
252                        len = pty.write_bytes(&buf).unwrap();
253                    }
254                }
255            };
256
257            let wpty = vb.clone();
258            let wdone = done.clone();
259            let rd = scope.spawn(move || reader(&wdone, &*wpty));
260            let ws = (0..NR_TH)
261                .map(|i| {
262                    let pty = vb.clone();
263                    scope.spawn(move || writer(&pty, i as u8))
264                })
265                .collect::<Vec<_>>();
266
267            for t in ws {
268                t.join().unwrap();
269            }
270            done.store(true, std::sync::atomic::Ordering::SeqCst);
271            rd.join().unwrap();
272
273            let expected = ITER * BS;
274            for count in (&*counts).iter().enumerate() {
275                let nr = count.1.load(std::sync::atomic::Ordering::SeqCst);
276                if nr != expected {
277                    tracing::warn!("{}: found wrong count: {} {}", count.0, nr, expected);
278                }
279            }
280        });
281        tracing::info!("finished mt pty test");
282    }
283}