Skip to main content

wde_wgpu/passes/
compute_pass.rs

1//! Compute pass abstraction.
2use crate::{compute_pipeline::ComputePipeline, instance::RenderError, pipelines::BindGroup};
3use wde_logger::prelude::*;
4
5/// RAII wrapper around `wgpu::ComputePass` created by `CommandBuffer::create_compute_pass`.
6///
7/// # Example
8/// ```rust,no_run
9/// use wde_wgpu::{
10///     command_buffer::CommandBuffer,
11///     compute_pass::WComputePass,
12///     compute_pipeline::ComputePipeline,
13///     instance::{RenderError, RenderInstanceData},
14/// };
15///
16/// let mut cmd = CommandBuffer::new(instance, "compute-frame");
17/// {
18///     let mut pass: WComputePass = cmd.create_compute_pass("cull");
19///     pass
20///         .set_pipeline(pipeline)?
21///         .set_bind_group(0, bind_group)
22///         .set_push_constants(bytemuck::bytes_of(&[4u32, 8u32]));
23///     pass.dispatch(4, 1, 1)?;
24/// }
25/// cmd.submit(instance);
26/// ```
27pub struct WComputePass<'a> {
28    pub label: String,
29    compute_pass: wgpu::ComputePass<'a>,
30    pipeline_set: bool
31}
32
33impl std::fmt::Debug for WComputePass<'_> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ComputePass")
36            .field("label", &self.label)
37            .finish()
38    }
39}
40
41impl<'a> WComputePass<'a> {
42    /// Create a new compute pass.
43    ///
44    /// # Arguments
45    ///
46    /// * `label` - The label of the compute pass.
47    /// * `compute_pass` - The compute pass to create.
48    pub fn new(label: &str, compute_pass: wgpu::ComputePass<'a>) -> Self {
49        Self {
50            label: label.to_string(),
51            compute_pass,
52            pipeline_set: false
53        }
54    }
55
56    /// Set the pipeline of the compute pass.
57    /// The bind groups of the pipeline are also set.
58    ///
59    /// # Arguments
60    ///
61    /// * `pipeline` - The pipeline to set.
62    ///
63    /// # Errors
64    ///
65    /// * `RenderError::PipelineNotInitialized` - The pipeline is not initialized.
66    pub fn set_pipeline(
67        &mut self,
68        pipeline: &'a ComputePipeline
69    ) -> Result<&mut Self, RenderError> {
70        if pipeline.get_pipeline().is_none() {
71            error!(pipeline.label, "Pipeline is not created yet.");
72            return Err(RenderError::PipelineNotInitialized);
73        }
74
75        // Set pipeline
76        self.compute_pass
77            .set_pipeline(pipeline.get_pipeline().as_ref().unwrap());
78        self.pipeline_set = true;
79        Ok(self)
80    }
81
82    /// Set push constants of the compute pass.
83    ///
84    /// # Arguments
85    ///
86    /// * `data` - The data to set.
87    pub fn set_push_constants(&mut self, data: &[u8]) -> &mut Self {
88        self.compute_pass.set_push_constants(0, data);
89        self
90    }
91
92    /// Set a bind group of the compute pass at a binding.
93    ///
94    /// # Arguments
95    ///
96    /// * `binding` - The binding of the bind group.
97    /// * `bind_group` - The bind group to set.
98    pub fn set_bind_group(&mut self, binding: u32, bind_group: &'a BindGroup) -> &mut Self {
99        debug_assert!(
100            bind_group.0.is_some(),
101            "Bind group {} is not created yet.",
102            binding
103        );
104        self.compute_pass
105            .set_bind_group(binding, bind_group.0.as_ref().unwrap(), &[]);
106        self
107    }
108
109    /// Dispatch the compute pass.
110    ///
111    /// # Arguments
112    ///
113    /// * `x` - The x dimension.
114    /// * `y` - The y dimension.
115    /// * `z` - The z dimension.
116    ///
117    /// # Errors
118    ///
119    /// * `RenderError::PipelineNotSet` - The pipeline is not set.
120    pub fn dispatch(&mut self, x: u32, y: u32, z: u32) -> Result<(), RenderError> {
121        if !self.pipeline_set {
122            error!(self.label, "Pipeline is not set.");
123            return Err(RenderError::PipelineNotSet);
124        }
125
126        // Dispatch
127        self.compute_pass.dispatch_workgroups(x, y, z);
128        Ok(())
129    }
130}