1use std::{
2 cell::UnsafeCell,
3 sync::atomic::{AtomicU64, Ordering},
4};
5
6use twizzler_abi::syscall::{
7 ThreadSync, ThreadSyncFlags, ThreadSyncOp, ThreadSyncReference, ThreadSyncSleep, ThreadSyncWake,
8};
9
10pub struct VolatileBuffer<const N: usize> {
11 reserve: AtomicU64,
12 head: AtomicU64,
13 tail: AtomicU64,
14 buffer: UnsafeCell<[u8; N]>,
15}
16unsafe impl<const N: usize> Send for VolatileBuffer<N> {}
17unsafe impl<const N: usize> Sync for VolatileBuffer<N> {}
18
19impl<const N: usize> VolatileBuffer<N> {
20 pub fn new() -> Self {
21 Self {
22 buffer: UnsafeCell::new([0; N]),
23 head: AtomicU64::new(0),
24 tail: AtomicU64::new(0),
25 reserve: AtomicU64::new(0),
26 }
27 }
28
29 pub fn avail_space(&self) -> usize {
30 let tail = self.tail.load(Ordering::SeqCst);
31 let resv = self.reserve.load(Ordering::SeqCst);
32
33 (N - 1) - (resv - tail) as usize
34 }
35
36 pub fn pending_bytes(&self) -> usize {
37 let head = self.head.load(Ordering::SeqCst);
38 let tail = self.tail.load(Ordering::SeqCst);
39
40 (head - tail) as usize
41 }
42
43 pub fn is_empty(&self) -> bool {
44 let tail = self.tail.load(Ordering::SeqCst);
45 let head = self.head.load(Ordering::SeqCst);
46
47 head == tail
48 }
49
50 pub fn sync_for_pending_data(&self) -> ThreadSyncSleep {
51 let head = self.head.load(Ordering::SeqCst);
52 ThreadSyncSleep::new(
53 ThreadSyncReference::Virtual(&self.head),
54 head,
55 ThreadSyncOp::Equal,
56 ThreadSyncFlags::empty(),
57 )
58 }
59
60 pub fn sync_for_avail_space(&self) -> ThreadSyncSleep {
61 let tail = self.tail.load(Ordering::SeqCst);
62 ThreadSyncSleep::new(
63 ThreadSyncReference::Virtual(&self.tail),
64 tail,
65 ThreadSyncOp::Equal,
66 ThreadSyncFlags::empty(),
67 )
68 }
69
70 pub fn read_bytes(&self, mut buf: &mut [u8]) -> std::io::Result<usize> {
71 let mut count = 0;
72 while buf.len() > 0 {
73 let head = self.head.load(Ordering::SeqCst);
74 let tail = self.tail.load(Ordering::SeqCst);
75
76 if tail == head {
78 return Ok(count);
79 }
80
81 assert!(head >= tail);
82 let n = std::cmp::min(buf.len(), (head - tail) as usize);
83 let n = self.read_from_circle(&mut buf[0..n], tail as usize % N);
84
85 if self
86 .tail
87 .compare_exchange(tail, tail + n as u64, Ordering::SeqCst, Ordering::SeqCst)
88 .is_err()
89 {
90 continue;
91 }
92 self.do_wake(&self.tail);
93 buf = &mut buf[n..];
94 count += n;
95 }
96 Ok(count)
97 }
98
99 pub fn write_bytes(&self, mut buf: &[u8]) -> std::io::Result<usize> {
100 let mut count = 0;
101 while buf.len() > 0 {
102 let resv = self.reserve.load(Ordering::SeqCst);
103 let tail = self.tail.load(Ordering::SeqCst);
104
105 let avail = (N - 1) - (resv - tail) as usize;
106 if avail == 0 {
107 return Ok(count);
108 }
109
110 let n = std::cmp::min(buf.len(), avail);
111
112 if self
114 .reserve
115 .compare_exchange(resv, resv + n as u64, Ordering::SeqCst, Ordering::SeqCst)
116 .is_err()
117 {
118 continue;
120 }
121
122 while self.head.load(Ordering::SeqCst) != resv {
126 core::hint::spin_loop();
127 }
128
129 let n = self.write_to_circle(&buf[0..n], resv as usize % N);
130
131 let old_head = self.head.fetch_add(n as u64, Ordering::SeqCst);
132 if old_head != resv {
133 tracing::warn!("head incremented unexpectedly ({} != {})", old_head, resv);
134 }
135 self.do_wake(&self.head);
136
137 buf = &buf[n..];
138 count += n;
139 }
140 Ok(count)
141 }
142
143 fn get_buffer(&self) -> &[u8] {
144 let ptr = self.buffer.get();
145 unsafe { ptr.as_ref().unwrap() }
146 }
147
148 fn get_buffer_mut(&self) -> &mut [u8] {
149 let ptr = self.buffer.get();
150 unsafe { ptr.as_mut().unwrap() }
151 }
152
153 fn read_from_circle(&self, buf: &mut [u8], phase: usize) -> usize {
154 let buffer = self.get_buffer();
155 let (second, first) = buffer.split_at(phase);
156 let first_len = first.len().min(buf.len());
157 let second_len = second.len().min(buf.len().saturating_sub(first_len));
158
159 (&mut buf[0..first_len]).copy_from_slice(&first[0..first_len]);
160 (&mut buf[first_len..(first_len + second_len)]).copy_from_slice(&second[0..second_len]);
161 return first_len + second_len;
162 }
163
164 fn write_to_circle(&self, buf: &[u8], phase: usize) -> usize {
165 let buffer = self.get_buffer_mut();
166 let (second, first) = buffer.split_at_mut(phase);
167 let first_len = first.len().min(buf.len());
168 let second_len = second.len().min(buf.len().saturating_sub(first_len));
169
170 (&mut first[0..first_len]).copy_from_slice(&buf[0..first_len]);
171 (&mut second[0..second_len]).copy_from_slice(&buf[first_len..(first_len + second_len)]);
172 return first_len + second_len;
173 }
174
175 fn do_wake(&self, ptr: &AtomicU64) {
176 let _ = twizzler_abi::syscall::sys_thread_sync(
177 &mut [ThreadSync::new_wake(ThreadSyncWake::new(
178 ThreadSyncReference::Virtual(ptr),
179 usize::MAX,
180 ))],
181 None,
182 )
183 .inspect_err(|e| tracing::error!("failed to wake on volatile buffer: {}", e));
184 }
185}
186
187#[cfg(test)]
188pub mod tests {
189 use std::sync::{
190 Arc,
191 atomic::{AtomicBool, AtomicUsize},
192 };
193
194 use crate::buffer::VolatileBuffer;
195
196 #[test]
197 pub fn test_basic() {
198 let vb = VolatileBuffer::<2048>::new();
199
200 let mut buf = [0; 1024];
201 assert_eq!(vb.read_bytes(&mut buf).unwrap(), 0);
202
203 for i in 0..100 {
204 buf.fill(i);
205
206 assert_eq!(vb.write_bytes(&buf).unwrap(), 1024);
207 assert_eq!(vb.read_bytes(&mut buf).unwrap(), 1024);
208 assert_eq!(buf, [i; 1024]);
209 }
210 }
211
212 #[test]
213 pub fn test_mt() {
214 const ITER: usize = 100;
215 const BS: usize = 1;
216 const NR_TH: usize = 8;
217 std::thread::scope(|scope| {
218 let vb = Arc::new(VolatileBuffer::<2048>::new());
219
220 let counts = Arc::new([const { AtomicUsize::new(0) }; NR_TH]);
221 let wcounts = counts.clone();
222 let done = Arc::new(AtomicBool::new(false));
223 tracing::info!("starting mt pty test");
224
225 let reader = move |done: &AtomicBool, pty: &VolatileBuffer<_>| {
226 let do_read = || -> usize {
227 let mut buf = [0; 8];
228 let len = pty.read_bytes(&mut buf).unwrap();
229 if len > 0 {
230 tracing::info!("rr: {} {}", len, buf[0]);
231 }
232 for b in &buf[0..len] {
233 let idx = *b as usize;
234 tracing::info!(" => {}", idx);
235 wcounts[idx].fetch_add(1, std::sync::atomic::Ordering::SeqCst);
236 }
237 len
238 };
239 while !done.load(std::sync::atomic::Ordering::SeqCst) {
240 do_read();
241 }
242 while do_read() > 0 {}
243 };
244
245 let writer = |pty: &VolatileBuffer<_>, c: u8| {
246 for i in 0..ITER {
247 let buf = [c; BS];
248 tracing::info!("ww: {} {}", c, i);
249 let mut len = pty.write_bytes(&buf).unwrap();
250 while len == 0 {
251 tracing::info!("{} had to retry", c);
252 len = pty.write_bytes(&buf).unwrap();
253 }
254 }
255 };
256
257 let wpty = vb.clone();
258 let wdone = done.clone();
259 let rd = scope.spawn(move || reader(&wdone, &*wpty));
260 let ws = (0..NR_TH)
261 .map(|i| {
262 let pty = vb.clone();
263 scope.spawn(move || writer(&pty, i as u8))
264 })
265 .collect::<Vec<_>>();
266
267 for t in ws {
268 t.join().unwrap();
269 }
270 done.store(true, std::sync::atomic::Ordering::SeqCst);
271 rd.join().unwrap();
272
273 let expected = ITER * BS;
274 for count in (&*counts).iter().enumerate() {
275 let nr = count.1.load(std::sync::atomic::Ordering::SeqCst);
276 if nr != expected {
277 tracing::warn!("{}: found wrong count: {} {}", count.0, nr, expected);
278 }
279 }
280 });
281 tracing::info!("finished mt pty test");
282 }
283}