wde_wgpu/pipelines/
compute_pipeline.rs

1//! Compute pipeline module.
2use futures_lite::future;
3use wde_logger::prelude::*;
4use wgpu::{BindGroupLayout, ShaderStages, naga};
5
6use crate::instance::{RenderError, RenderInstanceData};
7
8// Compute pipeline configuration
9struct ComputePipelineConfig {
10    push_constants: Vec<wgpu::PushConstantRange>,
11    bind_groups: Vec<wgpu::BindGroupLayout>,
12    shader: String
13}
14
15/// Create a compute pipeline from WGSL source and bind group layouts.
16///
17/// # Example
18/// ```rust,no_run
19/// use wde_wgpu::compute_pipeline::ComputePipeline;
20///
21/// let mut pipeline = ComputePipeline::new("particles");
22/// pipeline
23///     .set_shader(include_str!("../../../res/marching-cubes/spawn_terrain.comp.wgsl"))
24///     .set_bind_groups(layouts)
25///     .add_push_constant(16) // bytes starting at offset 0
26///     .init(instance)
27///     .expect("shader validated");
28/// assert!(pipeline.is_initialized());
29///
30/// let _wgpu_pipeline = pipeline.get_pipeline().unwrap();
31/// ```
32pub struct ComputePipeline {
33    /// Label for the compute pipeline
34    pub label: String,
35    /// The compute pipeline
36    pub pipeline: Option<wgpu::ComputePipeline>,
37    /// The pipeline layout
38    pub layout: Option<wgpu::PipelineLayout>,
39    /// Whether the compute pipeline has been initialized
40    pub is_initialized: bool,
41    /// Configuration of the compute pipeline
42    config: ComputePipelineConfig
43}
44
45impl ComputePipeline {
46    /// Create a new compute pipeline.
47    ///
48    /// # Arguments
49    ///
50    /// * `label` - Label of the render pipeline for debugging.
51    pub fn new(label: &str) -> Self {
52        Self {
53            label: label.to_string(),
54            pipeline: None,
55            layout: None,
56            is_initialized: false,
57            config: ComputePipelineConfig {
58                push_constants: Vec::new(),
59                bind_groups: Vec::new(),
60                shader: String::new()
61            }
62        }
63    }
64
65    /// Set the compute shader of the pipeline.
66    ///
67    /// # Arguments
68    ///
69    /// * `shader` - The shader source code.
70    pub fn set_shader(&mut self, shader: &str) -> &mut Self {
71        self.config.shader = shader.to_string();
72        self
73    }
74
75    /// Add a set of bind groups via its layout to the compute pipeline.
76    /// Note that the order of the bind groups will be the same as the order of the bindings in the shaders.
77    ///
78    /// # Arguments
79    ///
80    /// * `layout` - The bind group layout.
81    pub fn set_bind_groups(&mut self, layout: Vec<BindGroupLayout>) -> &mut Self {
82        for l in layout {
83            self.config.bind_groups.push(l);
84        }
85
86        self
87    }
88
89    /// Add a push constant to the compute pipeline.
90    ///
91    /// # Arguments
92    ///
93    /// * `offset` - The offset of the push constant.
94    /// * `size` - The size of the push constant.
95    pub fn add_push_constant(&mut self, offset: u32, size: u32) -> &mut Self {
96        self.config.push_constants.push(wgpu::PushConstantRange {
97            stages: ShaderStages::COMPUTE,
98            range: offset..offset + size
99        });
100        self
101    }
102
103    /// Initialize a new compute pipeline.
104    ///
105    /// # Arguments
106    ///
107    /// * `instance` - Render instance.
108    ///
109    /// # Returns
110    ///
111    /// * `Result<(), RenderError>` - The result of the initialization.
112    pub fn init(&mut self, instance: &RenderInstanceData) -> Result<(), RenderError> {
113        event!(LogLevel::TRACE, "Creating compute pipeline {}.", self.label);
114        let d = &self.config;
115
116        // Security checks
117        if d.shader.is_empty() {
118            error!(self.label, "Pipeline does not have a compute shader.");
119            return Err(RenderError::MissingShader);
120        }
121
122        // Load shader
123        trace!(self.label, "Loading compute shader.");
124        let shader_module = match naga::front::wgsl::parse_str(&self.config.shader) {
125            Ok(shader) => {
126                match naga::valid::Validator::new(
127                    naga::valid::ValidationFlags::all(),
128                    naga::valid::Capabilities::all()
129                )
130                .validate(&shader)
131                {
132                    Ok(_) => instance
133                        .device
134                        .create_shader_module(wgpu::ShaderModuleDescriptor {
135                            label: Some(format!("{}-render-pip-comp", self.label).as_str()),
136                            source: wgpu::ShaderSource::Wgsl(self.config.shader.to_owned().into())
137                        }),
138                    Err(e) => {
139                        error!(self.label, "Compute shader validation failed: {:?}.", e);
140                        return Err(RenderError::ShaderCompilationError);
141                    }
142                }
143            }
144            Err(e) => {
145                let mut error = format!("Compute shader parsing failed \"{}\".\n", e);
146                for (span, message) in e.labels() {
147                    let location = span.location(&self.config.shader);
148                    error.push_str(&format!(
149                        " - Error on line {} at position {}: \"{}\"\n",
150                        location.line_number, location.line_position, message
151                    ));
152                }
153                error!(self.label, "{}", error);
154                return Err(RenderError::ShaderCompilationError);
155            }
156        };
157        future::block_on(async {
158            let compil_info = shader_module.get_compilation_info().await;
159            for message in compil_info.messages {
160                match message.message_type {
161                    wgpu::CompilationMessageType::Error => error!(
162                        self.label,
163                        "Compute shader {} compilation error '{}' (at {:?}).",
164                        self.label,
165                        message.message,
166                        message.location
167                    ),
168                    wgpu::CompilationMessageType::Warning => warn!(
169                        self.label,
170                        "Compute shader {} compilation warning '{}' (at {:?}).",
171                        self.label,
172                        message.message,
173                        message.location
174                    ),
175                    wgpu::CompilationMessageType::Info => debug!(
176                        self.label,
177                        "Compute shader {} compilation info '{}' (at {:?}).",
178                        self.label,
179                        message.message,
180                        message.location
181                    )
182                }
183            }
184        });
185
186        // Create pipeline layout
187        trace!(self.label, "Creating compute pipeline instance.");
188        let layout = instance
189            .device
190            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
191                label: Some(format!("{}-compute-pip-layout", self.label).as_str()),
192                bind_group_layouts: &d
193                    .bind_groups
194                    .iter()
195                    .collect::<Vec<&wgpu::BindGroupLayout>>(),
196                push_constant_ranges: &d.push_constants
197            });
198
199        // Create a compute pipeline
200        instance
201            .device
202            .push_error_scope(wgpu::ErrorFilter::Validation);
203        let pipeline = instance
204            .device
205            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
206                label: Some(format!("{}-compute-pip", self.label).as_str()),
207                layout: Some(&layout),
208                module: &shader_module,
209                entry_point: Some("main"),
210                compilation_options: wgpu::PipelineCompilationOptions::default(),
211                cache: None
212            });
213
214        // Check for errors
215        let mut res: Result<(), RenderError> = Ok(());
216        future::block_on(async {
217            let error = instance.device.pop_error_scope().await;
218            match error {
219                Some(wgpu::Error::Validation {
220                    source,
221                    description
222                }) => {
223                    error!(
224                        self.label,
225                        "Failed to create compute pipeline with source error: {:?}. Description: {}.",
226                        source,
227                        description
228                    );
229                    res = Err(RenderError::ShaderCompilationError);
230                }
231                Some(e) => {
232                    error!(self.label, "Failed to create compute pipeline: {:?}.", e);
233                    res = Err(RenderError::ShaderCompilationError);
234                }
235                None => ()
236            }
237        });
238
239        // Set pipeline
240        self.pipeline = Some(pipeline);
241        self.layout = Some(layout);
242        self.is_initialized = true;
243
244        res
245    }
246
247    /// Get the compute pipeline.
248    ///
249    /// # Returns
250    ///
251    /// * `Option<&ComputePipelineRef>` - The compute pipeline.
252    pub fn get_pipeline(&self) -> Option<&wgpu::ComputePipeline> {
253        self.pipeline.as_ref()
254    }
255
256    /// Get the pipeline layout.
257    ///
258    /// # Returns
259    ///
260    /// * `Option<&PipelineLayout>` - The pipeline layout.
261    pub fn get_layout(&self) -> Option<&wgpu::PipelineLayout> {
262        self.layout.as_ref()
263    }
264
265    /// Check if the compute pipeline is initialized.
266    ///
267    ///
268    /// # Returns
269    ///
270    /// * `bool` - True if the compute pipeline is initialized, false otherwise.
271    pub fn is_initialized(&self) -> bool {
272        self.is_initialized
273    }
274}