1use core::{mem::size_of, ptr::NonNull};
2use std::sync::Arc;
3
4use twizzler_abi::{
5 device::{bus::pcie::PcieDeviceInfo, DeviceInterruptFlags},
6 syscall::{sys_thread_sync, ThreadSync},
7};
8use twizzler_driver::{bus::pcie::PcieCapability, device::Device};
9use virtio_drivers::{
10 transport::{pci::VirtioPciError, DeviceStatus, DeviceType, InterruptStatus, Transport},
11 Error,
12};
13use virtio_pcie::{VirtioIsrStatus, VirtioPciNotifyCap};
14use volatile::{map_field, VolatilePtr};
15use zerocopy::{FromBytes, Immutable, IntoBytes};
16
17pub mod virtio_pcie;
18use self::virtio_pcie::{CfgLocation, VirtioCfgType, VirtioCommonCfg, VirtioPciCap};
19
20pub struct TwizzlerTransport {
21 device: Arc<Device>,
22
23 common_cfg: CfgLocation,
24
25 notify_region: CfgLocation,
26 notify_offset_multiplier: u32,
27
28 isr_status: CfgLocation,
29
30 config_space: Option<NonNull<[u32]>>,
31}
32
33unsafe impl Send for TwizzlerTransport {}
34
35fn get_device() -> Option<Device> {
36 let devices = devmgr::get_devices(devmgr::DriverSpec {
37 supported: devmgr::Supported::PcieClass(3, 0, 0),
38 })
39 .ok()?;
40
41 for device in &devices {
42 let device = Device::new(device.id).ok();
43 if let Some(device) = device {
44 let info = unsafe { device.get_info::<PcieDeviceInfo>(0).unwrap() };
45 if info.get_data().vendor_id == 0x1AF4 && info.get_data().device_id == 0x1050 {
46 tracing::info!(
47 "found virtio-gpu controller at {:02x}:{:02x}.{:02x}",
48 info.get_data().bus_nr,
49 info.get_data().dev_nr,
50 info.get_data().func_nr
51 );
52 return Some(device);
53 }
54 }
55 }
56 None
57}
58
59impl TwizzlerTransport {
60 pub fn new(notifier: std::sync::mpsc::Sender<Option<()>>) -> Result<Self, VirtioPciError> {
61 let device = Arc::new(get_device().expect("failed to find virtio-gpu device"));
62 let int = device.allocate_interrupt(0).unwrap();
63 device
64 .repr_mut()
65 .register_interrupt(int.1 as usize, int.0, DeviceInterruptFlags::empty());
66 let int_device = device.clone();
67
68 let info = unsafe { device.get_info::<PcieDeviceInfo>(0).unwrap() };
69 if info.get_data().vendor_id != 0x1AF4 {
70 tracing::trace!("Vendor ID: {}", info.get_data().vendor_id);
71 return Err(VirtioPciError::InvalidVendorId(info.get_data().vendor_id));
72 }
73
74 let mut common_cfg = None;
75 let mut notify_region = None;
76 let mut notify_offset_multiplier = 0;
77 let mut isr_status = None;
78 let mut config_space = None;
79
80 let mm = device.find_mmio_bar(0xff).unwrap();
81 for cap in device.pcie_capabilities(&mm).unwrap() {
82 let off: usize = match cap {
83 PcieCapability::VendorSpecific(x) => x,
84 _ => {
85 continue;
86 }
87 };
88
89 let mut virtio_cfg_ref = unsafe { mm.get_mmio_offset_mut::<VirtioPciCap>(off) };
90 let virtio_cfg = virtio_cfg_ref.as_mut_ptr();
91 match map_field!(virtio_cfg.cfg_type).read() {
92 VirtioCfgType::CommonCfg if common_cfg.is_none() => {
93 tracing::trace!(
94 "Common CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
95 map_field!(virtio_cfg.bar).read(),
96 map_field!(virtio_cfg.offset).read(),
97 map_field!(virtio_cfg.length).read()
98 );
99 common_cfg = Some(CfgLocation {
100 bar: map_field!(virtio_cfg.bar).read() as usize,
101 offset: map_field!(virtio_cfg.offset).read() as usize,
102 length: map_field!(virtio_cfg.length).read() as usize,
103 });
104 }
105 VirtioCfgType::NotifyCfg if notify_region.is_none() => {
106 let mut notify_ref =
107 unsafe { mm.get_mmio_offset_mut::<VirtioPciNotifyCap>(off) };
108 let notify_cap = notify_ref.as_mut_ptr();
109 notify_offset_multiplier = map_field!(notify_cap.notify_off_multiplier).read();
110 tracing::trace!("Notify CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}, Offset multiplier: {:?}", map_field!(virtio_cfg.bar).read(), map_field!(virtio_cfg.offset).read(), map_field!(virtio_cfg.length).read(), notify_offset_multiplier);
111 notify_region = Some(CfgLocation {
112 bar: map_field!(virtio_cfg.bar).read() as usize,
113 offset: map_field!(virtio_cfg.offset).read() as usize,
114 length: map_field!(virtio_cfg.length).read() as usize,
115 })
116 }
117
118 VirtioCfgType::IsrCfg if isr_status.is_none() => {
119 tracing::trace!(
120 "ISR CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
121 map_field!(virtio_cfg.bar).read(),
122 map_field!(virtio_cfg.offset).read(),
123 map_field!(virtio_cfg.length).read()
124 );
125 isr_status = Some(CfgLocation {
126 bar: map_field!(virtio_cfg.bar).read() as usize,
127 offset: map_field!(virtio_cfg.offset).read() as usize,
128 length: map_field!(virtio_cfg.length).read() as usize,
129 });
130 }
131
132 VirtioCfgType::DeviceCfg if config_space.is_none() => {
133 tracing::trace!(
134 "Device CFG found! Bar: {:?}, Offset: {:?}, Length: {:?}",
135 map_field!(virtio_cfg.bar).read(),
136 map_field!(virtio_cfg.offset).read(),
137 map_field!(virtio_cfg.length).read()
138 );
139 let bar_num = map_field!(virtio_cfg.bar).read() as usize;
140 let bar = device.find_mmio_bar(bar_num).unwrap();
141 let mut start = unsafe {
142 bar.get_mmio_offset_mut::<u32>(map_field!(virtio_cfg.offset).read() as usize)
143 };
144 let len = map_field!(virtio_cfg.length).read() as usize;
145
146 let ptr = unsafe {
147 NonNull::from(core::slice::from_raw_parts_mut(
148 start.as_mut_ptr().as_raw_ptr().as_ptr(),
149 len,
150 ))
151 };
152 config_space = Some(ptr);
153 }
154 _ => {}
155 }
156 }
157 let common_cfg = common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?;
158 let notify_region = notify_region.ok_or(VirtioPciError::MissingNotifyConfig)?;
159 let isr_status = isr_status.ok_or(VirtioPciError::MissingIsrConfig)?;
160
161 let _thread = std::thread::spawn(move || loop {
162 if int_device.repr().check_for_interrupt(0).is_some() {
165 let _ = notifier.send(None);
167 }
168 if int_device.repr().check_for_interrupt(0).is_none() {
174 let int_sleep = int_device.repr().setup_interrupt_sleep(0);
175 let _ = sys_thread_sync(&mut [ThreadSync::new_sleep(int_sleep)], None);
177 }
178 });
179
180 Ok(Self {
181 device,
182 common_cfg,
183 notify_region,
184 notify_offset_multiplier,
185 isr_status,
186 config_space,
187 })
188 }
189}
190
191impl Transport for TwizzlerTransport {
192 fn device_type(&self) -> DeviceType {
193 device_type(
194 unsafe { self.device.get_info::<PcieDeviceInfo>(0) }
195 .unwrap()
196 .get_data()
197 .device_id,
198 )
199 }
200
201 fn read_device_features(&mut self) -> u64 {
202 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
203 let mut reference =
204 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
205 let ptr = reference.as_mut_ptr();
206
207 map_field!(ptr.device_feature_select).write(0);
208 let mut device_feature_bits = map_field!(ptr.device_feature).read() as u64;
209 map_field!(ptr.device_feature_select).write(1);
210 device_feature_bits |= (map_field!(ptr.device_feature).read() as u64) << 32;
211 device_feature_bits
212 }
213
214 fn write_driver_features(&mut self, driver_features: u64) {
215 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
216 let mut reference =
217 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
218 let ptr = reference.as_mut_ptr();
219
220 map_field!(ptr.driver_feature_select).write(0);
221 map_field!(ptr.driver_feature).write(driver_features as u32);
222 map_field!(ptr.driver_feature_select).write(1);
223 map_field!(ptr.driver_feature).write((driver_features >> 32) as u32);
224 }
225
226 fn max_queue_size(&mut self, queue: u16) -> u32 {
227 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
228 let mut reference =
229 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
230 let ptr = reference.as_mut_ptr();
231
232 map_field!(ptr.queue_select).write(queue);
233 map_field!(ptr.queue_size).read().into()
234 }
235
236 fn notify(&mut self, queue: u16) {
237 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
238 let mut reference =
239 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
240 let ptr = reference.as_mut_ptr();
241
242 map_field!(ptr.queue_select).write(queue);
243
244 let queue_notify_off = map_field!(ptr.queue_notify_off).read();
245
246 let offset_bytes = queue_notify_off as usize * self.notify_offset_multiplier as usize;
247 let index = offset_bytes / size_of::<u16>();
248
249 let notify_bar = self.device.find_mmio_bar(self.notify_region.bar).unwrap();
250 let start = unsafe {
251 notify_bar
252 .get_mmio_offset_mut::<u16>(self.notify_region.offset as usize)
253 .as_mut_ptr()
254 .as_raw_ptr()
255 .as_ptr()
256 };
257
258 let notify_ptr = unsafe {
259 VolatilePtr::new(NonNull::from(core::slice::from_raw_parts_mut(
260 start,
261 self.notify_region.length as usize,
262 )))
263 };
264
265 let to_write = notify_ptr.index(index);
266 to_write.write(queue);
267 }
268
269 fn get_status(&self) -> virtio_drivers::transport::DeviceStatus {
270 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
271 let mut reference =
272 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
273 let ptr = reference.as_mut_ptr();
274
275 let status = map_field!(ptr.device_status).read();
276 DeviceStatus::from_bits_truncate(status.into())
277 }
278
279 fn set_status(&mut self, status: virtio_drivers::transport::DeviceStatus) {
280 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
281 let mut reference =
282 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
283 let ptr = reference.as_mut_ptr();
284
285 map_field!(ptr.device_status).write(status.bits() as u8);
286 }
287
288 fn set_guest_page_size(&mut self, _guest_page_size: u32) {
289 }
291
292 fn requires_legacy_layout(&self) -> bool {
293 false
294 }
295
296 fn queue_set(
297 &mut self,
298 queue: u16,
299 size: u32,
300 descriptors: virtio_drivers::PhysAddr,
301 driver_area: virtio_drivers::PhysAddr,
302 device_area: virtio_drivers::PhysAddr,
303 ) {
304 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
305 let mut reference =
306 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
307 let ptr = reference.as_mut_ptr();
308
309 map_field!(ptr.config_msix_vector).write(0);
310 map_field!(ptr.queue_select).write(queue);
311 map_field!(ptr.queue_size).write(size as u16);
312 map_field!(ptr.queue_desc).write(descriptors.try_into().unwrap());
313 map_field!(ptr.queue_driver).write(driver_area.try_into().unwrap());
314 map_field!(ptr.queue_device).write(device_area.try_into().unwrap());
315 map_field!(ptr.queue_msix_vector).write(0);
316 map_field!(ptr.queue_enable).write(1);
317 }
318
319 fn queue_unset(&mut self, _queue: u16) {
320 }
323
324 fn queue_used(&mut self, queue: u16) -> bool {
325 let bar = self.device.find_mmio_bar(self.common_cfg.bar).unwrap();
326 let mut reference =
327 unsafe { bar.get_mmio_offset_mut::<VirtioCommonCfg>(self.common_cfg.offset) };
328 let ptr = reference.as_mut_ptr();
329
330 map_field!(ptr.queue_select).write(queue);
331 map_field!(ptr.queue_enable).read() == 1
332 }
333
334 fn ack_interrupt(&mut self) -> InterruptStatus {
335 let bar = self.device.find_mmio_bar(self.isr_status.bar).unwrap();
336 let mut reference =
337 unsafe { bar.get_mmio_offset_mut::<VirtioIsrStatus>(self.isr_status.offset) };
338 let ptr = reference.as_mut_ptr();
339
340 let status = ptr.read();
341 if status & 0x3 != 0 {
342 InterruptStatus::all()
343 } else {
344 InterruptStatus::empty()
345 }
346 }
347
348 fn read_config_generation(&self) -> u32 {
349 0
350 }
351
352 fn read_config_space<T: FromBytes + IntoBytes>(
353 &self,
354 offset: usize,
355 ) -> virtio_drivers::Result<T> {
356 if let Some(config_space) = self.config_space {
357 if offset > config_space.len() * size_of::<u32>() {
358 Err(Error::ConfigSpaceTooSmall)
359 } else {
360 let result = unsafe {
362 config_space
363 .as_ptr()
364 .cast::<T>()
365 .byte_add(offset)
366 .read_volatile()
367 };
368 Ok(result)
369 }
370 } else {
371 Err(Error::ConfigSpaceMissing)
372 }
373 }
374
375 fn write_config_space<T: IntoBytes + Immutable>(
376 &mut self,
377 offset: usize,
378 value: T,
379 ) -> virtio_drivers::Result<()> {
380 if let Some(config_space) = self.config_space {
381 if offset > config_space.len() * size_of::<u32>() {
382 Err(Error::ConfigSpaceTooSmall)
383 } else {
384 unsafe {
386 config_space
387 .as_ptr()
388 .cast::<T>()
389 .byte_add(offset)
390 .write_volatile(value)
391 };
392 Ok(())
393 }
394 } else {
395 Err(Error::ConfigSpaceMissing)
396 }
397 }
398}
399
400impl Drop for TwizzlerTransport {
401 fn drop(&mut self) {
402 self.set_status(DeviceStatus::empty());
404 while self.get_status() != DeviceStatus::empty() {
405 core::hint::spin_loop();
407 }
408 }
409}
410
411const PCI_DEVICE_ID_OFFSET: u16 = 0x1040;
413
414const TRANSITIONAL_NETWORK: u16 = 0x1000;
415const TRANSITIONAL_BLOCK: u16 = 0x1001;
416const TRANSITIONAL_MEMORY_BALLOONING: u16 = 0x1002;
417const TRANSITIONAL_CONSOLE: u16 = 0x1003;
418const TRANSITIONAL_SCSI_HOST: u16 = 0x1004;
419const TRANSITIONAL_ENTROPY_SOURCE: u16 = 0x1005;
420const TRANSITIONAL_9P_TRANSPORT: u16 = 0x1009;
421
422fn device_type(pci_device_id: u16) -> DeviceType {
423 match pci_device_id {
424 TRANSITIONAL_NETWORK => DeviceType::Network,
425 TRANSITIONAL_BLOCK => DeviceType::Block,
426 TRANSITIONAL_MEMORY_BALLOONING => DeviceType::MemoryBalloon,
427 TRANSITIONAL_CONSOLE => DeviceType::Console,
428 TRANSITIONAL_SCSI_HOST => DeviceType::ScsiHost,
429 TRANSITIONAL_ENTROPY_SOURCE => DeviceType::EntropySource,
430 TRANSITIONAL_9P_TRANSPORT => DeviceType::_9P,
431 id if id >= PCI_DEVICE_ID_OFFSET => {
432 DeviceType::try_from(id - PCI_DEVICE_ID_OFFSET).unwrap()
433 }
434 _ => todo!(),
435 }
436}