wde_wgpu/passes/compute_pass.rs
1//! Compute pass abstraction.
2use crate::{compute_pipeline::ComputePipeline, instance::RenderError, pipelines::BindGroup};
3use wde_logger::prelude::*;
4
5/// RAII wrapper around `wgpu::ComputePass` created by `CommandBuffer::create_compute_pass`.
6///
7/// # Example
8/// ```rust,no_run
9/// use wde_wgpu::{
10/// command_buffer::CommandBuffer,
11/// compute_pass::WComputePass,
12/// compute_pipeline::ComputePipeline,
13/// instance::{RenderError, RenderInstanceData},
14/// };
15///
16/// let mut cmd = CommandBuffer::new(instance, "compute-frame");
17/// {
18/// let mut pass: WComputePass = cmd.create_compute_pass("cull");
19/// pass
20/// .set_pipeline(pipeline)?
21/// .set_bind_group(0, bind_group)
22/// .set_push_constants(bytemuck::bytes_of(&[4u32, 8u32]));
23/// pass.dispatch(4, 1, 1)?;
24/// }
25/// cmd.submit(instance);
26/// ```
27pub struct WComputePass<'a> {
28 pub label: String,
29 compute_pass: wgpu::ComputePass<'a>,
30 pipeline_set: bool
31}
32
33impl std::fmt::Debug for WComputePass<'_> {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("ComputePass")
36 .field("label", &self.label)
37 .finish()
38 }
39}
40
41impl<'a> WComputePass<'a> {
42 /// Create a new compute pass.
43 ///
44 /// # Arguments
45 ///
46 /// * `label` - The label of the compute pass.
47 /// * `compute_pass` - The compute pass to create.
48 pub fn new(label: &str, compute_pass: wgpu::ComputePass<'a>) -> Self {
49 event!(LogLevel::TRACE, "Creating a new compute pass {}.", label);
50 Self {
51 label: label.to_string(),
52 compute_pass,
53 pipeline_set: false
54 }
55 }
56
57 /// Set the pipeline of the compute pass.
58 /// The bind groups of the pipeline are also set.
59 ///
60 /// # Arguments
61 ///
62 /// * `pipeline` - The pipeline to set.
63 ///
64 /// # Errors
65 ///
66 /// * `RenderError::PipelineNotInitialized` - The pipeline is not initialized.
67 pub fn set_pipeline(
68 &mut self,
69 pipeline: &'a ComputePipeline
70 ) -> Result<&mut Self, RenderError> {
71 if pipeline.get_pipeline().is_none() {
72 error!(pipeline.label, "Pipeline is not created yet.");
73 return Err(RenderError::PipelineNotInitialized);
74 }
75
76 // Set pipeline
77 self.compute_pass
78 .set_pipeline(pipeline.get_pipeline().as_ref().unwrap());
79 self.pipeline_set = true;
80 Ok(self)
81 }
82
83 /// Set push constants of the compute pass.
84 ///
85 /// # Arguments
86 ///
87 /// * `data` - The data to set.
88 pub fn set_push_constants(&mut self, data: &[u8]) -> &mut Self {
89 self.compute_pass.set_push_constants(0, data);
90 self
91 }
92
93 /// Set a bind group of the compute pass at a binding.
94 ///
95 /// # Arguments
96 ///
97 /// * `binding` - The binding of the bind group.
98 /// * `bind_group` - The bind group to set.
99 pub fn set_bind_group(&mut self, binding: u32, bind_group: &'a BindGroup) -> &mut Self {
100 debug_assert!(
101 bind_group.0.is_some(),
102 "Bind group {} is not created yet.",
103 binding
104 );
105 self.compute_pass
106 .set_bind_group(binding, bind_group.0.as_ref().unwrap(), &[]);
107 self
108 }
109
110 /// Dispatch the compute pass.
111 ///
112 /// # Arguments
113 ///
114 /// * `x` - The x dimension.
115 /// * `y` - The y dimension.
116 /// * `z` - The z dimension.
117 ///
118 /// # Errors
119 ///
120 /// * `RenderError::PipelineNotSet` - The pipeline is not set.
121 pub fn dispatch(&mut self, x: u32, y: u32, z: u32) -> Result<(), RenderError> {
122 if !self.pipeline_set {
123 error!(self.label, "Pipeline is not set.");
124 return Err(RenderError::PipelineNotSet);
125 }
126
127 // Dispatch
128 event!(
129 LogLevel::TRACE,
130 "Dispatching compute pipeline {} with dimension ({}, {}, {}).",
131 self.label,
132 x,
133 y,
134 z
135 );
136 self.compute_pass.dispatch_workgroups(x, y, z);
137 Ok(())
138 }
139}