Skip to main content

wde_renderer/passes/
pipeline_manager.rs

1use super::{ComputePipelineDescriptor, RenderPipelineDescriptor};
2use crate::assets::PrepareAssetError;
3use crate::core::RenderInstance;
4use crate::{
5    assets::Shader,
6    core::{Extract, ExtractWorld, Render, RenderSet}
7};
8use bevy::{
9    app::{App, Plugin},
10    asset::{AssetEvent, AssetId, Assets},
11    ecs::prelude::*,
12    prelude::MessageReader
13};
14use std::collections::HashMap;
15use wde_logger::prelude::*;
16use wde_wgpu::{
17    compute_pipeline::ComputePipeline,
18    render_pipeline::{RenderPipeline, ShaderStages}
19};
20
21pub(crate) struct PipelineManagerPlugin;
22impl Plugin for PipelineManagerPlugin {
23    fn build(&self, app: &mut App) {
24        app.init_resource::<PipelineManager>()
25            .add_systems(Extract, extract_shaders)
26            .add_systems(
27                Render,
28                (load_render_pipelines, load_compute_pipelines).in_set(RenderSet::Prepare)
29            );
30    }
31}
32
33/// The status of a cached pipeline. Used to query the pipeline manager for the status of a pipeline.
34pub enum CachedPipelineStatus<'a> {
35    Loading,
36    OkRender(&'a RenderPipeline),
37    OkCompute(&'a ComputePipeline),
38    Error
39}
40/// The index of a cached pipeline.
41pub type CachedPipelineIndex = usize;
42/// Stores queued and realized pipelines plus shader cache.
43/// Provides an interface to queue pipelines for loading and query their status.
44#[derive(Resource, Default)]
45pub struct PipelineManager {
46    /// Monotonic index generator for cached pipelines.
47    pipeline_iter: CachedPipelineIndex,
48
49    /// Render pipelines waiting to be built.
50    processing_render_pipelines: HashMap<CachedPipelineIndex, RenderPipelineDescriptor>,
51    /// Built render pipelines ready for use.
52    loaded_render_pipelines: HashMap<CachedPipelineIndex, RenderPipeline>,
53    /// Original descriptors corresponding to built render pipelines.
54    loaded_render_pipelines_desc: HashMap<CachedPipelineIndex, RenderPipelineDescriptor>,
55
56    /// Compute pipelines waiting to be built.
57    processing_compute_pipelines: HashMap<CachedPipelineIndex, ComputePipelineDescriptor>,
58    /// Built compute pipelines ready for use.
59    loaded_compute_pipelines: HashMap<CachedPipelineIndex, ComputePipeline>,
60    /// Original descriptors corresponding to built compute pipelines.
61    loaded_compute_pipelines_desc: HashMap<CachedPipelineIndex, ComputePipelineDescriptor>,
62
63    /// CPU-side cache of shader assets keyed by asset id.
64    shader_cache: HashMap<AssetId<Shader>, Shader>,
65    /// Mapping from shader asset ids to pipelines that reference them (for hot-reload).
66    shader_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>>
67}
68impl PipelineManager {
69    /// Push the creation of a render pipeline to the pipeline manager queue.
70    pub fn create_render_pipeline<E: Send + Sync + 'static>(
71        &mut self,
72        descriptor: RenderPipelineDescriptor,
73        asset: E
74    ) -> Result<CachedPipelineIndex, PrepareAssetError<E>> {
75        // Check that every bind group layout is Some
76        for layout in descriptor.bind_group_layouts.iter() {
77            if layout.is_none() {
78                return Err(PrepareAssetError::RetryNextUpdate(asset));
79            }
80        }
81
82        // Store the pipeline descriptor to the queued pipelines
83        let id = self.pipeline_iter;
84        self.processing_render_pipelines.insert(id, descriptor);
85        self.pipeline_iter += 1;
86        Ok(id)
87    }
88
89    /// Push the creation of a compute pipeline to the pipeline manager queue.
90    pub fn create_compute_pipeline<E: Send + Sync + 'static>(
91        &mut self,
92        descriptor: ComputePipelineDescriptor,
93        asset: E
94    ) -> Result<CachedPipelineIndex, PrepareAssetError<E>> {
95        // Check that every bind group layout is Some
96        for layout in descriptor.bind_group_layouts.iter() {
97            if layout.is_none() {
98                return Err(PrepareAssetError::RetryNextUpdate(asset));
99            }
100        }
101
102        // Store the pipeline descriptor to the queued pipelines
103        let id = self.pipeline_iter;
104        self.processing_compute_pipelines.insert(id, descriptor);
105        self.pipeline_iter += 1;
106        Ok(id)
107    }
108
109    /// Get the status of a pipeline from its cached index.
110    /// Returns loading/ready/error state for both render and compute pipelines.
111    pub fn get_pipeline(&'_ self, id: CachedPipelineIndex) -> CachedPipelineStatus<'_> {
112        if self.processing_render_pipelines.contains_key(&id)
113            || self.processing_compute_pipelines.contains_key(&id)
114        {
115            CachedPipelineStatus::Loading
116        } else if let Some(pipeline) = self.loaded_render_pipelines.get(&id) {
117            CachedPipelineStatus::OkRender(pipeline)
118        } else if let Some(pipeline) = self.loaded_compute_pipelines.get(&id) {
119            CachedPipelineStatus::OkCompute(pipeline)
120        } else {
121            error_once!("Pipeline with id {} not found", id);
122            CachedPipelineStatus::Error
123        }
124    }
125}
126
127/// Extract the shaders from the asset server and store them in the pipeline manager.
128fn extract_shaders(
129    mut pipeline_manager: ResMut<PipelineManager>,
130    shaders: ExtractWorld<Res<Assets<Shader>>>,
131    mut shader_events: ExtractWorld<MessageReader<AssetEvent<Shader>>>
132) {
133    let cache = &mut pipeline_manager.shader_cache;
134    let mut updated_ids = Vec::new();
135    for event in shader_events.read() {
136        match event {
137            AssetEvent::Added { id } => {
138                if let Some(shader) = shaders.get(*id) {
139                    cache.insert(*id, shader.clone());
140                }
141            }
142            AssetEvent::Modified { id } => {
143                if let Some(shader) = shaders.get(*id) {
144                    cache.insert(*id, shader.clone());
145                    updated_ids.push(*id);
146                }
147            }
148            AssetEvent::Removed { id } => {
149                cache.remove(id);
150            }
151            AssetEvent::Unused { .. } => {}
152            AssetEvent::LoadedWithDependencies { .. } => {}
153        }
154    }
155
156    // Recreate the shader to pipelines map
157    for id in updated_ids {
158        let p_ids = match pipeline_manager.shader_to_pipelines.get(&id) {
159            Some(p_ids) => p_ids.clone(),
160            None => continue
161        };
162        for p_id in p_ids.iter() {
163            // Only update the pipeline if it is loaded
164            if pipeline_manager.loaded_render_pipelines.contains_key(p_id) {
165                let desc = pipeline_manager
166                    .loaded_render_pipelines_desc
167                    .remove(p_id)
168                    .unwrap();
169                pipeline_manager
170                    .processing_render_pipelines
171                    .insert(*p_id, desc.clone());
172                pipeline_manager.loaded_render_pipelines.remove(p_id);
173            }
174            if pipeline_manager.loaded_compute_pipelines.contains_key(p_id) {
175                let desc = pipeline_manager
176                    .loaded_compute_pipelines_desc
177                    .remove(p_id)
178                    .unwrap();
179                pipeline_manager
180                    .processing_compute_pipelines
181                    .insert(*p_id, desc.clone());
182                pipeline_manager.loaded_compute_pipelines.remove(p_id);
183            }
184        }
185    }
186}
187
188/// Load the pipelines that are queued in the pipeline manager.
189fn load_render_pipelines(
190    mut pipeline_manager: ResMut<PipelineManager>,
191    render_instance: Res<RenderInstance>
192) {
193    let mut pipelines_loaded_indices: Vec<(usize, RenderPipeline)> = Vec::new();
194    let mut pipelines_loaded_desc: HashMap<CachedPipelineIndex, RenderPipelineDescriptor> =
195        HashMap::new();
196    let mut shaders_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>> =
197        pipeline_manager.shader_to_pipelines.clone();
198    for (id, descriptor) in pipeline_manager.processing_render_pipelines.iter() {
199        let mut can_load = true;
200
201        // Check if vertex shader is loaded
202        let vert_shader = match &descriptor.vert {
203            Some(shader) => {
204                match pipeline_manager.shader_cache.get(&shader.id()) {
205                    Some(shader) => Some(shader),
206                    None => {
207                        // Shader is not loaded yet
208                        can_load = false;
209                        None
210                    }
211                }
212            }
213            None => None
214        };
215
216        // Check if fragment shader is loaded
217        let frag_shader = match &descriptor.frag {
218            Some(shader) => {
219                match pipeline_manager.shader_cache.get(&shader.id()) {
220                    Some(shader) => Some(shader),
221                    None => {
222                        // Shader is not loaded yet
223                        can_load = false;
224                        None
225                    }
226                }
227            }
228            None => None
229        };
230
231        // Skip if shaders are not loaded
232        if !can_load {
233            continue;
234        }
235        pipelines_loaded_desc.insert(*id, descriptor.clone());
236        shaders_to_pipelines
237            .entry(descriptor.vert.as_ref().unwrap().id())
238            .or_default()
239            .push(*id);
240        shaders_to_pipelines
241            .entry(descriptor.frag.as_ref().unwrap().id())
242            .or_default()
243            .push(*id);
244
245        // Build the layouts
246        let mut bind_group_layouts = Vec::new();
247        for layout in descriptor.bind_group_layouts.iter() {
248            let layout = match layout
249                .as_ref()
250                .unwrap()
251                .build(&render_instance.0.read().unwrap())
252            {
253                Ok(layout) => layout,
254                Err(e) => {
255                    error!(
256                        "Failed to build bind group layout for pipeline {}: {:?}",
257                        descriptor.label, e
258                    );
259                    continue;
260                }
261            };
262            bind_group_layouts.push(layout);
263        }
264
265        // Load the pipeline
266        let mut pipeline = RenderPipeline::new(descriptor.label);
267        if let Some(vert_shader) = vert_shader {
268            pipeline.set_shader(&vert_shader.content, ShaderStages::VERTEX);
269        }
270        if let Some(frag_shader) = frag_shader {
271            pipeline.set_shader(&frag_shader.content, ShaderStages::FRAGMENT);
272        }
273        pipeline.set_fragment_blend(descriptor.fragment_blend);
274        pipeline.set_color_write_mask(descriptor.color_write);
275        pipeline.set_topology(descriptor.topology);
276        pipeline.set_cull_mode(descriptor.cull_mode);
277        pipeline.set_depth(descriptor.depth.clone());
278        pipeline.set_use_vertices_buffer(descriptor.vertex_buffer);
279        if let Some(ref render_targets) = descriptor.render_targets {
280            pipeline.set_render_targets(render_targets.clone());
281        }
282        pipeline.set_sample_count(descriptor.sample_count);
283        for push_constant in descriptor.push_constants.iter() {
284            pipeline.add_push_constant(
285                push_constant.stages,
286                push_constant.offset,
287                push_constant.size
288            );
289        }
290        pipeline.set_bind_groups(bind_group_layouts);
291        match pipeline.init(&render_instance.0.read().unwrap()) {
292            Ok(_) => (),
293            Err(e) => {
294                error_once!("Failed to load pipeline: {:?}", e);
295                continue;
296            }
297        }
298        debug!("Loaded pipeline {} with id {}", descriptor.label, id);
299
300        // Add the pipeline to the loaded pipelines
301        pipelines_loaded_indices.push((*id, pipeline));
302    }
303
304    // Remove loaded pipelines and add them to the loaded pipelines
305    while let Some((id, pipeline)) = pipelines_loaded_indices.pop() {
306        pipeline_manager.processing_render_pipelines.remove(&id);
307        pipeline_manager
308            .loaded_render_pipelines
309            .insert(id, pipeline);
310        pipeline_manager
311            .loaded_render_pipelines_desc
312            .insert(id, pipelines_loaded_desc.remove(&id).unwrap());
313    }
314
315    // Update the shader to pipelines map
316    pipeline_manager.shader_to_pipelines = shaders_to_pipelines;
317}
318
319/// Load the pipelines that are queued in the pipeline manager.
320fn load_compute_pipelines(
321    mut pipeline_manager: ResMut<PipelineManager>,
322    render_instance: Res<RenderInstance>
323) {
324    let mut pipelines_loaded_indices: Vec<(usize, ComputePipeline)> = Vec::new();
325    let mut pipelines_loaded_desc: HashMap<CachedPipelineIndex, ComputePipelineDescriptor> =
326        HashMap::new();
327    let mut shaders_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>> =
328        pipeline_manager.shader_to_pipelines.clone();
329    for (id, descriptor) in pipeline_manager.processing_compute_pipelines.iter() {
330        let mut can_load = true;
331
332        // Check if compute shader is loaded
333        let compute_shader = match &descriptor.comp {
334            Some(shader) => {
335                match pipeline_manager.shader_cache.get(&shader.id()) {
336                    Some(shader) => Some(shader),
337                    None => {
338                        // Shader is not loaded yet
339                        can_load = false;
340                        None
341                    }
342                }
343            }
344            None => None
345        };
346
347        // Skip if shaders are not loaded
348        if !can_load {
349            continue;
350        }
351        pipelines_loaded_desc.insert(*id, descriptor.clone());
352        shaders_to_pipelines
353            .entry(descriptor.comp.as_ref().unwrap().id())
354            .or_default()
355            .push(*id);
356
357        // Build the layouts
358        let mut bind_group_layouts = Vec::new();
359        for layout in descriptor.bind_group_layouts.iter() {
360            let layout = match layout
361                .as_ref()
362                .unwrap()
363                .build(&render_instance.0.read().unwrap())
364            {
365                Ok(layout) => layout,
366                Err(e) => {
367                    error!(
368                        "Failed to build bind group layout for pipeline {}: {:?}",
369                        descriptor.label, e
370                    );
371                    continue;
372                }
373            };
374            bind_group_layouts.push(layout);
375        }
376
377        // Load the pipeline
378        let mut pipeline = ComputePipeline::new(descriptor.label);
379        if let Some(compute_shader) = compute_shader {
380            pipeline.set_shader(&compute_shader.content);
381        }
382        for push_constant in descriptor.push_constants.iter() {
383            pipeline.add_push_constant(push_constant.offset, push_constant.size);
384        }
385        pipeline.set_bind_groups(bind_group_layouts);
386        match pipeline.init(&render_instance.0.read().unwrap()) {
387            Ok(_) => (),
388            Err(e) => {
389                error_once!("Failed to load pipeline: {:?}", e);
390                continue;
391            }
392        }
393
394        // Add the pipeline to the loaded pipelines
395        pipelines_loaded_indices.push((*id, pipeline));
396    }
397
398    // Remove loaded pipelines and add them to the loaded pipelines
399    while let Some((id, pipeline)) = pipelines_loaded_indices.pop() {
400        pipeline_manager.processing_compute_pipelines.remove(&id);
401        pipeline_manager
402            .loaded_compute_pipelines
403            .insert(id, pipeline);
404        pipeline_manager
405            .loaded_compute_pipelines_desc
406            .insert(id, pipelines_loaded_desc.remove(&id).unwrap());
407    }
408
409    // Update the shader to pipelines map
410    pipeline_manager.shader_to_pipelines = shaders_to_pipelines;
411}