wde_wgpu/passes/compute_pass.rs
1//! Compute pass abstraction.
2use crate::{compute_pipeline::ComputePipeline, instance::RenderError, pipelines::BindGroup};
3use wde_logger::prelude::*;
4
5/// RAII wrapper around `wgpu::ComputePass` created by `CommandBuffer::create_compute_pass`.
6///
7/// # Example
8/// ```rust,no_run
9/// use wde_wgpu::{
10/// command_buffer::CommandBuffer,
11/// compute_pass::WComputePass,
12/// compute_pipeline::ComputePipeline,
13/// instance::{RenderError, RenderInstanceData},
14/// };
15///
16/// let mut cmd = CommandBuffer::new(instance, "compute-frame");
17/// {
18/// let mut pass: WComputePass = cmd.create_compute_pass("cull");
19/// pass
20/// .set_pipeline(pipeline)?
21/// .set_bind_group(0, bind_group)
22/// .set_push_constants(bytemuck::bytes_of(&[4u32, 8u32]));
23/// pass.dispatch(4, 1, 1)?;
24/// }
25/// cmd.submit(instance);
26/// ```
27pub struct WComputePass<'a> {
28 pub label: String,
29 compute_pass: wgpu::ComputePass<'a>,
30 pipeline_set: bool
31}
32
33impl std::fmt::Debug for WComputePass<'_> {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("ComputePass")
36 .field("label", &self.label)
37 .finish()
38 }
39}
40
41impl<'a> WComputePass<'a> {
42 /// Create a new compute pass.
43 ///
44 /// # Arguments
45 ///
46 /// * `label` - The label of the compute pass.
47 /// * `compute_pass` - The compute pass to create.
48 pub fn new(label: &str, compute_pass: wgpu::ComputePass<'a>) -> Self {
49 Self {
50 label: label.to_string(),
51 compute_pass,
52 pipeline_set: false
53 }
54 }
55
56 /// Set the pipeline of the compute pass.
57 /// The bind groups of the pipeline are also set.
58 ///
59 /// # Arguments
60 ///
61 /// * `pipeline` - The pipeline to set.
62 ///
63 /// # Errors
64 ///
65 /// * `RenderError::PipelineNotInitialized` - The pipeline is not initialized.
66 pub fn set_pipeline(
67 &mut self,
68 pipeline: &'a ComputePipeline
69 ) -> Result<&mut Self, RenderError> {
70 if pipeline.get_pipeline().is_none() {
71 error!(pipeline.label, "Pipeline is not created yet.");
72 return Err(RenderError::PipelineNotInitialized);
73 }
74
75 // Set pipeline
76 self.compute_pass
77 .set_pipeline(pipeline.get_pipeline().as_ref().unwrap());
78 self.pipeline_set = true;
79 Ok(self)
80 }
81
82 /// Set push constants of the compute pass.
83 ///
84 /// # Arguments
85 ///
86 /// * `data` - The data to set.
87 pub fn set_push_constants(&mut self, data: &[u8]) -> &mut Self {
88 self.compute_pass.set_push_constants(0, data);
89 self
90 }
91
92 /// Set a bind group of the compute pass at a binding.
93 ///
94 /// # Arguments
95 ///
96 /// * `binding` - The binding of the bind group.
97 /// * `bind_group` - The bind group to set.
98 pub fn set_bind_group(&mut self, binding: u32, bind_group: &'a BindGroup) -> &mut Self {
99 debug_assert!(
100 bind_group.0.is_some(),
101 "Bind group {} is not created yet.",
102 binding
103 );
104 self.compute_pass
105 .set_bind_group(binding, bind_group.0.as_ref().unwrap(), &[]);
106 self
107 }
108
109 /// Dispatch the compute pass.
110 ///
111 /// # Arguments
112 ///
113 /// * `x` - The x dimension.
114 /// * `y` - The y dimension.
115 /// * `z` - The z dimension.
116 ///
117 /// # Errors
118 ///
119 /// * `RenderError::PipelineNotSet` - The pipeline is not set.
120 pub fn dispatch(&mut self, x: u32, y: u32, z: u32) -> Result<(), RenderError> {
121 if !self.pipeline_set {
122 error!(self.label, "Pipeline is not set.");
123 return Err(RenderError::PipelineNotSet);
124 }
125
126 // Dispatch
127 self.compute_pass.dispatch_workgroups(x, y, z);
128 Ok(())
129 }
130}