Skip to main content

wde_renderer/assets/
shader.rs

1use wde_logger::prelude::*;
2
3use bevy::{
4    asset::{AssetLoader, LoadContext, io::Reader},
5    prelude::*
6};
7use std::collections::HashSet;
8use std::io::{Error, ErrorKind};
9use thiserror::Error;
10
11/// Stores a shader source as UTF-8 text. File should have a `.wgsl` extension.
12/// Most of the time, the user does not need to load shaders directly, as they can be embedded in materials and pipelines.
13/// Note: the shader will only be compiled on the GPU when used in a pipeline, so this is just a container for the source code.
14#[derive(Asset, TypePath, Clone, Debug)]
15pub struct Shader {
16    /// WGSL source contents as UTF-8 text.
17    pub content: String
18}
19
20#[derive(Debug, Error)]
21pub(crate) enum ShaderLoaderError {
22    #[error("Could not load shader: {0}")]
23    Io(#[from] std::io::Error),
24    #[error("Failed to resolve shader include: {0}")]
25    Include(String)
26}
27#[derive(Default, TypePath)]
28pub(crate) struct ShaderLoader;
29impl AssetLoader for ShaderLoader {
30    type Asset = Shader;
31    type Settings = ();
32    type Error = ShaderLoaderError;
33
34    async fn load(
35        &self,
36        reader: &mut dyn Reader,
37        _settings: &Self::Settings,
38        load_context: &mut LoadContext<'_>
39    ) -> Result<Self::Asset, Self::Error> {
40        debug!("Loading shader {}.", load_context.path());
41
42        // Read the shader data
43        let mut bytes = Vec::new();
44        reader.read_to_end(&mut bytes).await?;
45
46        // Read the content
47        let content = match String::from_utf8(bytes) {
48            Ok(content) => content,
49            Err(_) => {
50                return Err(ShaderLoaderError::Io(Error::new(
51                    ErrorKind::InvalidData,
52                    "Could not convert shader to string."
53                )));
54            }
55        };
56
57        // Resolve #include directives
58        let mut included = HashSet::new();
59        let content = resolve_includes(content, load_context, &mut included).await?;
60        Ok(Shader { content })
61    }
62
63    fn extensions(&self) -> &[&str] {
64        &["wgsl"]
65    }
66}
67
68/// Resolve `#include "path"` directives in `source`, inlining the content of each referenced file.
69///
70/// Paths are relative to the asset root (`res/`). Each file is included at most once per
71/// compilation unit — duplicate `#include` lines for the same path are silently skipped,
72/// which also prevents infinite loops from circular includes.
73///
74/// Included files are registered as hot-reload dependencies: editing an include file
75/// triggers a reload of every shader that (directly or transitively) includes it.
76async fn resolve_includes(
77    source: String,
78    load_context: &mut LoadContext<'_>,
79    included: &mut HashSet<String>
80) -> Result<String, ShaderLoaderError> {
81    // Work on a mutable list of lines so we can splice included content in-place.
82    // The cursor `i` never skips newly inserted lines, so nested includes are processed
83    // automatically without recursion.
84    let mut lines: Vec<String> = source.lines().map(str::to_owned).collect();
85    let mut i = 0;
86    while i < lines.len() {
87        let trimmed = lines[i].trim().to_owned();
88        if let Some(rest) = trimmed.strip_prefix("#include") {
89            let path_str = rest.trim();
90            if path_str.starts_with('"') && path_str.ends_with('"') && path_str.len() >= 2 {
91                let include_path = path_str[1..path_str.len() - 1].to_owned();
92                if !included.contains(&include_path) {
93                    included.insert(include_path.clone());
94                    let bytes = load_context
95                        .read_asset_bytes(include_path.clone())
96                        .await
97                        .map_err(|e| {
98                            ShaderLoaderError::Include(format!(
99                                "Could not read '{}': {}",
100                                include_path, e
101                            ))
102                        })?;
103                    let include_source = String::from_utf8(bytes).map_err(|_| {
104                        ShaderLoaderError::Include(format!("'{}' is not valid UTF-8", include_path))
105                    })?;
106                    let new_lines: Vec<String> =
107                        include_source.lines().map(str::to_owned).collect();
108                    // Replace the #include line with the file contents.
109                    // The cursor stays at `i` so nested #include lines in the
110                    // inserted block are processed on the next iterations.
111                    lines.splice(i..=i, new_lines);
112                } else {
113                    // Already included — remove this directive and advance.
114                    lines.remove(i);
115                }
116                continue;
117            } else {
118                return Err(ShaderLoaderError::Include(format!(
119                    "Invalid #include syntax on line {i}: expected `#include \"path\"`"
120                )));
121            }
122        }
123        i += 1;
124    }
125    Ok(lines.join("\n"))
126}