wde_wgpu/pipelines/
render_pipeline.rs

1//! Render pipeline module.
2
3use futures_lite::future;
4use wde_logger::prelude::*;
5use wgpu::{BindGroupLayout, naga};
6
7use crate::{
8    instance::{RenderError, RenderInstanceData},
9    texture::{DEPTH_FORMAT, SWAPCHAIN_FORMAT, TextureFormat},
10    vertex::Vertex
11};
12
13/// List of available shaders.
14pub type ShaderStages = wgpu::ShaderStages;
15/// Type of the shader module.
16pub type ShaderModule = wgpu::ShaderModule;
17/// Export culling params.
18pub type Face = wgpu::Face;
19/// Export compare function.
20pub type CompareFunction = wgpu::CompareFunction;
21
22/// Depth
23pub type StencilState = wgpu::StencilState;
24pub type StencilFaceState = wgpu::StencilFaceState;
25pub type StencilOperation = wgpu::StencilOperation;
26
27/// Blend state
28pub type BlendState = wgpu::BlendState;
29pub type BlendComponent = wgpu::BlendComponent;
30pub type BlendFactor = wgpu::BlendFactor;
31pub type BlendOperation = wgpu::BlendOperation;
32
33/// Describes an optional depth attachment for a pipeline.
34#[derive(Clone)]
35pub struct DepthDescriptor {
36    /// Whether the pipeline should have a depth attachment.
37    pub enabled: bool,
38    /// Whether the stencil attachment should be read-only.
39    pub write: bool,
40    /// The comparison function that the depth attachment will use.
41    pub compare: CompareFunction,
42    /// The stencil state for the depth attachment. If `None`, stencil testing is disabled.
43    pub stencil: StencilState
44}
45impl Default for DepthDescriptor {
46    fn default() -> Self {
47        Self {
48            enabled: false,
49            write: true,
50            compare: CompareFunction::Less,
51            stencil: StencilState::default()
52        }
53    }
54}
55
56/// Convenience enum that maps to `wgpu::PrimitiveTopology`.
57#[derive(Clone, Copy)]
58pub enum RenderTopology {
59    PointList,
60    LineList,
61    LineStrip,
62    TriangleList,
63    TriangleStrip
64}
65
66// Render pipeline configuration
67struct RenderPipelineConfig {
68    depth: DepthDescriptor,
69    render_targets: Vec<TextureFormat>,
70    primitive_topology: wgpu::PrimitiveTopology,
71    push_constants: Vec<wgpu::PushConstantRange>,
72    bind_groups: Vec<wgpu::BindGroupLayout>,
73    vertex_shader: String,
74    fragment_shader: String,
75    fragment_blend: Option<BlendState>,
76    cull_mode: Option<Face>,
77    sample_count: u32,
78    use_vertices_buffer: bool
79}
80
81/// Stores a render pipeline.
82///
83/// # Example
84/// ```rust,no_run
85/// use wde_wgpu::render_pipeline::{DepthStencilDescriptor, RenderPipeline, RenderTopology, ShaderStages};
86///
87/// let mut pipeline = RenderPipeline::new("gbuffer");
88/// pipeline
89///     .set_shader(include_str!("../../../res/pbr/gbuffer_vert.wgsl"), ShaderStages::VERTEX)
90///     .set_shader(include_str!("../../../res/pbr/gbuffer_frag.wgsl"), ShaderStages::FRAGMENT)
91///     .set_topology(RenderTopology::TriangleList)
92///     .set_depth(DepthStencilDescriptor { enabled: true, write: true, compare: wgpu::CompareFunction::Less })
93///     .set_bind_groups(layouts)
94///     .add_push_constant(ShaderStages::VERTEX, 0, 64)
95///     .init(instance)
96///     .expect("shaders validated");
97/// assert!(pipeline.is_initialized());
98///
99/// let _wgpu_pipeline = pipeline.get_pipeline().unwrap();
100/// ```
101pub struct RenderPipeline {
102    pub label: String,
103    is_initialized: bool,
104    pipeline: Option<wgpu::RenderPipeline>,
105    layout: Option<wgpu::PipelineLayout>,
106    config: RenderPipelineConfig
107}
108
109impl std::fmt::Debug for RenderPipeline {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("RenderPipeline")
112            .field("label", &self.label)
113            .field("is_initialized", &self.is_initialized)
114            .finish()
115    }
116}
117
118impl RenderPipeline {
119    /// Create a new render pipeline.
120    /// By default, the render pipeline does not have a depth or stencil.
121    /// By default, the primitive topology is `Topology::TriangleList`.
122    /// By default, the cull mode is `Some(Face::Back)`.
123    /// By default, the sample count is 1.
124    /// By default, this pipeline expects a vertex buffer to be used.
125    ///
126    /// # Arguments
127    ///
128    /// * `label` - Label of the render pipeline for debugging.
129    pub fn new(label: &str) -> Self {
130        Self {
131            label: label.to_string(),
132            pipeline: None,
133            layout: None,
134            is_initialized: false,
135            config: RenderPipelineConfig {
136                depth: DepthDescriptor::default(),
137                render_targets: Vec::from([SWAPCHAIN_FORMAT]),
138                primitive_topology: wgpu::PrimitiveTopology::TriangleList,
139                push_constants: Vec::new(),
140                bind_groups: Vec::new(),
141                vertex_shader: String::new(),
142                fragment_shader: String::new(),
143                fragment_blend: Some(wgpu::BlendState::REPLACE),
144                cull_mode: Some(Face::Back),
145                sample_count: 1,
146                use_vertices_buffer: true
147            }
148        }
149    }
150
151    /// Set a given shader.
152    ///
153    /// # Arguments
154    ///
155    /// * `shader` - The shader source code.
156    /// * `shader_type` - The shader type.
157    pub fn set_shader(&mut self, shader: &str, shader_type: ShaderStages) -> &mut Self {
158        match shader_type {
159            ShaderStages::VERTEX => self.config.vertex_shader = shader.to_string(),
160            ShaderStages::FRAGMENT => self.config.fragment_shader = shader.to_string(),
161            _ => {
162                error!(self.label, "Unsupported shader type.");
163            }
164        };
165        self
166    }
167
168    /// Set the primitive topology.
169    ///
170    /// # Arguments
171    ///
172    /// * `topology` - The primitive topology.
173    pub fn set_topology(&mut self, topology: RenderTopology) -> &mut Self {
174        self.config.primitive_topology = match topology {
175            RenderTopology::PointList => wgpu::PrimitiveTopology::PointList,
176            RenderTopology::LineList => wgpu::PrimitiveTopology::LineList,
177            RenderTopology::LineStrip => wgpu::PrimitiveTopology::LineStrip,
178            RenderTopology::TriangleList => wgpu::PrimitiveTopology::TriangleList,
179            RenderTopology::TriangleStrip => wgpu::PrimitiveTopology::TriangleStrip
180        };
181        self
182    }
183
184    /// Set the configuration of the depth/stencil attachment.
185    pub fn set_depth(&mut self, depth: DepthDescriptor) -> &mut Self {
186        self.config.depth = depth;
187        self
188    }
189
190    /// Set the cull mode. None means no culling.
191    pub fn set_cull_mode(&mut self, cull_mode: Option<Face>) -> &mut Self {
192        self.config.cull_mode = cull_mode;
193        self
194    }
195
196    /// Add a set of bind groups via its layout to the render pipeline.
197    /// Note that the order of the bind groups will be the same as the order of the bindings in the shaders.
198    ///
199    /// # Arguments
200    ///
201    /// * `layout` - The bind group layout.
202    pub fn set_bind_groups(&mut self, layout: Vec<BindGroupLayout>) -> &mut Self {
203        for l in layout {
204            self.config.bind_groups.push(l);
205        }
206
207        self
208    }
209
210    /// Set the render targets of the render pipeline.
211    ///
212    /// # Arguments
213    ///
214    /// * `targets` - The render targets.
215    pub fn set_render_targets(&mut self, targets: Vec<TextureFormat>) -> &mut Self {
216        self.config.render_targets = targets;
217        self
218    }
219
220    /// Set the sample count for multisampling.
221    ///
222    /// # Arguments
223    ///
224    /// * `count` - The sample count.
225    pub fn set_sample_count(&mut self, count: u32) -> &mut Self {
226        self.config.sample_count = count;
227        self
228    }
229
230    /// Set whether the pipeline uses a vertex buffer.
231    ///
232    /// # Arguments
233    ///
234    /// * `use_buffer` - Whether to use a vertex buffer.
235    pub fn set_use_vertices_buffer(&mut self, use_buffer: bool) -> &mut Self {
236        self.config.use_vertices_buffer = use_buffer;
237        self
238    }
239
240    /// Set the blend state for the fragment shader.
241    ///
242    /// # Arguments
243    ///
244    /// * `blend` - The blend state.
245    pub fn set_fragment_blend(&mut self, blend: Option<BlendState>) -> &mut Self {
246        self.config.fragment_blend = blend;
247        self
248    }
249
250    /// Add a push constant to the render pipeline.
251    ///
252    /// # Arguments
253    ///
254    /// * `stages` - The shader stages.
255    /// * `offset` - The offset of the push constant.
256    /// * `size` - The size of the push constant.
257    pub fn add_push_constant(&mut self, stages: ShaderStages, offset: u32, size: u32) {
258        self.config.push_constants.push(wgpu::PushConstantRange {
259            stages,
260            range: offset..offset + size
261        });
262    }
263
264    /// Initialize a new render pipeline.
265    ///
266    /// # Arguments
267    ///
268    /// * `instance` - Render instance.
269    ///
270    /// # Returns
271    ///
272    /// * `Result<(), RenderError>` - The result of the initialization.
273    pub fn init(&mut self, instance: &RenderInstanceData<'_>) -> Result<(), RenderError> {
274        event!(LogLevel::TRACE, "Creating render pipeline {}.", self.label);
275        let d = &self.config;
276
277        // Security checks
278        if d.vertex_shader.is_empty() || d.fragment_shader.is_empty() {
279            error!(
280                self.label,
281                "Pipeline does not have a vertex or fragment shader."
282            );
283            return Err(RenderError::MissingShader);
284        }
285
286        // Load vertex shader
287        trace!(self.label, "Loading shaders.");
288        let shader_module_vert = match naga::front::wgsl::parse_str(&self.config.vertex_shader) {
289            Ok(shader) => {
290                match naga::valid::Validator::new(
291                    naga::valid::ValidationFlags::all(),
292                    naga::valid::Capabilities::all()
293                )
294                .validate(&shader)
295                {
296                    Ok(_) => instance
297                        .device
298                        .create_shader_module(wgpu::ShaderModuleDescriptor {
299                            label: Some(format!("{}-render-pip-vert", self.label).as_str()),
300                            source: wgpu::ShaderSource::Wgsl(
301                                self.config.vertex_shader.to_owned().into()
302                            )
303                        }),
304                    Err(e) => {
305                        error!(self.label, "Vertex shader validation failed: {:?}.", e);
306                        return Err(RenderError::ShaderCompilationError);
307                    }
308                }
309            }
310            Err(e) => {
311                let mut error = format!("Vertex shader parsing failed \"{}\".\n", e);
312                for (span, message) in e.labels() {
313                    let location = span.location(&self.config.vertex_shader);
314                    error.push_str(&format!(
315                        " - Error on line {} at position {}: \"{}\"\n",
316                        location.line_number, location.line_position, message
317                    ));
318                }
319                error!(self.label, "{}", error);
320                return Err(RenderError::ShaderCompilationError);
321            }
322        };
323        future::block_on(async {
324            let compil_info = shader_module_vert.get_compilation_info().await;
325            for message in compil_info.messages {
326                match message.message_type {
327                    wgpu::CompilationMessageType::Error => error!(
328                        self.label,
329                        "Vertex shader {} compilation error '{}' (at {:?}).",
330                        self.label,
331                        message.message,
332                        message.location
333                    ),
334                    wgpu::CompilationMessageType::Warning => warn!(
335                        self.label,
336                        "Vertex shader {} compilation warning '{}' (at {:?}).",
337                        self.label,
338                        message.message,
339                        message.location
340                    ),
341                    wgpu::CompilationMessageType::Info => debug!(
342                        self.label,
343                        "Vertex shader {} compilation info '{}' (at {:?}).",
344                        self.label,
345                        message.message,
346                        message.location
347                    )
348                }
349            }
350        });
351
352        // Load fragment shader
353        let shader_module_frag = match naga::front::wgsl::parse_str(&self.config.fragment_shader) {
354            Ok(shader) => {
355                match naga::valid::Validator::new(
356                    naga::valid::ValidationFlags::all(),
357                    naga::valid::Capabilities::all()
358                )
359                .validate(&shader)
360                {
361                    Ok(_) => instance
362                        .device
363                        .create_shader_module(wgpu::ShaderModuleDescriptor {
364                            label: Some(format!("{}-render-pip-frag", self.label).as_str()),
365                            source: wgpu::ShaderSource::Wgsl(
366                                self.config.fragment_shader.to_owned().into()
367                            )
368                        }),
369                    Err(e) => {
370                        error!(self.label, "Fragment shader validation failed: {:?}.", e);
371                        return Err(RenderError::ShaderCompilationError);
372                    }
373                }
374            }
375            Err(e) => {
376                let mut error = format!("Fragment shader parsing failed \"{}\".\n", e);
377                for (span, message) in e.labels() {
378                    let location = span.location(&self.config.fragment_shader);
379                    error.push_str(&format!(
380                        " - Error on line {} at position {}: \"{}\"\n",
381                        location.line_number, location.line_position, message
382                    ));
383                }
384                error!(self.label, "{}", error);
385                return Err(RenderError::ShaderCompilationError);
386            }
387        };
388        future::block_on(async {
389            let compil_info = shader_module_frag.get_compilation_info().await;
390            for message in compil_info.messages {
391                match message.message_type {
392                    wgpu::CompilationMessageType::Error => error!(
393                        self.label,
394                        "Fragment shader {} compilation error '{}' (at {:?}).",
395                        self.label,
396                        message.message,
397                        message.location
398                    ),
399                    wgpu::CompilationMessageType::Warning => warn!(
400                        self.label,
401                        "Fragment shader {} compilation warning '{}' (at {:?}).",
402                        self.label,
403                        message.message,
404                        message.location
405                    ),
406                    wgpu::CompilationMessageType::Info => debug!(
407                        self.label,
408                        "Fragment shader {} compilation info '{}' (at {:?}).",
409                        self.label,
410                        message.message,
411                        message.location
412                    )
413                }
414            }
415        });
416
417        // Create pipeline layout
418        trace!(self.label, "Creating render pipeline instance.");
419        let layout = instance
420            .device
421            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
422                label: Some(format!("{}-render-pip-layout", self.label).as_str()),
423                bind_group_layouts: &d
424                    .bind_groups
425                    .iter()
426                    .collect::<Vec<&wgpu::BindGroupLayout>>(),
427                push_constant_ranges: &d.push_constants
428            });
429
430        // Define vertex buffers
431        let vertex_buffers: &[wgpu::VertexBufferLayout] = if d.use_vertices_buffer {
432            &[Vertex::describe()]
433        } else {
434            &[]
435        };
436
437        // Create pipeline
438        let mut res: Result<(), RenderError> = Ok(());
439        instance
440            .device
441            .push_error_scope(wgpu::ErrorFilter::Validation);
442        let pipeline = instance
443            .device
444            .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
445                label: Some(format!("{}-render-pip", self.label).as_str()),
446                layout: Some(&layout),
447                cache: None,
448                vertex: wgpu::VertexState {
449                    module: &shader_module_vert,
450                    entry_point: Some("main"),
451                    buffers: vertex_buffers,
452                    compilation_options: wgpu::PipelineCompilationOptions::default()
453                },
454                fragment: Some(wgpu::FragmentState {
455                    // Always write to swapchain format
456                    module: &shader_module_frag,
457                    entry_point: Some("main"),
458                    targets: d
459                        .render_targets
460                        .iter()
461                        .map(|format| {
462                            Some(wgpu::ColorTargetState {
463                                format: *format,
464                                blend: d.fragment_blend,
465                                write_mask: wgpu::ColorWrites::ALL
466                            })
467                        })
468                        .collect::<Vec<Option<wgpu::ColorTargetState>>>()
469                        .as_slice(),
470                    compilation_options: wgpu::PipelineCompilationOptions::default()
471                }),
472                primitive: wgpu::PrimitiveState {
473                    topology: d.primitive_topology,
474                    strip_index_format: None,
475                    front_face: wgpu::FrontFace::Ccw,
476                    cull_mode: d.cull_mode,
477                    polygon_mode: wgpu::PolygonMode::Fill,
478                    conservative: false,
479                    unclipped_depth: false
480                },
481                depth_stencil: if d.depth.enabled {
482                    Some(wgpu::DepthStencilState {
483                        format: match DEPTH_FORMAT {
484                            TextureFormat::Depth32Float => wgpu::TextureFormat::Depth32Float,
485                            _ => {
486                                error!(
487                                    "Depth format is not supported for render pipeline '{}'.",
488                                    self.label
489                                );
490                                res = Err(RenderError::UnsupportedDepthFormat);
491                                wgpu::TextureFormat::Depth32Float
492                            }
493                        },
494                        depth_write_enabled: d.depth.write,
495                        depth_compare: d.depth.compare,
496                        stencil: d.depth.stencil.clone(),
497                        bias: wgpu::DepthBiasState::default()
498                    })
499                } else {
500                    None
501                },
502                multisample: wgpu::MultisampleState {
503                    count: d.sample_count,
504                    mask: !0,
505                    alpha_to_coverage_enabled: false
506                },
507                multiview: Default::default()
508            });
509
510        // Check for errors
511        future::block_on(async {
512            let error = instance.device.pop_error_scope().await;
513            match error {
514                Some(wgpu::Error::Validation {
515                    source,
516                    description
517                }) => {
518                    error!(
519                        self.label,
520                        "Failed to create render pipeline with source error: {:?}. Description: {}.",
521                        source,
522                        description
523                    );
524                    res = Err(RenderError::ShaderCompilationError);
525                }
526                Some(e) => {
527                    error!(self.label, "Failed to create render pipeline: {:?}.", e);
528                    res = Err(RenderError::ShaderCompilationError);
529                }
530                None => ()
531            }
532        });
533
534        // Set pipeline
535        self.pipeline = Some(pipeline);
536        self.layout = Some(layout);
537        self.is_initialized = true;
538
539        res
540    }
541
542    /// Get the render pipeline.
543    ///
544    /// # Returns
545    ///
546    /// * `Option<&RenderPipelineRef>` - The render pipeline.
547    pub fn get_pipeline(&self) -> Option<&wgpu::RenderPipeline> {
548        self.pipeline.as_ref()
549    }
550
551    /// Get the pipeline layout.
552    ///
553    /// # Returns
554    ///
555    /// * `Option<&PipelineLayout>` - The pipeline layout.
556    pub fn get_layout(&self) -> Option<&wgpu::PipelineLayout> {
557        self.layout.as_ref()
558    }
559
560    /// Check if the render pipeline is initialized.
561    ///
562    ///
563    /// # Returns
564    ///
565    /// * `bool` - True if the render pipeline is initialized, false otherwise.
566    pub fn is_initialized(&self) -> bool {
567        self.is_initialized
568    }
569}