Skip to main content

wde_wgpu/pipelines/
render_pipeline.rs

1//! Render pipeline module.
2
3use futures_lite::future;
4use wde_logger::prelude::*;
5use wgpu::{BindGroupLayout, naga};
6
7use crate::{
8    instance::{RenderError, RenderInstanceData},
9    texture::{DEPTH_FORMAT, SWAPCHAIN_FORMAT, TextureFormat},
10    vertex::Vertex
11};
12
13/// List of available shaders.
14pub type ShaderStages = wgpu::ShaderStages;
15/// Type of the shader module.
16pub type ShaderModule = wgpu::ShaderModule;
17/// Export culling params.
18pub type Face = wgpu::Face;
19/// Export compare function.
20pub type CompareFunction = wgpu::CompareFunction;
21
22/// Depth
23pub type StencilState = wgpu::StencilState;
24pub type StencilFaceState = wgpu::StencilFaceState;
25pub type StencilOperation = wgpu::StencilOperation;
26
27/// Blend state
28pub type BlendState = wgpu::BlendState;
29pub type BlendComponent = wgpu::BlendComponent;
30pub type BlendFactor = wgpu::BlendFactor;
31pub type BlendOperation = wgpu::BlendOperation;
32pub type ColorWrites = wgpu::ColorWrites;
33
34/// Describes an optional depth attachment for a pipeline.
35#[derive(Clone)]
36pub struct DepthDescriptor {
37    /// Whether the pipeline should have a depth attachment.
38    pub enabled: bool,
39    /// Whether the stencil attachment should be read-only.
40    pub write: bool,
41    /// The comparison function that the depth attachment will use.
42    pub compare: CompareFunction,
43    /// The stencil state for the depth attachment. If `None`, stencil testing is disabled.
44    pub stencil: StencilState,
45    /// Override the depth texture format. Defaults to `DEPTH_FORMAT` when `None`.
46    pub format: Option<TextureFormat>
47}
48impl Default for DepthDescriptor {
49    fn default() -> Self {
50        Self {
51            enabled: false,
52            write: true,
53            compare: CompareFunction::Less,
54            stencil: StencilState::default(),
55            format: None
56        }
57    }
58}
59
60/// Convenience enum that maps to `wgpu::PrimitiveTopology`.
61#[derive(Clone, Copy)]
62pub enum RenderTopology {
63    PointList,
64    LineList,
65    LineStrip,
66    TriangleList,
67    TriangleStrip
68}
69
70// Render pipeline configuration
71struct RenderPipelineConfig {
72    depth: DepthDescriptor,
73    render_targets: Vec<TextureFormat>,
74    primitive_topology: wgpu::PrimitiveTopology,
75    push_constants: Vec<wgpu::PushConstantRange>,
76    bind_groups: Vec<wgpu::BindGroupLayout>,
77    vertex_shader: String,
78    fragment_shader: String,
79    fragment_blend: Option<BlendState>,
80    color_write: ColorWrites,
81    cull_mode: Option<Face>,
82    sample_count: u32,
83    use_vertices_buffer: bool
84}
85
86/// Stores a render pipeline.
87///
88/// # Example
89/// ```rust,no_run
90/// use wde_wgpu::render_pipeline::{DepthStencilDescriptor, RenderPipeline, RenderTopology, ShaderStages};
91///
92/// let mut pipeline = RenderPipeline::new("gbuffer");
93/// pipeline
94///     .set_shader(include_str!("../../../res/pbr/gbuffer.vert.wgsl"), ShaderStages::VERTEX)
95///     .set_shader(include_str!("../../../res/pbr/gbuffer.frag.wgsl"), ShaderStages::FRAGMENT)
96///     .set_topology(RenderTopology::TriangleList)
97///     .set_depth(DepthStencilDescriptor { enabled: true, write: true, compare: wgpu::CompareFunction::Less })
98///     .set_bind_groups(layouts)
99///     .add_push_constant(ShaderStages::VERTEX, 0, 64)
100///     .init(instance)
101///     .expect("shaders validated");
102/// assert!(pipeline.is_initialized());
103///
104/// let _wgpu_pipeline = pipeline.get_pipeline().unwrap();
105/// ```
106pub struct RenderPipeline {
107    pub label: String,
108    is_initialized: bool,
109    pipeline: Option<wgpu::RenderPipeline>,
110    layout: Option<wgpu::PipelineLayout>,
111    config: RenderPipelineConfig
112}
113
114impl std::fmt::Debug for RenderPipeline {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("RenderPipeline")
117            .field("label", &self.label)
118            .field("is_initialized", &self.is_initialized)
119            .finish()
120    }
121}
122
123impl RenderPipeline {
124    /// Create a new render pipeline.
125    /// By default, the render pipeline does not have a depth or stencil.
126    /// By default, the primitive topology is `Topology::TriangleList`.
127    /// By default, the cull mode is `Some(Face::Back)`.
128    /// By default, the sample count is 1.
129    /// By default, this pipeline expects a vertex buffer to be used.
130    ///
131    /// # Arguments
132    ///
133    /// * `label` - Label of the render pipeline for debugging.
134    pub fn new(label: &str) -> Self {
135        Self {
136            label: label.to_string(),
137            pipeline: None,
138            layout: None,
139            is_initialized: false,
140            config: RenderPipelineConfig {
141                depth: DepthDescriptor::default(),
142                render_targets: Vec::from([SWAPCHAIN_FORMAT]),
143                primitive_topology: wgpu::PrimitiveTopology::TriangleList,
144                push_constants: Vec::new(),
145                bind_groups: Vec::new(),
146                vertex_shader: String::new(),
147                fragment_shader: String::new(),
148                fragment_blend: Some(wgpu::BlendState::REPLACE),
149                color_write: ColorWrites::ALL,
150                cull_mode: Some(Face::Back),
151                sample_count: 1,
152                use_vertices_buffer: true
153            }
154        }
155    }
156
157    /// Set a given shader.
158    ///
159    /// # Arguments
160    ///
161    /// * `shader` - The shader source code.
162    /// * `shader_type` - The shader type.
163    pub fn set_shader(&mut self, shader: &str, shader_type: ShaderStages) -> &mut Self {
164        match shader_type {
165            ShaderStages::VERTEX => self.config.vertex_shader = shader.to_string(),
166            ShaderStages::FRAGMENT => self.config.fragment_shader = shader.to_string(),
167            _ => {
168                error!(self.label, "Unsupported shader type.");
169            }
170        };
171        self
172    }
173
174    /// Set the primitive topology.
175    ///
176    /// # Arguments
177    ///
178    /// * `topology` - The primitive topology.
179    pub fn set_topology(&mut self, topology: RenderTopology) -> &mut Self {
180        self.config.primitive_topology = match topology {
181            RenderTopology::PointList => wgpu::PrimitiveTopology::PointList,
182            RenderTopology::LineList => wgpu::PrimitiveTopology::LineList,
183            RenderTopology::LineStrip => wgpu::PrimitiveTopology::LineStrip,
184            RenderTopology::TriangleList => wgpu::PrimitiveTopology::TriangleList,
185            RenderTopology::TriangleStrip => wgpu::PrimitiveTopology::TriangleStrip
186        };
187        self
188    }
189
190    /// Set the configuration of the depth/stencil attachment.
191    pub fn set_depth(&mut self, depth: DepthDescriptor) -> &mut Self {
192        self.config.depth = depth;
193        self
194    }
195
196    /// Set the cull mode. None means no culling.
197    pub fn set_cull_mode(&mut self, cull_mode: Option<Face>) -> &mut Self {
198        self.config.cull_mode = cull_mode;
199        self
200    }
201
202    /// Add a set of bind groups via its layout to the render pipeline.
203    /// Note that the order of the bind groups will be the same as the order of the bindings in the shaders.
204    ///
205    /// # Arguments
206    ///
207    /// * `layout` - The bind group layout.
208    pub fn set_bind_groups(&mut self, layout: Vec<BindGroupLayout>) -> &mut Self {
209        for l in layout {
210            self.config.bind_groups.push(l);
211        }
212
213        self
214    }
215
216    /// Set the render targets of the render pipeline.
217    ///
218    /// # Arguments
219    ///
220    /// * `targets` - The render targets.
221    pub fn set_render_targets(&mut self, targets: Vec<TextureFormat>) -> &mut Self {
222        self.config.render_targets = targets;
223        self
224    }
225
226    /// Set the sample count for multisampling.
227    ///
228    /// # Arguments
229    ///
230    /// * `count` - The sample count.
231    pub fn set_sample_count(&mut self, count: u32) -> &mut Self {
232        self.config.sample_count = count;
233        self
234    }
235
236    /// Set whether the pipeline uses a vertex buffer.
237    ///
238    /// # Arguments
239    ///
240    /// * `use_buffer` - Whether to use a vertex buffer.
241    pub fn set_use_vertices_buffer(&mut self, use_buffer: bool) -> &mut Self {
242        self.config.use_vertices_buffer = use_buffer;
243        self
244    }
245
246    /// Set the blend state for the fragment shader.
247    ///
248    /// # Arguments
249    ///
250    /// * `blend` - The blend state.
251    pub fn set_fragment_blend(&mut self, blend: Option<BlendState>) -> &mut Self {
252        self.config.fragment_blend = blend;
253        self
254    }
255
256    /// Set the color write mask for the fragment output.
257    pub fn set_color_write_mask(&mut self, mask: ColorWrites) -> &mut Self {
258        self.config.color_write = mask;
259        self
260    }
261
262    /// Add a push constant to the render pipeline.
263    ///
264    /// # Arguments
265    ///
266    /// * `stages` - The shader stages.
267    /// * `offset` - The offset of the push constant.
268    /// * `size` - The size of the push constant.
269    pub fn add_push_constant(&mut self, stages: ShaderStages, offset: u32, size: u32) {
270        self.config.push_constants.push(wgpu::PushConstantRange {
271            stages,
272            range: offset..offset + size
273        });
274    }
275
276    /// Initialize a new render pipeline.
277    ///
278    /// # Arguments
279    ///
280    /// * `instance` - Render instance.
281    ///
282    /// # Returns
283    ///
284    /// * `Result<(), RenderError>` - The result of the initialization.
285    pub fn init(&mut self, instance: &RenderInstanceData<'_>) -> Result<(), RenderError> {
286        event!(LogLevel::TRACE, "Creating render pipeline {}.", self.label);
287        let d = &self.config;
288
289        // Security checks
290        if d.vertex_shader.is_empty() || d.fragment_shader.is_empty() {
291            error!(
292                self.label,
293                "Pipeline does not have a vertex or fragment shader."
294            );
295            return Err(RenderError::MissingShader);
296        }
297
298        // Load vertex shader
299        trace!(self.label, "Loading shaders.");
300        let shader_module_vert = match naga::front::wgsl::parse_str(&self.config.vertex_shader) {
301            Ok(shader) => {
302                match naga::valid::Validator::new(
303                    naga::valid::ValidationFlags::all(),
304                    naga::valid::Capabilities::all()
305                )
306                .validate(&shader)
307                {
308                    Ok(_) => instance
309                        .device
310                        .create_shader_module(wgpu::ShaderModuleDescriptor {
311                            label: Some(format!("{}-render-pip-vert", self.label).as_str()),
312                            source: wgpu::ShaderSource::Wgsl(
313                                self.config.vertex_shader.to_owned().into()
314                            )
315                        }),
316                    Err(e) => {
317                        error!(self.label, "Vertex shader validation failed: {:?}.", e);
318                        return Err(RenderError::ShaderCompilationError);
319                    }
320                }
321            }
322            Err(e) => {
323                let mut error = format!("Vertex shader parsing failed \"{}\".\n", e);
324                for (span, message) in e.labels() {
325                    let location = span.location(&self.config.vertex_shader);
326                    error.push_str(&format!(
327                        " - Error on line {} at position {}: \"{}\"\n",
328                        location.line_number, location.line_position, message
329                    ));
330                }
331                error!(self.label, "{}", error);
332                return Err(RenderError::ShaderCompilationError);
333            }
334        };
335        future::block_on(async {
336            let compil_info = shader_module_vert.get_compilation_info().await;
337            for message in compil_info.messages {
338                match message.message_type {
339                    wgpu::CompilationMessageType::Error => error!(
340                        self.label,
341                        "Vertex shader {} compilation error '{}' (at {:?}).",
342                        self.label,
343                        message.message,
344                        message.location
345                    ),
346                    wgpu::CompilationMessageType::Warning => warn!(
347                        self.label,
348                        "Vertex shader {} compilation warning '{}' (at {:?}).",
349                        self.label,
350                        message.message,
351                        message.location
352                    ),
353                    wgpu::CompilationMessageType::Info => debug!(
354                        self.label,
355                        "Vertex shader {} compilation info '{}' (at {:?}).",
356                        self.label,
357                        message.message,
358                        message.location
359                    )
360                }
361            }
362        });
363
364        // Load fragment shader
365        let shader_module_frag = match naga::front::wgsl::parse_str(&self.config.fragment_shader) {
366            Ok(shader) => {
367                match naga::valid::Validator::new(
368                    naga::valid::ValidationFlags::all(),
369                    naga::valid::Capabilities::all()
370                )
371                .validate(&shader)
372                {
373                    Ok(_) => instance
374                        .device
375                        .create_shader_module(wgpu::ShaderModuleDescriptor {
376                            label: Some(format!("{}-render-pip-frag", self.label).as_str()),
377                            source: wgpu::ShaderSource::Wgsl(
378                                self.config.fragment_shader.to_owned().into()
379                            )
380                        }),
381                    Err(e) => {
382                        error!(self.label, "Fragment shader validation failed: {:?}.", e);
383                        return Err(RenderError::ShaderCompilationError);
384                    }
385                }
386            }
387            Err(e) => {
388                let mut error = format!("Fragment shader parsing failed \"{}\".\n", e);
389                for (span, message) in e.labels() {
390                    let location = span.location(&self.config.fragment_shader);
391                    error.push_str(&format!(
392                        " - Error on line {} at position {}: \"{}\"\n",
393                        location.line_number, location.line_position, message
394                    ));
395                }
396                error!(self.label, "{}", error);
397                return Err(RenderError::ShaderCompilationError);
398            }
399        };
400        future::block_on(async {
401            let compil_info = shader_module_frag.get_compilation_info().await;
402            for message in compil_info.messages {
403                match message.message_type {
404                    wgpu::CompilationMessageType::Error => error!(
405                        self.label,
406                        "Fragment shader {} compilation error '{}' (at {:?}).",
407                        self.label,
408                        message.message,
409                        message.location
410                    ),
411                    wgpu::CompilationMessageType::Warning => warn!(
412                        self.label,
413                        "Fragment shader {} compilation warning '{}' (at {:?}).",
414                        self.label,
415                        message.message,
416                        message.location
417                    ),
418                    wgpu::CompilationMessageType::Info => debug!(
419                        self.label,
420                        "Fragment shader {} compilation info '{}' (at {:?}).",
421                        self.label,
422                        message.message,
423                        message.location
424                    )
425                }
426            }
427        });
428
429        // Create pipeline layout
430        trace!(self.label, "Creating render pipeline instance.");
431        let layout = instance
432            .device
433            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
434                label: Some(format!("{}-render-pip-layout", self.label).as_str()),
435                bind_group_layouts: &d
436                    .bind_groups
437                    .iter()
438                    .collect::<Vec<&wgpu::BindGroupLayout>>(),
439                push_constant_ranges: &d.push_constants
440            });
441
442        // Define vertex buffers
443        let vertex_buffers: &[wgpu::VertexBufferLayout] = if d.use_vertices_buffer {
444            &[Vertex::describe()]
445        } else {
446            &[]
447        };
448
449        // Create pipeline
450        let mut res: Result<(), RenderError> = Ok(());
451        instance
452            .device
453            .push_error_scope(wgpu::ErrorFilter::Validation);
454        let pipeline = instance
455            .device
456            .create_render_pipeline(&wgpu::RenderPipelineDescriptor {
457                label: Some(format!("{}-render-pip", self.label).as_str()),
458                layout: Some(&layout),
459                cache: None,
460                vertex: wgpu::VertexState {
461                    module: &shader_module_vert,
462                    entry_point: Some("main"),
463                    buffers: vertex_buffers,
464                    compilation_options: wgpu::PipelineCompilationOptions::default()
465                },
466                fragment: Some(wgpu::FragmentState {
467                    // Always write to swapchain format
468                    module: &shader_module_frag,
469                    entry_point: Some("main"),
470                    targets: d
471                        .render_targets
472                        .iter()
473                        .map(|format| {
474                            Some(wgpu::ColorTargetState {
475                                format: *format,
476                                blend: d.fragment_blend,
477                                write_mask: d.color_write
478                            })
479                        })
480                        .collect::<Vec<Option<wgpu::ColorTargetState>>>()
481                        .as_slice(),
482                    compilation_options: wgpu::PipelineCompilationOptions::default()
483                }),
484                primitive: wgpu::PrimitiveState {
485                    topology: d.primitive_topology,
486                    strip_index_format: None,
487                    front_face: wgpu::FrontFace::Ccw,
488                    cull_mode: d.cull_mode,
489                    polygon_mode: wgpu::PolygonMode::Fill,
490                    conservative: false,
491                    unclipped_depth: false
492                },
493                depth_stencil: if d.depth.enabled {
494                    Some(wgpu::DepthStencilState {
495                        format: d.depth.format.unwrap_or(DEPTH_FORMAT),
496                        depth_write_enabled: d.depth.write,
497                        depth_compare: d.depth.compare,
498                        stencil: d.depth.stencil.clone(),
499                        bias: wgpu::DepthBiasState::default()
500                    })
501                } else {
502                    None
503                },
504                multisample: wgpu::MultisampleState {
505                    count: d.sample_count,
506                    mask: !0,
507                    alpha_to_coverage_enabled: false
508                },
509                multiview: Default::default()
510            });
511
512        // Check for errors
513        future::block_on(async {
514            let error = instance.device.pop_error_scope().await;
515            match error {
516                Some(wgpu::Error::Validation {
517                    source,
518                    description
519                }) => {
520                    error!(
521                        self.label,
522                        "Failed to create render pipeline with source error: {:?}. Description: {}.",
523                        source,
524                        description
525                    );
526                    res = Err(RenderError::ShaderCompilationError);
527                }
528                Some(e) => {
529                    error!(self.label, "Failed to create render pipeline: {:?}.", e);
530                    res = Err(RenderError::ShaderCompilationError);
531                }
532                None => ()
533            }
534        });
535
536        // Set pipeline
537        self.pipeline = Some(pipeline);
538        self.layout = Some(layout);
539        self.is_initialized = true;
540
541        res
542    }
543
544    /// Get the render pipeline.
545    ///
546    /// # Returns
547    ///
548    /// * `Option<&RenderPipelineRef>` - The render pipeline.
549    pub fn get_pipeline(&self) -> Option<&wgpu::RenderPipeline> {
550        self.pipeline.as_ref()
551    }
552
553    /// Get the pipeline layout.
554    ///
555    /// # Returns
556    ///
557    /// * `Option<&PipelineLayout>` - The pipeline layout.
558    pub fn get_layout(&self) -> Option<&wgpu::PipelineLayout> {
559        self.layout.as_ref()
560    }
561
562    /// Check if the render pipeline is initialized.
563    ///
564    ///
565    /// # Returns
566    ///
567    /// * `bool` - True if the render pipeline is initialized, false otherwise.
568    pub fn is_initialized(&self) -> bool {
569        self.is_initialized
570    }
571}