wde_renderer/passes/
pipeline_manager.rs

1use super::{ComputePipelineDescriptor, RenderPipelineDescriptor};
2use crate::assets::PrepareAssetError;
3use crate::core::RenderInstance;
4use crate::{
5    assets::Shader,
6    core::{Extract, ExtractWorld, Render, RenderSet}
7};
8use bevy::{
9    app::{App, Plugin},
10    asset::{AssetEvent, AssetId, Assets},
11    ecs::prelude::*,
12    prelude::MessageReader
13};
14use std::collections::HashMap;
15use wde_logger::prelude::*;
16use wde_wgpu::{
17    compute_pipeline::ComputePipeline,
18    render_pipeline::{RenderPipeline, ShaderStages}
19};
20
21pub(crate) struct PipelineManagerPlugin;
22impl Plugin for PipelineManagerPlugin {
23    fn build(&self, app: &mut App) {
24        app.init_resource::<PipelineManager>()
25            .add_systems(Extract, extract_shaders)
26            .add_systems(
27                Render,
28                (load_render_pipelines, load_compute_pipelines).in_set(RenderSet::Prepare)
29            );
30    }
31}
32
33/// The status of a cached pipeline. Used to query the pipeline manager for the status of a pipeline.
34pub enum CachedPipelineStatus<'a> {
35    Loading,
36    OkRender(&'a RenderPipeline),
37    OkCompute(&'a ComputePipeline),
38    Error
39}
40/// The index of a cached pipeline.
41pub type CachedPipelineIndex = usize;
42/// Stores queued and realized pipelines plus shader cache.
43/// Provides an interface to queue pipelines for loading and query their status.
44#[derive(Resource, Default)]
45pub struct PipelineManager {
46    /// Monotonic index generator for cached pipelines.
47    pipeline_iter: CachedPipelineIndex,
48
49    /// Render pipelines waiting to be built.
50    processing_render_pipelines: HashMap<CachedPipelineIndex, RenderPipelineDescriptor>,
51    /// Built render pipelines ready for use.
52    loaded_render_pipelines: HashMap<CachedPipelineIndex, RenderPipeline>,
53    /// Original descriptors corresponding to built render pipelines.
54    loaded_render_pipelines_desc: HashMap<CachedPipelineIndex, RenderPipelineDescriptor>,
55
56    /// Compute pipelines waiting to be built.
57    processing_compute_pipelines: HashMap<CachedPipelineIndex, ComputePipelineDescriptor>,
58    /// Built compute pipelines ready for use.
59    loaded_compute_pipelines: HashMap<CachedPipelineIndex, ComputePipeline>,
60    /// Original descriptors corresponding to built compute pipelines.
61    loaded_compute_pipelines_desc: HashMap<CachedPipelineIndex, ComputePipelineDescriptor>,
62
63    /// CPU-side cache of shader assets keyed by asset id.
64    shader_cache: HashMap<AssetId<Shader>, Shader>,
65    /// Mapping from shader asset ids to pipelines that reference them (for hot-reload).
66    shader_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>>
67}
68impl PipelineManager {
69    /// Push the creation of a render pipeline to the pipeline manager queue.
70    pub fn create_render_pipeline<E: Send + Sync + 'static>(
71        &mut self,
72        descriptor: RenderPipelineDescriptor,
73        asset: E
74    ) -> Result<CachedPipelineIndex, PrepareAssetError<E>> {
75        // Check that every bind group layout is Some
76        for layout in descriptor.bind_group_layouts.iter() {
77            if layout.is_none() {
78                return Err(PrepareAssetError::RetryNextUpdate(asset));
79            }
80        }
81
82        // Store the pipeline descriptor to the queued pipelines
83        let id = self.pipeline_iter;
84        self.processing_render_pipelines.insert(id, descriptor);
85        self.pipeline_iter += 1;
86        Ok(id)
87    }
88
89    /// Push the creation of a compute pipeline to the pipeline manager queue.
90    pub fn create_compute_pipeline<E: Send + Sync + 'static>(
91        &mut self,
92        descriptor: ComputePipelineDescriptor,
93        asset: E
94    ) -> Result<CachedPipelineIndex, PrepareAssetError<E>> {
95        // Check that every bind group layout is Some
96        for layout in descriptor.bind_group_layouts.iter() {
97            if layout.is_none() {
98                return Err(PrepareAssetError::RetryNextUpdate(asset));
99            }
100        }
101
102        // Store the pipeline descriptor to the queued pipelines
103        let id = self.pipeline_iter;
104        self.processing_compute_pipelines.insert(id, descriptor);
105        self.pipeline_iter += 1;
106        Ok(id)
107    }
108
109    /// Get the status of a pipeline from its cached index.
110    /// Returns loading/ready/error state for both render and compute pipelines.
111    pub fn get_pipeline(&'_ self, id: CachedPipelineIndex) -> CachedPipelineStatus<'_> {
112        if self.processing_render_pipelines.contains_key(&id)
113            || self.processing_compute_pipelines.contains_key(&id)
114        {
115            CachedPipelineStatus::Loading
116        } else if let Some(pipeline) = self.loaded_render_pipelines.get(&id) {
117            CachedPipelineStatus::OkRender(pipeline)
118        } else if let Some(pipeline) = self.loaded_compute_pipelines.get(&id) {
119            CachedPipelineStatus::OkCompute(pipeline)
120        } else {
121            error_once!("Pipeline with id {} not found", id);
122            CachedPipelineStatus::Error
123        }
124    }
125}
126
127/// Extract the shaders from the asset server and store them in the pipeline manager.
128fn extract_shaders(
129    mut pipeline_manager: ResMut<PipelineManager>,
130    shaders: ExtractWorld<Res<Assets<Shader>>>,
131    mut shader_events: ExtractWorld<MessageReader<AssetEvent<Shader>>>
132) {
133    let cache = &mut pipeline_manager.shader_cache;
134    let mut updated_ids = Vec::new();
135    for event in shader_events.read() {
136        match event {
137            AssetEvent::Added { id } => {
138                if let Some(shader) = shaders.get(*id) {
139                    cache.insert(*id, shader.clone());
140                }
141            }
142            AssetEvent::Modified { id } => {
143                if let Some(shader) = shaders.get(*id) {
144                    cache.insert(*id, shader.clone());
145                    updated_ids.push(*id);
146                }
147            }
148            AssetEvent::Removed { id } => {
149                cache.remove(id);
150            }
151            AssetEvent::Unused { .. } => {}
152            AssetEvent::LoadedWithDependencies { .. } => {}
153        }
154    }
155
156    // Recreate the shader to pipelines map
157    for id in updated_ids {
158        let p_ids = match pipeline_manager.shader_to_pipelines.get(&id) {
159            Some(p_ids) => p_ids.clone(),
160            None => continue
161        };
162        for p_id in p_ids.iter() {
163            // Only update the pipeline if it is loaded
164            if pipeline_manager.loaded_render_pipelines.contains_key(p_id) {
165                let desc = pipeline_manager
166                    .loaded_render_pipelines_desc
167                    .remove(p_id)
168                    .unwrap();
169                pipeline_manager
170                    .processing_render_pipelines
171                    .insert(*p_id, desc.clone());
172                pipeline_manager.loaded_render_pipelines.remove(p_id);
173            }
174            if pipeline_manager.loaded_compute_pipelines.contains_key(p_id) {
175                let desc = pipeline_manager
176                    .loaded_compute_pipelines_desc
177                    .remove(p_id)
178                    .unwrap();
179                pipeline_manager
180                    .processing_compute_pipelines
181                    .insert(*p_id, desc.clone());
182                pipeline_manager.loaded_compute_pipelines.remove(p_id);
183            }
184        }
185    }
186}
187
188/// Load the pipelines that are queued in the pipeline manager.
189fn load_render_pipelines(
190    mut pipeline_manager: ResMut<PipelineManager>,
191    render_instance: Res<RenderInstance>
192) {
193    let mut pipelines_loaded_indices: Vec<(usize, RenderPipeline)> = Vec::new();
194    let mut pipelines_loaded_desc: HashMap<CachedPipelineIndex, RenderPipelineDescriptor> =
195        HashMap::new();
196    let mut shaders_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>> =
197        pipeline_manager.shader_to_pipelines.clone();
198    for (id, descriptor) in pipeline_manager.processing_render_pipelines.iter() {
199        let mut can_load = true;
200
201        // Check if vertex shader is loaded
202        let vert_shader = match &descriptor.vert {
203            Some(shader) => {
204                match pipeline_manager.shader_cache.get(&shader.id()) {
205                    Some(shader) => Some(shader),
206                    None => {
207                        // Shader is not loaded yet
208                        can_load = false;
209                        None
210                    }
211                }
212            }
213            None => None
214        };
215
216        // Check if fragment shader is loaded
217        let frag_shader = match &descriptor.frag {
218            Some(shader) => {
219                match pipeline_manager.shader_cache.get(&shader.id()) {
220                    Some(shader) => Some(shader),
221                    None => {
222                        // Shader is not loaded yet
223                        can_load = false;
224                        None
225                    }
226                }
227            }
228            None => None
229        };
230
231        // Skip if shaders are not loaded
232        if !can_load {
233            continue;
234        }
235        pipelines_loaded_desc.insert(*id, descriptor.clone());
236        shaders_to_pipelines
237            .entry(descriptor.vert.as_ref().unwrap().id())
238            .or_default()
239            .push(*id);
240        shaders_to_pipelines
241            .entry(descriptor.frag.as_ref().unwrap().id())
242            .or_default()
243            .push(*id);
244
245        // Build the layouts
246        let mut bind_group_layouts = Vec::new();
247        for layout in descriptor.bind_group_layouts.iter() {
248            let layout = match layout
249                .as_ref()
250                .unwrap()
251                .build(&render_instance.0.read().unwrap())
252            {
253                Ok(layout) => layout,
254                Err(e) => {
255                    error!(
256                        "Failed to build bind group layout for pipeline {}: {:?}",
257                        descriptor.label, e
258                    );
259                    continue;
260                }
261            };
262            bind_group_layouts.push(layout);
263        }
264
265        // Load the pipeline
266        let mut pipeline = RenderPipeline::new(descriptor.label);
267        if let Some(vert_shader) = vert_shader {
268            pipeline.set_shader(&vert_shader.content, ShaderStages::VERTEX);
269        }
270        if let Some(frag_shader) = frag_shader {
271            pipeline.set_shader(&frag_shader.content, ShaderStages::FRAGMENT);
272        }
273        pipeline.set_fragment_blend(descriptor.fragment_blend);
274        pipeline.set_topology(descriptor.topology);
275        pipeline.set_cull_mode(descriptor.cull_mode);
276        pipeline.set_depth(descriptor.depth.clone());
277        pipeline.set_use_vertices_buffer(descriptor.vertex_buffer);
278        if let Some(ref render_targets) = descriptor.render_targets {
279            pipeline.set_render_targets(render_targets.clone());
280        }
281        pipeline.set_sample_count(descriptor.sample_count);
282        for push_constant in descriptor.push_constants.iter() {
283            pipeline.add_push_constant(
284                push_constant.stages,
285                push_constant.offset,
286                push_constant.size
287            );
288        }
289        pipeline.set_bind_groups(bind_group_layouts);
290        match pipeline.init(&render_instance.0.read().unwrap()) {
291            Ok(_) => (),
292            Err(e) => {
293                error_once!("Failed to load pipeline: {:?}", e);
294                continue;
295            }
296        }
297        debug!("Loaded pipeline {} with id {}", descriptor.label, id);
298
299        // Add the pipeline to the loaded pipelines
300        pipelines_loaded_indices.push((*id, pipeline));
301    }
302
303    // Remove loaded pipelines and add them to the loaded pipelines
304    while let Some((id, pipeline)) = pipelines_loaded_indices.pop() {
305        pipeline_manager.processing_render_pipelines.remove(&id);
306        pipeline_manager
307            .loaded_render_pipelines
308            .insert(id, pipeline);
309        pipeline_manager
310            .loaded_render_pipelines_desc
311            .insert(id, pipelines_loaded_desc.remove(&id).unwrap());
312    }
313
314    // Update the shader to pipelines map
315    pipeline_manager.shader_to_pipelines = shaders_to_pipelines;
316}
317
318/// Load the pipelines that are queued in the pipeline manager.
319fn load_compute_pipelines(
320    mut pipeline_manager: ResMut<PipelineManager>,
321    render_instance: Res<RenderInstance>
322) {
323    let mut pipelines_loaded_indices: Vec<(usize, ComputePipeline)> = Vec::new();
324    let mut pipelines_loaded_desc: HashMap<CachedPipelineIndex, ComputePipelineDescriptor> =
325        HashMap::new();
326    let mut shaders_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>> =
327        pipeline_manager.shader_to_pipelines.clone();
328    for (id, descriptor) in pipeline_manager.processing_compute_pipelines.iter() {
329        let mut can_load = true;
330
331        // Check if compute shader is loaded
332        let compute_shader = match &descriptor.comp {
333            Some(shader) => {
334                match pipeline_manager.shader_cache.get(&shader.id()) {
335                    Some(shader) => Some(shader),
336                    None => {
337                        // Shader is not loaded yet
338                        can_load = false;
339                        None
340                    }
341                }
342            }
343            None => None
344        };
345
346        // Skip if shaders are not loaded
347        if !can_load {
348            continue;
349        }
350        pipelines_loaded_desc.insert(*id, descriptor.clone());
351        shaders_to_pipelines
352            .entry(descriptor.comp.as_ref().unwrap().id())
353            .or_default()
354            .push(*id);
355
356        // Build the layouts
357        let mut bind_group_layouts = Vec::new();
358        for layout in descriptor.bind_group_layouts.iter() {
359            let layout = match layout
360                .as_ref()
361                .unwrap()
362                .build(&render_instance.0.read().unwrap())
363            {
364                Ok(layout) => layout,
365                Err(e) => {
366                    error!(
367                        "Failed to build bind group layout for pipeline {}: {:?}",
368                        descriptor.label, e
369                    );
370                    continue;
371                }
372            };
373            bind_group_layouts.push(layout);
374        }
375
376        // Load the pipeline
377        let mut pipeline = ComputePipeline::new(descriptor.label);
378        if let Some(compute_shader) = compute_shader {
379            pipeline.set_shader(&compute_shader.content);
380        }
381        for push_constant in descriptor.push_constants.iter() {
382            pipeline.add_push_constant(push_constant.offset, push_constant.size);
383        }
384        pipeline.set_bind_groups(bind_group_layouts);
385        match pipeline.init(&render_instance.0.read().unwrap()) {
386            Ok(_) => (),
387            Err(e) => {
388                error_once!("Failed to load pipeline: {:?}", e);
389                continue;
390            }
391        }
392
393        // Add the pipeline to the loaded pipelines
394        pipelines_loaded_indices.push((*id, pipeline));
395    }
396
397    // Remove loaded pipelines and add them to the loaded pipelines
398    while let Some((id, pipeline)) = pipelines_loaded_indices.pop() {
399        pipeline_manager.processing_compute_pipelines.remove(&id);
400        pipeline_manager
401            .loaded_compute_pipelines
402            .insert(id, pipeline);
403        pipeline_manager
404            .loaded_compute_pipelines_desc
405            .insert(id, pipelines_loaded_desc.remove(&id).unwrap());
406    }
407
408    // Update the shader to pipelines map
409    pipeline_manager.shader_to_pipelines = shaders_to_pipelines;
410}