wde_wgpu/passes/
compute_pass.rs

1//! Compute pass abstraction.
2use crate::{compute_pipeline::ComputePipeline, instance::RenderError, pipelines::BindGroup};
3use wde_logger::prelude::*;
4
5/// RAII wrapper around `wgpu::ComputePass` created by `CommandBuffer::create_compute_pass`.
6///
7/// # Example
8/// ```rust,no_run
9/// use wde_wgpu::{
10///     command_buffer::CommandBuffer,
11///     compute_pass::WComputePass,
12///     compute_pipeline::ComputePipeline,
13///     instance::{RenderError, RenderInstanceData},
14/// };
15///
16/// let mut cmd = CommandBuffer::new(instance, "compute-frame");
17/// {
18///     let mut pass: WComputePass = cmd.create_compute_pass("cull");
19///     pass
20///         .set_pipeline(pipeline)?
21///         .set_bind_group(0, bind_group)
22///         .set_push_constants(bytemuck::bytes_of(&[4u32, 8u32]));
23///     pass.dispatch(4, 1, 1)?;
24/// }
25/// cmd.submit(instance);
26/// ```
27pub struct WComputePass<'a> {
28    pub label: String,
29    compute_pass: wgpu::ComputePass<'a>,
30    pipeline_set: bool
31}
32
33impl std::fmt::Debug for WComputePass<'_> {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ComputePass")
36            .field("label", &self.label)
37            .finish()
38    }
39}
40
41impl<'a> WComputePass<'a> {
42    /// Create a new compute pass.
43    ///
44    /// # Arguments
45    ///
46    /// * `label` - The label of the compute pass.
47    /// * `compute_pass` - The compute pass to create.
48    pub fn new(label: &str, compute_pass: wgpu::ComputePass<'a>) -> Self {
49        event!(LogLevel::TRACE, "Creating a new compute pass {}.", label);
50        Self {
51            label: label.to_string(),
52            compute_pass,
53            pipeline_set: false
54        }
55    }
56
57    /// Set the pipeline of the compute pass.
58    /// The bind groups of the pipeline are also set.
59    ///
60    /// # Arguments
61    ///
62    /// * `pipeline` - The pipeline to set.
63    ///
64    /// # Errors
65    ///
66    /// * `RenderError::PipelineNotInitialized` - The pipeline is not initialized.
67    pub fn set_pipeline(
68        &mut self,
69        pipeline: &'a ComputePipeline
70    ) -> Result<&mut Self, RenderError> {
71        if pipeline.get_pipeline().is_none() {
72            error!(pipeline.label, "Pipeline is not created yet.");
73            return Err(RenderError::PipelineNotInitialized);
74        }
75
76        // Set pipeline
77        self.compute_pass
78            .set_pipeline(pipeline.get_pipeline().as_ref().unwrap());
79        self.pipeline_set = true;
80        Ok(self)
81    }
82
83    /// Set push constants of the compute pass.
84    ///
85    /// # Arguments
86    ///
87    /// * `data` - The data to set.
88    pub fn set_push_constants(&mut self, data: &[u8]) -> &mut Self {
89        self.compute_pass.set_push_constants(0, data);
90        self
91    }
92
93    /// Set a bind group of the compute pass at a binding.
94    ///
95    /// # Arguments
96    ///
97    /// * `binding` - The binding of the bind group.
98    /// * `bind_group` - The bind group to set.
99    pub fn set_bind_group(&mut self, binding: u32, bind_group: &'a BindGroup) -> &mut Self {
100        debug_assert!(
101            bind_group.0.is_some(),
102            "Bind group {} is not created yet.",
103            binding
104        );
105        self.compute_pass
106            .set_bind_group(binding, bind_group.0.as_ref().unwrap(), &[]);
107        self
108    }
109
110    /// Dispatch the compute pass.
111    ///
112    /// # Arguments
113    ///
114    /// * `x` - The x dimension.
115    /// * `y` - The y dimension.
116    /// * `z` - The z dimension.
117    ///
118    /// # Errors
119    ///
120    /// * `RenderError::PipelineNotSet` - The pipeline is not set.
121    pub fn dispatch(&mut self, x: u32, y: u32, z: u32) -> Result<(), RenderError> {
122        if !self.pipeline_set {
123            error!(self.label, "Pipeline is not set.");
124            return Err(RenderError::PipelineNotSet);
125        }
126
127        // Dispatch
128        event!(
129            LogLevel::TRACE,
130            "Dispatching compute pipeline {} with dimension ({}, {}, {}).",
131            self.label,
132            x,
133            y,
134            z
135        );
136        self.compute_pass.dispatch_workgroups(x, y, z);
137        Ok(())
138    }
139}