1use super::{ComputePipelineDescriptor, RenderPipelineDescriptor};
2use crate::assets::PrepareAssetError;
3use crate::core::RenderInstance;
4use crate::{
5 assets::Shader,
6 core::{Extract, ExtractWorld, Render, RenderSet}
7};
8use bevy::{
9 app::{App, Plugin},
10 asset::{AssetEvent, AssetId, Assets},
11 ecs::prelude::*,
12 prelude::MessageReader
13};
14use std::collections::HashMap;
15use wde_logger::prelude::*;
16use wde_wgpu::{
17 compute_pipeline::ComputePipeline,
18 render_pipeline::{RenderPipeline, ShaderStages}
19};
20
21pub(crate) struct PipelineManagerPlugin;
22impl Plugin for PipelineManagerPlugin {
23 fn build(&self, app: &mut App) {
24 app.init_resource::<PipelineManager>()
25 .add_systems(Extract, extract_shaders)
26 .add_systems(
27 Render,
28 (load_render_pipelines, load_compute_pipelines).in_set(RenderSet::Prepare)
29 );
30 }
31}
32
33pub enum CachedPipelineStatus<'a> {
35 Loading,
36 OkRender(&'a RenderPipeline),
37 OkCompute(&'a ComputePipeline),
38 Error
39}
40pub type CachedPipelineIndex = usize;
42#[derive(Resource, Default)]
45pub struct PipelineManager {
46 pipeline_iter: CachedPipelineIndex,
48
49 processing_render_pipelines: HashMap<CachedPipelineIndex, RenderPipelineDescriptor>,
51 loaded_render_pipelines: HashMap<CachedPipelineIndex, RenderPipeline>,
53 loaded_render_pipelines_desc: HashMap<CachedPipelineIndex, RenderPipelineDescriptor>,
55
56 processing_compute_pipelines: HashMap<CachedPipelineIndex, ComputePipelineDescriptor>,
58 loaded_compute_pipelines: HashMap<CachedPipelineIndex, ComputePipeline>,
60 loaded_compute_pipelines_desc: HashMap<CachedPipelineIndex, ComputePipelineDescriptor>,
62
63 shader_cache: HashMap<AssetId<Shader>, Shader>,
65 shader_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>>
67}
68impl PipelineManager {
69 pub fn create_render_pipeline<E: Send + Sync + 'static>(
71 &mut self,
72 descriptor: RenderPipelineDescriptor,
73 asset: E
74 ) -> Result<CachedPipelineIndex, PrepareAssetError<E>> {
75 for layout in descriptor.bind_group_layouts.iter() {
77 if layout.is_none() {
78 return Err(PrepareAssetError::RetryNextUpdate(asset));
79 }
80 }
81
82 let id = self.pipeline_iter;
84 self.processing_render_pipelines.insert(id, descriptor);
85 self.pipeline_iter += 1;
86 Ok(id)
87 }
88
89 pub fn create_compute_pipeline<E: Send + Sync + 'static>(
91 &mut self,
92 descriptor: ComputePipelineDescriptor,
93 asset: E
94 ) -> Result<CachedPipelineIndex, PrepareAssetError<E>> {
95 for layout in descriptor.bind_group_layouts.iter() {
97 if layout.is_none() {
98 return Err(PrepareAssetError::RetryNextUpdate(asset));
99 }
100 }
101
102 let id = self.pipeline_iter;
104 self.processing_compute_pipelines.insert(id, descriptor);
105 self.pipeline_iter += 1;
106 Ok(id)
107 }
108
109 pub fn get_pipeline(&'_ self, id: CachedPipelineIndex) -> CachedPipelineStatus<'_> {
112 if self.processing_render_pipelines.contains_key(&id)
113 || self.processing_compute_pipelines.contains_key(&id)
114 {
115 CachedPipelineStatus::Loading
116 } else if let Some(pipeline) = self.loaded_render_pipelines.get(&id) {
117 CachedPipelineStatus::OkRender(pipeline)
118 } else if let Some(pipeline) = self.loaded_compute_pipelines.get(&id) {
119 CachedPipelineStatus::OkCompute(pipeline)
120 } else {
121 error_once!("Pipeline with id {} not found", id);
122 CachedPipelineStatus::Error
123 }
124 }
125}
126
127fn extract_shaders(
129 mut pipeline_manager: ResMut<PipelineManager>,
130 shaders: ExtractWorld<Res<Assets<Shader>>>,
131 mut shader_events: ExtractWorld<MessageReader<AssetEvent<Shader>>>
132) {
133 let cache = &mut pipeline_manager.shader_cache;
134 let mut updated_ids = Vec::new();
135 for event in shader_events.read() {
136 match event {
137 AssetEvent::Added { id } => {
138 if let Some(shader) = shaders.get(*id) {
139 cache.insert(*id, shader.clone());
140 }
141 }
142 AssetEvent::Modified { id } => {
143 if let Some(shader) = shaders.get(*id) {
144 cache.insert(*id, shader.clone());
145 updated_ids.push(*id);
146 }
147 }
148 AssetEvent::Removed { id } => {
149 cache.remove(id);
150 }
151 AssetEvent::Unused { .. } => {}
152 AssetEvent::LoadedWithDependencies { .. } => {}
153 }
154 }
155
156 for id in updated_ids {
158 let p_ids = match pipeline_manager.shader_to_pipelines.get(&id) {
159 Some(p_ids) => p_ids.clone(),
160 None => continue
161 };
162 for p_id in p_ids.iter() {
163 if pipeline_manager.loaded_render_pipelines.contains_key(p_id) {
165 let desc = pipeline_manager
166 .loaded_render_pipelines_desc
167 .remove(p_id)
168 .unwrap();
169 pipeline_manager
170 .processing_render_pipelines
171 .insert(*p_id, desc.clone());
172 pipeline_manager.loaded_render_pipelines.remove(p_id);
173 }
174 if pipeline_manager.loaded_compute_pipelines.contains_key(p_id) {
175 let desc = pipeline_manager
176 .loaded_compute_pipelines_desc
177 .remove(p_id)
178 .unwrap();
179 pipeline_manager
180 .processing_compute_pipelines
181 .insert(*p_id, desc.clone());
182 pipeline_manager.loaded_compute_pipelines.remove(p_id);
183 }
184 }
185 }
186}
187
188fn load_render_pipelines(
190 mut pipeline_manager: ResMut<PipelineManager>,
191 render_instance: Res<RenderInstance>
192) {
193 let mut pipelines_loaded_indices: Vec<(usize, RenderPipeline)> = Vec::new();
194 let mut pipelines_loaded_desc: HashMap<CachedPipelineIndex, RenderPipelineDescriptor> =
195 HashMap::new();
196 let mut shaders_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>> =
197 pipeline_manager.shader_to_pipelines.clone();
198 for (id, descriptor) in pipeline_manager.processing_render_pipelines.iter() {
199 let mut can_load = true;
200
201 let vert_shader = match &descriptor.vert {
203 Some(shader) => {
204 match pipeline_manager.shader_cache.get(&shader.id()) {
205 Some(shader) => Some(shader),
206 None => {
207 can_load = false;
209 None
210 }
211 }
212 }
213 None => None
214 };
215
216 let frag_shader = match &descriptor.frag {
218 Some(shader) => {
219 match pipeline_manager.shader_cache.get(&shader.id()) {
220 Some(shader) => Some(shader),
221 None => {
222 can_load = false;
224 None
225 }
226 }
227 }
228 None => None
229 };
230
231 if !can_load {
233 continue;
234 }
235 pipelines_loaded_desc.insert(*id, descriptor.clone());
236 shaders_to_pipelines
237 .entry(descriptor.vert.as_ref().unwrap().id())
238 .or_default()
239 .push(*id);
240 shaders_to_pipelines
241 .entry(descriptor.frag.as_ref().unwrap().id())
242 .or_default()
243 .push(*id);
244
245 let mut bind_group_layouts = Vec::new();
247 for layout in descriptor.bind_group_layouts.iter() {
248 let layout = match layout
249 .as_ref()
250 .unwrap()
251 .build(&render_instance.0.read().unwrap())
252 {
253 Ok(layout) => layout,
254 Err(e) => {
255 error!(
256 "Failed to build bind group layout for pipeline {}: {:?}",
257 descriptor.label, e
258 );
259 continue;
260 }
261 };
262 bind_group_layouts.push(layout);
263 }
264
265 let mut pipeline = RenderPipeline::new(descriptor.label);
267 if let Some(vert_shader) = vert_shader {
268 pipeline.set_shader(&vert_shader.content, ShaderStages::VERTEX);
269 }
270 if let Some(frag_shader) = frag_shader {
271 pipeline.set_shader(&frag_shader.content, ShaderStages::FRAGMENT);
272 }
273 pipeline.set_fragment_blend(descriptor.fragment_blend);
274 pipeline.set_topology(descriptor.topology);
275 pipeline.set_cull_mode(descriptor.cull_mode);
276 pipeline.set_depth(descriptor.depth.clone());
277 pipeline.set_use_vertices_buffer(descriptor.vertex_buffer);
278 if let Some(ref render_targets) = descriptor.render_targets {
279 pipeline.set_render_targets(render_targets.clone());
280 }
281 pipeline.set_sample_count(descriptor.sample_count);
282 for push_constant in descriptor.push_constants.iter() {
283 pipeline.add_push_constant(
284 push_constant.stages,
285 push_constant.offset,
286 push_constant.size
287 );
288 }
289 pipeline.set_bind_groups(bind_group_layouts);
290 match pipeline.init(&render_instance.0.read().unwrap()) {
291 Ok(_) => (),
292 Err(e) => {
293 error_once!("Failed to load pipeline: {:?}", e);
294 continue;
295 }
296 }
297 debug!("Loaded pipeline {} with id {}", descriptor.label, id);
298
299 pipelines_loaded_indices.push((*id, pipeline));
301 }
302
303 while let Some((id, pipeline)) = pipelines_loaded_indices.pop() {
305 pipeline_manager.processing_render_pipelines.remove(&id);
306 pipeline_manager
307 .loaded_render_pipelines
308 .insert(id, pipeline);
309 pipeline_manager
310 .loaded_render_pipelines_desc
311 .insert(id, pipelines_loaded_desc.remove(&id).unwrap());
312 }
313
314 pipeline_manager.shader_to_pipelines = shaders_to_pipelines;
316}
317
318fn load_compute_pipelines(
320 mut pipeline_manager: ResMut<PipelineManager>,
321 render_instance: Res<RenderInstance>
322) {
323 let mut pipelines_loaded_indices: Vec<(usize, ComputePipeline)> = Vec::new();
324 let mut pipelines_loaded_desc: HashMap<CachedPipelineIndex, ComputePipelineDescriptor> =
325 HashMap::new();
326 let mut shaders_to_pipelines: HashMap<AssetId<Shader>, Vec<CachedPipelineIndex>> =
327 pipeline_manager.shader_to_pipelines.clone();
328 for (id, descriptor) in pipeline_manager.processing_compute_pipelines.iter() {
329 let mut can_load = true;
330
331 let compute_shader = match &descriptor.comp {
333 Some(shader) => {
334 match pipeline_manager.shader_cache.get(&shader.id()) {
335 Some(shader) => Some(shader),
336 None => {
337 can_load = false;
339 None
340 }
341 }
342 }
343 None => None
344 };
345
346 if !can_load {
348 continue;
349 }
350 pipelines_loaded_desc.insert(*id, descriptor.clone());
351 shaders_to_pipelines
352 .entry(descriptor.comp.as_ref().unwrap().id())
353 .or_default()
354 .push(*id);
355
356 let mut bind_group_layouts = Vec::new();
358 for layout in descriptor.bind_group_layouts.iter() {
359 let layout = match layout
360 .as_ref()
361 .unwrap()
362 .build(&render_instance.0.read().unwrap())
363 {
364 Ok(layout) => layout,
365 Err(e) => {
366 error!(
367 "Failed to build bind group layout for pipeline {}: {:?}",
368 descriptor.label, e
369 );
370 continue;
371 }
372 };
373 bind_group_layouts.push(layout);
374 }
375
376 let mut pipeline = ComputePipeline::new(descriptor.label);
378 if let Some(compute_shader) = compute_shader {
379 pipeline.set_shader(&compute_shader.content);
380 }
381 for push_constant in descriptor.push_constants.iter() {
382 pipeline.add_push_constant(push_constant.offset, push_constant.size);
383 }
384 pipeline.set_bind_groups(bind_group_layouts);
385 match pipeline.init(&render_instance.0.read().unwrap()) {
386 Ok(_) => (),
387 Err(e) => {
388 error_once!("Failed to load pipeline: {:?}", e);
389 continue;
390 }
391 }
392
393 pipelines_loaded_indices.push((*id, pipeline));
395 }
396
397 while let Some((id, pipeline)) = pipelines_loaded_indices.pop() {
399 pipeline_manager.processing_compute_pipelines.remove(&id);
400 pipeline_manager
401 .loaded_compute_pipelines
402 .insert(id, pipeline);
403 pipeline_manager
404 .loaded_compute_pipelines_desc
405 .insert(id, pipelines_loaded_desc.remove(&id).unwrap());
406 }
407
408 pipeline_manager.shader_to_pipelines = shaders_to_pipelines;
410}