virtio_gpu/
transport.rs

1use core::{mem::size_of, ptr::NonNull};
2use std::sync::Arc;
3
4use twizzler_abi::{
5    device::{bus::pcie::PcieDeviceInfo, DeviceInterruptFlags},
6    syscall::{sys_thread_sync, ThreadSync},
7};
8use twizzler_driver::{bus::pcie::PcieCapability, device::Device};
9use virtio_drivers::{
10    transport::{pci::VirtioPciError, DeviceStatus, DeviceType, InterruptStatus, Transport},
11    Error,
12};
13use virtio_pcie::{VirtioIsrStatus, VirtioPciNotifyCap};
14use volatile::{map_field, VolatilePtr};
15use zerocopy::{FromBytes, Immutable, IntoBytes};
16
17pub mod virtio_pcie;
18use self::virtio_pcie::{CfgLocation, VirtioCfgType, VirtioCommonCfg, VirtioPciCap};
19
20pub struct TwizzlerTransport {
21    device: Arc<Device>,
22
23    common_cfg: CfgLocation,
24
25    notify_region: CfgLocation,
26    notify_offset_multiplier: u32,
27
28    isr_status: CfgLocation,
29
30    config_space: Option<NonNull<[u32]>>,
31}
32
33unsafe impl Send for TwizzlerTransport {}
34
35fn get_device() -> Option<Device> {
36    let devices = devmgr::get_devices(devmgr::DriverSpec {
37        supported: devmgr::Supported::PcieClass(3, 0, 0),
38    })
39    .ok()?;
40
41    for device in &devices {
42        let device = Device::new(device.id).ok();
43        if let Some(device) = device {
44            let info = unsafe { device.get_info::<PcieDeviceInfo>(0).unwrap() };
45            if info.get_data().vendor_id == 0x1AF4 && info.get_data().device_id == 0x1050 {
46                tracing::info!(
47                    "found virtio-gpu controller at {:02x}:{:02x}.{:02x}",
48                    info.get_data().bus_nr,
49                    info.get_data().dev_nr,
50                    info.get_data().func_nr
51                );
52                return Some(device);
53            }
54        }
55    }
56    None
57}
58
59impl TwizzlerTransport {
60    pub fn new(notifier: std::sync::mpsc::Sender<Option<()>>) -> Result<Self, VirtioPciError> {
61        let device = Arc::new(get_device().expect("failed to find virtio-gpu device"));
62        let int = device.allocate_interrupt(0).unwrap();
63        device
64            .repr_mut()
65            .register_interrupt(int.1 as usize, int.0, DeviceInterruptFlags::empty());
66        let int_device = device.clone();
67
68        let info = unsafe { device.get_info::<PcieDeviceInfo>(0).unwrap() };
69        if info.get_data().vendor_id != 0x1AF4 {
70            tracing::trace!("Vendor ID: {}", info.get_data().vendor_id);
71            return Err(VirtioPciError::InvalidVendorId(info.get_data().vendor_id));
72        }
73
74        let mut common_cfg = None;
75        let mut notify_region = None;
76        let mut notify_offset_multiplier = 0;
77        let mut isr_status = None;
78        let mut config_space = None;
79
80        let mm = device.find_mmio_bar(0xff).unwrap();
81        for cap in device.pcie_capabilities(&mm).unwrap() {
82            let off: usize = match cap {
83                PcieCapability::VendorSpecific(x) => x,
84                _ => {
85                    continue;
86                }
87            };
88
89            let mut virtio_cfg_ref = unsafe { mm.get_mmio_offset_mut::<VirtioPciCap>(off) };
90            let virtio_cfg = virtio_cfg_ref.as_mut_ptr();
91            match map_field!(virtio_cfg.cfg_type).read() {
92                VirtioCfgType::CommonCfg if common_cfg.is_none() => {
93                    tracing::trace!(
94                        "Common CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
95                        map_field!(virtio_cfg.bar).read(),
96                        map_field!(virtio_cfg.offset).read(),
97                        map_field!(virtio_cfg.length).read()
98                    );
99                    common_cfg = Some(CfgLocation {
100                        bar: map_field!(virtio_cfg.bar).read() as usize,
101                        offset: map_field!(virtio_cfg.offset).read() as usize,
102                        length: map_field!(virtio_cfg.length).read() as usize,
103                    });
104                }
105                VirtioCfgType::NotifyCfg if notify_region.is_none() => {
106                    let mut notify_ref =
107                        unsafe { mm.get_mmio_offset_mut::<VirtioPciNotifyCap>(off) };
108                    let notify_cap = notify_ref.as_mut_ptr();
109                    notify_offset_multiplier = map_field!(notify_cap.notify_off_multiplier).read();
110                    tracing::trace!("Notify CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}, Offset multiplier: {:?}", map_field!(virtio_cfg.bar).read(), map_field!(virtio_cfg.offset).read(), map_field!(virtio_cfg.length).read(), notify_offset_multiplier);
111                    notify_region = Some(CfgLocation {
112                        bar: map_field!(virtio_cfg.bar).read() as usize,
113                        offset: map_field!(virtio_cfg.offset).read() as usize,
114                        length: map_field!(virtio_cfg.length).read() as usize,
115                    })
116                }
117
118                VirtioCfgType::IsrCfg if isr_status.is_none() => {
119                    tracing::trace!(
120                        "ISR CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
121                        map_field!(virtio_cfg.bar).read(),
122                        map_field!(virtio_cfg.offset).read(),
123                        map_field!(virtio_cfg.length).read()
124                    );
125                    isr_status = Some(CfgLocation {
126                        bar: map_field!(virtio_cfg.bar).read() as usize,
127                        offset: map_field!(virtio_cfg.offset).read() as usize,
128                        length: map_field!(virtio_cfg.length).read() as usize,
129                    });
130                }
131
132                VirtioCfgType::DeviceCfg if config_space.is_none() => {
133                    tracing::trace!(
134                        "Device CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
135                        map_field!(virtio_cfg.bar).read(),
136                        map_field!(virtio_cfg.offset).read(),
137                        map_field!(virtio_cfg.length).read()
138                    );
139                    let bar_num = map_field!(virtio_cfg.bar).read() as usize;
140                    let bar = device.find_mmio_bar(bar_num).unwrap();
141                    let mut start = unsafe {
142                        bar.get_mmio_offset_mut::<u32>(map_field!(virtio_cfg.offset).read() as usize)
143                    };
144                    let len = map_field!(virtio_cfg.length).read() as usize;
145
146                    let ptr = unsafe {
147                        NonNull::from(core::slice::from_raw_parts_mut(
148                            start.as_mut_ptr().as_raw_ptr().as_ptr(),
149                            len,
150                        ))
151                    };
152                    config_space = Some(ptr);
153                }
154                _ => {}
155            }
156        }
157        let common_cfg = common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?;
158        let notify_region = notify_region.ok_or(VirtioPciError::MissingNotifyConfig)?;
159        let isr_status = isr_status.ok_or(VirtioPciError::MissingIsrConfig)?;
160
161        let _thread = std::thread::spawn(move || loop {
162            //for _ in 0..10 {
163            //    for _ in 0..100 {
164            if int_device.repr().check_for_interrupt(0).is_some() {
165                //tracing::info!("virtio int");
166                let _ = notifier.send(None);
167            }
168            //      core::hint::spin_loop();
169            //  }
170            // twizzler_abi::syscall::sys_thread_yield();
171            // }
172
173            if int_device.repr().check_for_interrupt(0).is_none() {
174                let int_sleep = int_device.repr().setup_interrupt_sleep(0);
175                //tracing::info!("virtio int: sleep");
176                let _ = sys_thread_sync(&mut [ThreadSync::new_sleep(int_sleep)], None);
177            }
178        });
179
180        Ok(Self {
181            device,
182            common_cfg,
183            notify_region,
184            notify_offset_multiplier,
185            isr_status,
186            config_space,
187        })
188    }
189}
190
191impl Transport for TwizzlerTransport {
192    fn device_type(&self) -> DeviceType {
193        device_type(
194            unsafe { self.device.get_info::<PcieDeviceInfo>(0) }
195                .unwrap()
196                .get_data()
197                .device_id,
198        )
199    }
200
201    fn read_device_features(&mut self) -> u64 {
202        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
203        let mut reference =
204            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
205        let ptr = reference.as_mut_ptr();
206
207        map_field!(ptr.device_feature_select).write(0);
208        let mut device_feature_bits = map_field!(ptr.device_feature).read() as u64;
209        map_field!(ptr.device_feature_select).write(1);
210        device_feature_bits |= (map_field!(ptr.device_feature).read() as u64) << 32;
211        device_feature_bits
212    }
213
214    fn write_driver_features(&mut self, driver_features: u64) {
215        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
216        let mut reference =
217            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
218        let ptr = reference.as_mut_ptr();
219
220        map_field!(ptr.driver_feature_select).write(0);
221        map_field!(ptr.driver_feature).write(driver_features as u32);
222        map_field!(ptr.driver_feature_select).write(1);
223        map_field!(ptr.driver_feature).write((driver_features >> 32) as u32);
224    }
225
226    fn max_queue_size(&mut self, queue: u16) -> u32 {
227        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
228        let mut reference =
229            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
230        let ptr = reference.as_mut_ptr();
231
232        map_field!(ptr.queue_select).write(queue);
233        map_field!(ptr.queue_size).read().into()
234    }
235
236    fn notify(&mut self, queue: u16) {
237        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
238        let mut reference =
239            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
240        let ptr = reference.as_mut_ptr();
241
242        map_field!(ptr.queue_select).write(queue);
243
244        let queue_notify_off = map_field!(ptr.queue_notify_off).read();
245
246        let offset_bytes = queue_notify_off as usize * self.notify_offset_multiplier as usize;
247        let index = offset_bytes / size_of::<u16>();
248
249        let notify_bar = self.device.find_mmio_bar(self.notify_region.bar).unwrap();
250        let start = unsafe {
251            notify_bar
252                .get_mmio_offset_mut::<u16>(self.notify_region.offset as usize)
253                .as_mut_ptr()
254                .as_raw_ptr()
255                .as_ptr()
256        };
257
258        let notify_ptr = unsafe {
259            VolatilePtr::new(NonNull::from(core::slice::from_raw_parts_mut(
260                start,
261                self.notify_region.length as usize,
262            )))
263        };
264
265        let to_write = notify_ptr.index(index);
266        to_write.write(queue);
267    }
268
269    fn get_status(&self) -> virtio_drivers::transport::DeviceStatus {
270        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
271        let mut reference =
272            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
273        let ptr = reference.as_mut_ptr();
274
275        let status = map_field!(ptr.device_status).read();
276        DeviceStatus::from_bits_truncate(status.into())
277    }
278
279    fn set_status(&mut self, status: virtio_drivers::transport::DeviceStatus) {
280        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
281        let mut reference =
282            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
283        let ptr = reference.as_mut_ptr();
284
285        map_field!(ptr.device_status).write(status.bits() as u8);
286    }
287
288    fn set_guest_page_size(&mut self, _guest_page_size: u32) {
289        // No-op, the PCI transport doesn't care.
290    }
291
292    fn requires_legacy_layout(&self) -> bool {
293        false
294    }
295
296    fn queue_set(
297        &mut self,
298        queue: u16,
299        size: u32,
300        descriptors: virtio_drivers::PhysAddr,
301        driver_area: virtio_drivers::PhysAddr,
302        device_area: virtio_drivers::PhysAddr,
303    ) {
304        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
305        let mut reference =
306            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
307        let ptr = reference.as_mut_ptr();
308
309        map_field!(ptr.config_msix_vector).write(0);
310        map_field!(ptr.queue_select).write(queue);
311        map_field!(ptr.queue_size).write(size as u16);
312        map_field!(ptr.queue_desc).write(descriptors.try_into().unwrap());
313        map_field!(ptr.queue_driver).write(driver_area.try_into().unwrap());
314        map_field!(ptr.queue_device).write(device_area.try_into().unwrap());
315        map_field!(ptr.queue_msix_vector).write(0);
316        map_field!(ptr.queue_enable).write(1);
317    }
318
319    fn queue_unset(&mut self, _queue: u16) {
320        // The VirtIO spec doesn't allow queues to be unset once they have been set up for the PCI
321        // transport, so this is a no-op.
322    }
323
324    fn queue_used(&mut self, queue: u16) -> bool {
325        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
326        let mut reference =
327            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
328        let ptr = reference.as_mut_ptr();
329
330        map_field!(ptr.queue_select).write(queue);
331        map_field!(ptr.queue_enable).read() == 1
332    }
333
334    fn ack_interrupt(&mut self) -> InterruptStatus {
335        let bar = self.device.find_mmio_bar(self.isr_status.bar).unwrap();
336        let mut reference =
337            unsafe { bar.get_mmio_offset_mut::<VirtioIsrStatus>(self.isr_status.offset) };
338        let ptr = reference.as_mut_ptr();
339
340        let status = ptr.read();
341        if status & 0x3 != 0 {
342            InterruptStatus::all()
343        } else {
344            InterruptStatus::empty()
345        }
346    }
347
348    fn read_config_generation(&self) -> u32 {
349        0
350    }
351
352    fn read_config_space<T: FromBytes + IntoBytes>(
353        &self,
354        offset: usize,
355    ) -> virtio_drivers::Result<T> {
356        if let Some(config_space) = self.config_space {
357            if offset > config_space.len() * size_of::<u32>() {
358                Err(Error::ConfigSpaceTooSmall)
359            } else {
360                // TODO: Use NonNull::as_non_null_ptr once it is stable.
361                let result = unsafe {
362                    config_space
363                        .as_ptr()
364                        .cast::<T>()
365                        .byte_add(offset)
366                        .read_volatile()
367                };
368                Ok(result)
369            }
370        } else {
371            Err(Error::ConfigSpaceMissing)
372        }
373    }
374
375    fn write_config_space<T: IntoBytes + Immutable>(
376        &mut self,
377        offset: usize,
378        value: T,
379    ) -> virtio_drivers::Result<()> {
380        if let Some(config_space) = self.config_space {
381            if offset > config_space.len() * size_of::<u32>() {
382                Err(Error::ConfigSpaceTooSmall)
383            } else {
384                // TODO: Use NonNull::as_non_null_ptr once it is stable.
385                unsafe {
386                    config_space
387                        .as_ptr()
388                        .cast::<T>()
389                        .byte_add(offset)
390                        .write_volatile(value)
391                };
392                Ok(())
393            }
394        } else {
395            Err(Error::ConfigSpaceMissing)
396        }
397    }
398}
399
400impl Drop for TwizzlerTransport {
401    fn drop(&mut self) {
402        // Disable the device
403        self.set_status(DeviceStatus::empty());
404        while self.get_status() != DeviceStatus::empty() {
405            // Wait for the device to acknowledge the status change
406            core::hint::spin_loop();
407        }
408    }
409}
410
411/// The offset to add to a VirtIO device ID to get the corresponding PCI device ID.
412const PCI_DEVICE_ID_OFFSET: u16 = 0x1040;
413
414const TRANSITIONAL_NETWORK: u16 = 0x1000;
415const TRANSITIONAL_BLOCK: u16 = 0x1001;
416const TRANSITIONAL_MEMORY_BALLOONING: u16 = 0x1002;
417const TRANSITIONAL_CONSOLE: u16 = 0x1003;
418const TRANSITIONAL_SCSI_HOST: u16 = 0x1004;
419const TRANSITIONAL_ENTROPY_SOURCE: u16 = 0x1005;
420const TRANSITIONAL_9P_TRANSPORT: u16 = 0x1009;
421
422fn device_type(pci_device_id: u16) -> DeviceType {
423    match pci_device_id {
424        TRANSITIONAL_NETWORK => DeviceType::Network,
425        TRANSITIONAL_BLOCK => DeviceType::Block,
426        TRANSITIONAL_MEMORY_BALLOONING => DeviceType::MemoryBalloon,
427        TRANSITIONAL_CONSOLE => DeviceType::Console,
428        TRANSITIONAL_SCSI_HOST => DeviceType::ScsiHost,
429        TRANSITIONAL_ENTROPY_SOURCE => DeviceType::EntropySource,
430        TRANSITIONAL_9P_TRANSPORT => DeviceType::_9P,
431        id if id >= PCI_DEVICE_ID_OFFSET => {
432            DeviceType::try_from(id - PCI_DEVICE_ID_OFFSET).unwrap()
433        }
434        _ => todo!(),
435    }
436}