virtio_net/
transport.rs

1use core::{mem::size_of, ptr::NonNull};
2use std::sync::Arc;
3
4use twizzler_abi::{
5    device::{bus::pcie::PcieDeviceInfo, DeviceInterruptFlags},
6    syscall::ThreadSync,
7};
8use twizzler_driver::{bus::pcie::PcieCapability, device::Device};
9use virtio_drivers::{
10    transport::{pci::VirtioPciError, DeviceStatus, DeviceType, InterruptStatus, Transport},
11    Error,
12};
13use virtio_pcie::{VirtioIsrStatus, VirtioPciNotifyCap};
14use volatile::{map_field, VolatilePtr};
15use zerocopy::{FromBytes, Immutable, IntoBytes};
16
17pub mod virtio_pcie;
18use self::virtio_pcie::{CfgLocation, VirtioCfgType, VirtioCommonCfg, VirtioPciCap};
19
20pub struct TwizzlerTransport {
21    device: Arc<Device>,
22
23    common_cfg: CfgLocation,
24
25    notify_region: CfgLocation,
26    notify_offset_multiplier: u32,
27
28    isr_status: CfgLocation,
29
30    config_space: Option<NonNull<[u32]>>,
31}
32
33unsafe impl Send for TwizzlerTransport {}
34
35fn get_device() -> Option<Device> {
36    let devices = devmgr::enumerate_devices(devmgr::DriverSpec {
37        supported: devmgr::Supported::PcieClass(2, 0, 0),
38    })
39    .ok()?;
40
41    for device in &devices {
42        let device = Device::new(device.id).ok();
43        if let Some(device) = device {
44            let info = unsafe { device.get_info::<PcieDeviceInfo>(0).unwrap() };
45            if info.get_data().vendor_id == 0x1AF4 {
46                tracing::info!(
47                    "found virtio-net controller at {:02x}:{:02x}.{:02x}",
48                    info.get_data().bus_nr,
49                    info.get_data().dev_nr,
50                    info.get_data().func_nr
51                );
52                return Some(device);
53            }
54        }
55    }
56    None
57}
58
59impl TwizzlerTransport {
60    pub fn new() -> Result<Self, VirtioPciError> {
61        let device = Arc::new(get_device().expect("failed to find virtio-net device"));
62        let int = device.allocate_interrupt(0).unwrap();
63        device
64            .repr_mut()
65            .register_interrupt(int.1 as usize, int.0, DeviceInterruptFlags::empty());
66
67        let info = unsafe { device.get_info::<PcieDeviceInfo>(0).unwrap() };
68        if info.get_data().vendor_id != 0x1AF4 {
69            tracing::trace!("Vendor ID: {}", info.get_data().vendor_id);
70            return Err(VirtioPciError::InvalidVendorId(info.get_data().vendor_id));
71        }
72
73        let mut common_cfg = None;
74        let mut notify_region = None;
75        let mut notify_offset_multiplier = 0;
76        let mut isr_status = None;
77        let mut config_space = None;
78
79        let mm = device.find_mmio_bar(0xff).unwrap();
80        for cap in device.pcie_capabilities(&mm).unwrap() {
81            let off: usize = match cap {
82                PcieCapability::VendorSpecific(x) => x,
83                _ => {
84                    continue;
85                }
86            };
87
88            let mut virtio_cfg_ref = unsafe { mm.get_mmio_offset_mut::<VirtioPciCap>(off) };
89            let virtio_cfg = virtio_cfg_ref.as_mut_ptr();
90            match map_field!(virtio_cfg.cfg_type).read() {
91                VirtioCfgType::CommonCfg if common_cfg.is_none() => {
92                    tracing::trace!(
93                        "Common CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
94                        map_field!(virtio_cfg.bar).read(),
95                        map_field!(virtio_cfg.offset).read(),
96                        map_field!(virtio_cfg.length).read()
97                    );
98                    common_cfg = Some(CfgLocation {
99                        bar: map_field!(virtio_cfg.bar).read() as usize,
100                        offset: map_field!(virtio_cfg.offset).read() as usize,
101                        length: map_field!(virtio_cfg.length).read() as usize,
102                    });
103                }
104                VirtioCfgType::NotifyCfg if notify_region.is_none() => {
105                    let mut notify_ref =
106                        unsafe { mm.get_mmio_offset_mut::<VirtioPciNotifyCap>(off) };
107                    let notify_cap = notify_ref.as_mut_ptr();
108                    notify_offset_multiplier = map_field!(notify_cap.notify_off_multiplier).read();
109                    tracing::trace!("Notify CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}, Offset multiplier: {:?}", map_field!(virtio_cfg.bar).read(), map_field!(virtio_cfg.offset).read(), map_field!(virtio_cfg.length).read(), notify_offset_multiplier);
110                    notify_region = Some(CfgLocation {
111                        bar: map_field!(virtio_cfg.bar).read() as usize,
112                        offset: map_field!(virtio_cfg.offset).read() as usize,
113                        length: map_field!(virtio_cfg.length).read() as usize,
114                    })
115                }
116
117                VirtioCfgType::IsrCfg if isr_status.is_none() => {
118                    tracing::trace!(
119                        "ISR CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
120                        map_field!(virtio_cfg.bar).read(),
121                        map_field!(virtio_cfg.offset).read(),
122                        map_field!(virtio_cfg.length).read()
123                    );
124                    isr_status = Some(CfgLocation {
125                        bar: map_field!(virtio_cfg.bar).read() as usize,
126                        offset: map_field!(virtio_cfg.offset).read() as usize,
127                        length: map_field!(virtio_cfg.length).read() as usize,
128                    });
129                }
130
131                VirtioCfgType::DeviceCfg if config_space.is_none() => {
132                    tracing::trace!(
133                        "Device CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
134                        map_field!(virtio_cfg.bar).read(),
135                        map_field!(virtio_cfg.offset).read(),
136                        map_field!(virtio_cfg.length).read()
137                    );
138                    let bar_num = map_field!(virtio_cfg.bar).read() as usize;
139                    let bar = device.find_mmio_bar(bar_num).unwrap();
140                    let mut start = unsafe {
141                        bar.get_mmio_offset_mut::<u32>(map_field!(virtio_cfg.offset).read() as usize)
142                    };
143                    let len = map_field!(virtio_cfg.length).read() as usize;
144
145                    let ptr = unsafe {
146                        NonNull::from(core::slice::from_raw_parts_mut(
147                            start.as_mut_ptr().as_raw_ptr().as_ptr(),
148                            len,
149                        ))
150                    };
151                    config_space = Some(ptr);
152                }
153                _ => {}
154            }
155        }
156        let common_cfg = common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?;
157        let notify_region = notify_region.ok_or(VirtioPciError::MissingNotifyConfig)?;
158        let isr_status = isr_status.ok_or(VirtioPciError::MissingIsrConfig)?;
159
160        Ok(Self {
161            device,
162            common_cfg,
163            notify_region,
164            notify_offset_multiplier,
165            isr_status,
166            config_space,
167        })
168    }
169
170    pub fn has_work(&self) -> bool {
171        self.device.repr().check_for_interrupt(0).is_some()
172    }
173
174    pub fn get_sleep(&self) -> ThreadSync {
175        ThreadSync::new_sleep(self.device.repr().setup_interrupt_sleep(0))
176    }
177
178    pub fn device(&self) -> Arc<Device> {
179        self.device.clone()
180    }
181}
182
183impl Transport for TwizzlerTransport {
184    fn device_type(&self) -> DeviceType {
185        device_type(
186            unsafe { self.device.get_info::<PcieDeviceInfo>(0) }
187                .unwrap()
188                .get_data()
189                .device_id,
190        )
191    }
192
193    fn read_device_features(&mut self) -> u64 {
194        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
195        let mut reference =
196            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
197        let ptr = reference.as_mut_ptr();
198
199        map_field!(ptr.device_feature_select).write(0);
200        let mut device_feature_bits = map_field!(ptr.device_feature).read() as u64;
201        map_field!(ptr.device_feature_select).write(1);
202        device_feature_bits |= (map_field!(ptr.device_feature).read() as u64) << 32;
203        device_feature_bits
204    }
205
206    fn write_driver_features(&mut self, driver_features: u64) {
207        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
208        let mut reference =
209            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
210        let ptr = reference.as_mut_ptr();
211
212        map_field!(ptr.driver_feature_select).write(0);
213        map_field!(ptr.driver_feature).write(driver_features as u32);
214        map_field!(ptr.driver_feature_select).write(1);
215        map_field!(ptr.driver_feature).write((driver_features >> 32) as u32);
216    }
217
218    fn max_queue_size(&mut self, queue: u16) -> u32 {
219        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
220        let mut reference =
221            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
222        let ptr = reference.as_mut_ptr();
223
224        map_field!(ptr.queue_select).write(queue);
225        map_field!(ptr.queue_size).read().into()
226    }
227
228    fn notify(&mut self, queue: u16) {
229        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
230        let mut reference =
231            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
232        let ptr = reference.as_mut_ptr();
233
234        map_field!(ptr.queue_select).write(queue);
235
236        let queue_notify_off = map_field!(ptr.queue_notify_off).read();
237
238        let offset_bytes = queue_notify_off as usize * self.notify_offset_multiplier as usize;
239        let index = offset_bytes / size_of::<u16>();
240
241        let notify_bar = self.device.find_mmio_bar(self.notify_region.bar).unwrap();
242        let start = unsafe {
243            notify_bar
244                .get_mmio_offset_mut::<u16>(self.notify_region.offset as usize)
245                .as_mut_ptr()
246                .as_raw_ptr()
247                .as_ptr()
248        };
249
250        let notify_ptr = unsafe {
251            VolatilePtr::new(NonNull::from(core::slice::from_raw_parts_mut(
252                start,
253                self.notify_region.length as usize,
254            )))
255        };
256
257        let to_write = notify_ptr.index(index);
258        to_write.write(queue);
259    }
260
261    fn get_status(&self) -> virtio_drivers::transport::DeviceStatus {
262        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
263        let mut reference =
264            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
265        let ptr = reference.as_mut_ptr();
266
267        let status = map_field!(ptr.device_status).read();
268        DeviceStatus::from_bits_truncate(status.into())
269    }
270
271    fn set_status(&mut self, status: virtio_drivers::transport::DeviceStatus) {
272        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
273        let mut reference =
274            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
275        let ptr = reference.as_mut_ptr();
276
277        map_field!(ptr.device_status).write(status.bits() as u8);
278    }
279
280    fn set_guest_page_size(&mut self, _guest_page_size: u32) {
281        // No-op, the PCI transport doesn't care.
282    }
283
284    fn requires_legacy_layout(&self) -> bool {
285        false
286    }
287
288    fn queue_set(
289        &mut self,
290        queue: u16,
291        size: u32,
292        descriptors: virtio_drivers::PhysAddr,
293        driver_area: virtio_drivers::PhysAddr,
294        device_area: virtio_drivers::PhysAddr,
295    ) {
296        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
297        let mut reference =
298            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
299        let ptr = reference.as_mut_ptr();
300
301        map_field!(ptr.config_msix_vector).write(0);
302        map_field!(ptr.queue_select).write(queue);
303        map_field!(ptr.queue_size).write(size as u16);
304        map_field!(ptr.queue_desc).write(descriptors.try_into().unwrap());
305        map_field!(ptr.queue_driver).write(driver_area.try_into().unwrap());
306        map_field!(ptr.queue_device).write(device_area.try_into().unwrap());
307        map_field!(ptr.queue_msix_vector).write(0);
308        map_field!(ptr.queue_enable).write(1);
309    }
310
311    fn queue_unset(&mut self, _queue: u16) {
312        // The VirtIO spec doesn't allow queues to be unset once they have been set up for the PCI
313        // transport, so this is a no-op.
314    }
315
316    fn queue_used(&mut self, queue: u16) -> bool {
317        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
318        let mut reference =
319            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
320        let ptr = reference.as_mut_ptr();
321
322        map_field!(ptr.queue_select).write(queue);
323        map_field!(ptr.queue_enable).read() == 1
324    }
325
326    fn ack_interrupt(&mut self) -> InterruptStatus {
327        let bar = self.device.find_mmio_bar(self.isr_status.bar).unwrap();
328        let mut reference =
329            unsafe { bar.get_mmio_offset_mut::<VirtioIsrStatus>(self.isr_status.offset) };
330        let ptr = reference.as_mut_ptr();
331
332        let status = ptr.read();
333        InterruptStatus::from_bits_truncate(status as u32)
334    }
335
336    fn read_config_generation(&self) -> u32 {
337        let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
338        let mut reference =
339            unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
340        let ptr = reference.as_mut_ptr();
341        map_field!(ptr.config_generation).read() as u32
342    }
343
344    fn read_config_space<T: FromBytes + IntoBytes>(
345        &self,
346        offset: usize,
347    ) -> virtio_drivers::Result<T> {
348        let config_space = self.config_space.ok_or(Error::ConfigSpaceMissing)?;
349        let config_len = config_space.len() * size_of::<u32>();
350        if offset + size_of::<T>() > config_len {
351            return Err(Error::ConfigSpaceTooSmall);
352        }
353        let ptr = config_space.as_ptr() as *const u8;
354        let src = unsafe { core::slice::from_raw_parts(ptr.add(offset), size_of::<T>()) };
355        T::read_from_bytes(src).map_err(|_| Error::ConfigSpaceTooSmall)
356    }
357
358    fn write_config_space<T: IntoBytes + Immutable>(
359        &mut self,
360        offset: usize,
361        value: T,
362    ) -> virtio_drivers::Result<()> {
363        let config_space = self.config_space.ok_or(Error::ConfigSpaceMissing)?;
364        let config_len = config_space.len() * size_of::<u32>();
365        if offset + size_of::<T>() > config_len {
366            return Err(Error::ConfigSpaceTooSmall);
367        }
368        let ptr = config_space.as_ptr() as *mut u8;
369        let dst = unsafe { core::slice::from_raw_parts_mut(ptr.add(offset), size_of::<T>()) };
370        value.write_to(dst).map_err(|_| Error::ConfigSpaceTooSmall)
371    }
372}
373
374impl Drop for TwizzlerTransport {
375    fn drop(&mut self) {
376        // Disable the device
377        self.set_status(DeviceStatus::empty());
378        while self.get_status() != DeviceStatus::empty() {
379            // Wait for the device to acknowledge the status change
380            core::hint::spin_loop();
381        }
382    }
383}
384
385/// The offset to add to a VirtIO device ID to get the corresponding PCI device ID.
386const PCI_DEVICE_ID_OFFSET: u16 = 0x1040;
387
388const TRANSITIONAL_NETWORK: u16 = 0x1000;
389const TRANSITIONAL_BLOCK: u16 = 0x1001;
390const TRANSITIONAL_MEMORY_BALLOONING: u16 = 0x1002;
391const TRANSITIONAL_CONSOLE: u16 = 0x1003;
392const TRANSITIONAL_SCSI_HOST: u16 = 0x1004;
393const TRANSITIONAL_ENTROPY_SOURCE: u16 = 0x1005;
394const TRANSITIONAL_9P_TRANSPORT: u16 = 0x1009;
395
396fn device_type(pci_device_id: u16) -> DeviceType {
397    match pci_device_id {
398        TRANSITIONAL_NETWORK => DeviceType::Network,
399        TRANSITIONAL_BLOCK => DeviceType::Block,
400        TRANSITIONAL_MEMORY_BALLOONING => DeviceType::MemoryBalloon,
401        TRANSITIONAL_CONSOLE => DeviceType::Console,
402        TRANSITIONAL_SCSI_HOST => DeviceType::ScsiHost,
403        TRANSITIONAL_ENTROPY_SOURCE => DeviceType::EntropySource,
404        TRANSITIONAL_9P_TRANSPORT => DeviceType::_9P,
405        id if id >= PCI_DEVICE_ID_OFFSET => {
406            DeviceType::try_from(id - PCI_DEVICE_ID_OFFSET).unwrap_or(DeviceType::Network)
407        }
408        _ => DeviceType::Network, // fallback; shouldn't be reached for our use case
409    }
410}