#
tokens: 45974/50000 4/78 files (page 4/7)
lines: off (toggle) GitHub
raw markdown copy
This is page 4 of 7. Use http://codebase.md/mehmetoguzderin/shaderc-vkrunner-mcp?page={x} to view the full context.

# Directory Structure

```
├── .devcontainer
│   ├── devcontainer.json
│   ├── docker-compose.yml
│   └── Dockerfile
├── .gitattributes
├── .github
│   └── workflows
│       └── build-push-image.yml
├── .gitignore
├── .vscode
│   └── mcp.json
├── Cargo.lock
├── Cargo.toml
├── Dockerfile
├── LICENSE
├── README.adoc
├── shaderc-vkrunner-mcp.jpg
├── src
│   └── main.rs
└── vkrunner
    ├── .editorconfig
    ├── .gitignore
    ├── .gitlab-ci.yml
    ├── build.rs
    ├── Cargo.toml
    ├── COPYING
    ├── examples
    │   ├── compute-shader.shader_test
    │   ├── cooperative-matrix.shader_test
    │   ├── depth-buffer.shader_test
    │   ├── desc_set_and_binding.shader_test
    │   ├── entrypoint.shader_test
    │   ├── float-framebuffer.shader_test
    │   ├── frexp.shader_test
    │   ├── geometry.shader_test
    │   ├── indices.shader_test
    │   ├── layouts.shader_test
    │   ├── properties.shader_test
    │   ├── push-constants.shader_test
    │   ├── require-subgroup-size.shader_test
    │   ├── row-major.shader_test
    │   ├── spirv.shader_test
    │   ├── ssbo.shader_test
    │   ├── tolerance.shader_test
    │   ├── tricolore.shader_test
    │   ├── ubo.shader_test
    │   ├── vertex-data-piglit.shader_test
    │   └── vertex-data.shader_test
    ├── include
    │   ├── vk_video
    │   │   ├── vulkan_video_codec_av1std_decode.h
    │   │   ├── vulkan_video_codec_av1std_encode.h
    │   │   ├── vulkan_video_codec_av1std.h
    │   │   ├── vulkan_video_codec_h264std_decode.h
    │   │   ├── vulkan_video_codec_h264std_encode.h
    │   │   ├── vulkan_video_codec_h264std.h
    │   │   ├── vulkan_video_codec_h265std_decode.h
    │   │   ├── vulkan_video_codec_h265std_encode.h
    │   │   ├── vulkan_video_codec_h265std.h
    │   │   └── vulkan_video_codecs_common.h
    │   └── vulkan
    │       ├── vk_platform.h
    │       ├── vulkan_core.h
    │       └── vulkan.h
    ├── precompile-script.py
    ├── README.md
    ├── scripts
    │   └── update-vulkan.sh
    ├── src
    │   └── main.rs
    ├── test-build.sh
    └── vkrunner
        ├── allocate_store.rs
        ├── buffer.rs
        ├── compiler
        │   └── fake_process.rs
        ├── compiler.rs
        ├── config.rs
        ├── context.rs
        ├── enum_table.rs
        ├── env_var_test.rs
        ├── executor.rs
        ├── fake_vulkan.rs
        ├── features.rs
        ├── flush_memory.rs
        ├── format_table.rs
        ├── format.rs
        ├── half_float.rs
        ├── hex.rs
        ├── inspect.rs
        ├── lib.rs
        ├── logger.rs
        ├── make-enums.py
        ├── make-features.py
        ├── make-formats.py
        ├── make-pipeline-key-data.py
        ├── make-vulkan-funcs-data.py
        ├── parse_num.rs
        ├── pipeline_key_data.rs
        ├── pipeline_key.rs
        ├── pipeline_set.rs
        ├── requirements.rs
        ├── result.rs
        ├── script.rs
        ├── shader_stage.rs
        ├── slot.rs
        ├── small_float.rs
        ├── source.rs
        ├── stream.rs
        ├── temp_file.rs
        ├── tester.rs
        ├── tolerance.rs
        ├── util.rs
        ├── vbo.rs
        ├── vk.rs
        ├── vulkan_funcs_data.rs
        ├── vulkan_funcs.rs
        ├── window_format.rs
        └── window.rs
```

# Files

--------------------------------------------------------------------------------
/vkrunner/vkrunner/pipeline_set.rs:
--------------------------------------------------------------------------------

```rust
// vkrunner
//
// Copyright (C) 2018 Intel Corporation
// Copypright 2023 Neil Roberts
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice (including the next
// paragraph) shall be included in all copies or substantial portions of the
// Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::window::Window;
use crate::compiler;
use crate::shader_stage;
use crate::vk;
use crate::script::{Script, Buffer, BufferType, Operation};
use crate::pipeline_key;
use crate::logger::Logger;
use crate::vbo::Vbo;
use std::rc::Rc;
use std::ptr;
use std::mem;
use std::fmt;

#[derive(Debug)]
pub struct PipelineSet {
    pipelines: PipelineVec,
    layout: PipelineLayout,
    // The descriptor data is only created if there are buffers in the
    // script
    descriptor_data: Option<DescriptorData>,
    stages: vk::VkShaderStageFlagBits,
    // These are never read but they should probably be kept alive
    // until the pipelines are destroyed.
    _pipeline_cache: PipelineCache,
    _modules: [Option<ShaderModule>; shader_stage::N_STAGES],
}

#[repr(C)]
#[derive(Debug)]
pub struct RectangleVertex {
    pub x: f32,
    pub y: f32,
    pub z: f32,
}

/// An error that can be returned by [PipelineSet::new].
#[derive(Debug)]
pub enum Error {
    /// Compiling one of the shaders in the script failed
    CompileError(compiler::Error),
    /// vkCreatePipelineCache failed
    CreatePipelineCacheFailed,
    /// vkCreateDescriptorPool failed
    CreateDescriptorPoolFailed,
    /// vkCreateDescriptorSetLayout failed
    CreateDescriptorSetLayoutFailed,
    /// vkCreatePipelineLayout failed
    CreatePipelineLayoutFailed,
    /// vkCreatePipeline failed
    CreatePipelineFailed,
}

impl From<compiler::Error> for Error {
    fn from(e: compiler::Error) -> Error {
        Error::CompileError(e)
    }
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
        match self {
            Error::CompileError(e) => e.fmt(f),
            Error::CreatePipelineCacheFailed => {
                write!(f, "vkCreatePipelineCache failed")
            },
            Error::CreateDescriptorPoolFailed => {
                write!(f, "vkCreateDescriptorPool failed")
            },
            Error::CreateDescriptorSetLayoutFailed => {
                write!(f, "vkCreateDescriptorSetLayout failed")
            },
            Error::CreatePipelineLayoutFailed => {
                write!(f, "vkCreatePipelineLayout failed")
            },
            Error::CreatePipelineFailed => {
                write!(f, "Pipeline creation function failed")
            },
        }
    }
}

#[derive(Debug)]
struct ShaderModule {
    handle: vk::VkShaderModule,
    // needed for the destructor
    window: Rc<Window>,
}

impl Drop for ShaderModule {
    fn drop(&mut self) {
        unsafe {
            self.window.device().vkDestroyShaderModule.unwrap()(
                self.window.vk_device(),
                self.handle,
                ptr::null(), // allocator
            );
        }
    }
}

#[derive(Debug)]
struct PipelineCache {
    handle: vk::VkPipelineCache,
    // needed for the destructor
    window: Rc<Window>,
}

impl Drop for PipelineCache {
    fn drop(&mut self) {
        unsafe {
            self.window.device().vkDestroyPipelineCache.unwrap()(
                self.window.vk_device(),
                self.handle,
                ptr::null(), // allocator
            );
        }
    }
}

impl PipelineCache {
    fn new(window: Rc<Window>) -> Result<PipelineCache, Error> {
        let create_info = vk::VkPipelineCacheCreateInfo {
            sType: vk::VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO,
            flags: 0,
            pNext: ptr::null(),
            initialDataSize: 0,
            pInitialData: ptr::null(),
        };

        let mut handle = vk::null_handle();

        let res = unsafe {
            window.device().vkCreatePipelineCache.unwrap()(
                window.vk_device(),
                ptr::addr_of!(create_info),
                ptr::null(), // allocator
                ptr::addr_of_mut!(handle),
            )
        };

        if res == vk::VK_SUCCESS {
            Ok(PipelineCache { handle, window })
        } else {
            Err(Error::CreatePipelineCacheFailed)
        }
    }
}

#[derive(Debug)]
struct DescriptorPool {
    handle: vk::VkDescriptorPool,
    // needed for the destructor
    window: Rc<Window>,
}

impl Drop for DescriptorPool {
    fn drop(&mut self) {
        unsafe {
            self.window.device().vkDestroyDescriptorPool.unwrap()(
                self.window.vk_device(),
                self.handle,
                ptr::null(), // allocator
            );
        }
    }
}

impl DescriptorPool {
    fn new(
        window: Rc<Window>,
        buffers: &[Buffer],
    ) -> Result<DescriptorPool, Error> {
        let mut n_ubos = 0;
        let mut n_ssbos = 0;

        for buffer in buffers {
            match buffer.buffer_type {
                BufferType::Ubo => n_ubos += 1,
                BufferType::Ssbo => n_ssbos += 1,
            }
        }

        let mut pool_sizes = Vec::<vk::VkDescriptorPoolSize>::new();

        if n_ubos > 0 {
            pool_sizes.push(vk::VkDescriptorPoolSize {
                type_: vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
                descriptorCount: n_ubos,
            });
        }
        if n_ssbos > 0 {
            pool_sizes.push(vk::VkDescriptorPoolSize {
                type_: vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
                descriptorCount: n_ssbos,
            });
        }

        // The descriptor pool shouldn’t have been created if there
        // were no buffers
        assert!(!pool_sizes.is_empty());

        let create_info = vk::VkDescriptorPoolCreateInfo {
            sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO,
            flags: vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT,
            pNext: ptr::null(),
            maxSets: n_desc_sets(buffers) as u32,
            poolSizeCount: pool_sizes.len() as u32,
            pPoolSizes: pool_sizes.as_ptr(),
        };

        let mut handle = vk::null_handle();

        let res = unsafe {
            window.device().vkCreateDescriptorPool.unwrap()(
                window.vk_device(),
                ptr::addr_of!(create_info),
                ptr::null(), // allocator
                ptr::addr_of_mut!(handle),
            )
        };

        if res == vk::VK_SUCCESS {
            Ok(DescriptorPool { handle, window })
        } else {
            Err(Error::CreateDescriptorPoolFailed)
        }
    }
}

#[derive(Debug)]
struct DescriptorSetLayoutVec {
    handles: Vec<vk::VkDescriptorSetLayout>,
    // needed for the destructor
    window: Rc<Window>,
}

impl Drop for DescriptorSetLayoutVec {
    fn drop(&mut self) {
        for &handle in self.handles.iter() {
            unsafe {
                self.window.device().vkDestroyDescriptorSetLayout.unwrap()(
                    self.window.vk_device(),
                    handle,
                    ptr::null(), // allocator
                );
            }
        }
    }
}

impl DescriptorSetLayoutVec {
    fn new(window: Rc<Window>) -> DescriptorSetLayoutVec {
        DescriptorSetLayoutVec { window, handles: Vec::new() }
    }

    fn len(&self) -> usize {
        self.handles.len()
    }

    fn as_ptr(&self) -> *const vk::VkDescriptorSetLayout {
        self.handles.as_ptr()
    }

    fn add(
        &mut self,
        bindings: &[vk::VkDescriptorSetLayoutBinding],
    ) -> Result<(), Error> {
        let create_info = vk::VkDescriptorSetLayoutCreateInfo {
            sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
            flags: 0,
            pNext: ptr::null(),
            bindingCount: bindings.len() as u32,
            pBindings: bindings.as_ptr(),
        };

        let mut handle = vk::null_handle();

        let res = unsafe {
            self.window.device().vkCreateDescriptorSetLayout.unwrap()(
                self.window.vk_device(),
                ptr::addr_of!(create_info),
                ptr::null(), // allocator
                ptr::addr_of_mut!(handle),
            )
        };

        if res == vk::VK_SUCCESS {
            self.handles.push(handle);
            Ok(())
        } else {
            Err(Error::CreateDescriptorSetLayoutFailed)
        }
    }
}

fn create_descriptor_set_layouts(
    window: Rc<Window>,
    buffers: &[Buffer],
    stages: vk::VkShaderStageFlags,
) -> Result<DescriptorSetLayoutVec, Error> {
    let n_desc_sets = n_desc_sets(buffers);
    let mut layouts = DescriptorSetLayoutVec::new(window);
    let mut bindings = Vec::new();
    let mut buffer_num = 0;

    for desc_set in 0..n_desc_sets {
        bindings.clear();

        while buffer_num < buffers.len()
            && buffers[buffer_num].desc_set as usize == desc_set
        {
            let buffer = &buffers[buffer_num];

            bindings.push(vk::VkDescriptorSetLayoutBinding {
                binding: buffer.binding,
                descriptorType: match buffer.buffer_type {
                    BufferType::Ubo => vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
                    BufferType::Ssbo => vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
                },
                descriptorCount: 1,
                stageFlags: stages,
                pImmutableSamplers: ptr::null(),
            });

            buffer_num += 1;
        }

        layouts.add(&bindings)?;
    }

    assert_eq!(layouts.len(), n_desc_sets);

    Ok(layouts)
}

fn n_desc_sets(buffers: &[Buffer]) -> usize {
    match buffers.last() {
        // The number of descriptor sets is the highest used
        // descriptor set index + 1. The buffers are in order so the
        // highest one should be the last one.
        Some(last) => last.desc_set as usize + 1,
        None => 0,
    }
}

#[derive(Debug)]
struct DescriptorData {
    pool: DescriptorPool,
    layouts: DescriptorSetLayoutVec,
}

impl DescriptorData {
    fn new(
        window: &Rc<Window>,
        buffers: &[Buffer],
        stages: vk::VkShaderStageFlagBits,
    ) -> Result<DescriptorData, Error> {
        let pool = DescriptorPool::new(
            Rc::clone(&window),
            buffers,
        )?;

        let layouts = create_descriptor_set_layouts(
            Rc::clone(&window),
            buffers,
            stages,
        )?;

        Ok(DescriptorData { pool, layouts })
    }
}

fn compile_shaders(
    logger: &mut Logger,
    window: &Rc<Window>,
    script: &Script,
    show_disassembly: bool,
) -> Result<[Option<ShaderModule>; shader_stage::N_STAGES], Error> {
    let mut modules: [Option<ShaderModule>; shader_stage::N_STAGES] =
        Default::default();

    for &stage in shader_stage::ALL_STAGES.iter() {
        if script.shaders(stage).is_empty() {
            continue;
        }

        modules[stage as usize] = Some(ShaderModule {
            handle: compiler::build_stage(
                logger,
                window.context(),
                script,
                stage,
                show_disassembly,
            )?,
            window: Rc::clone(window),
        });
    }

    Ok(modules)
}

fn stage_flags(script: &Script) -> vk::VkShaderStageFlagBits {
    // Set a flag for each stage that has a shader in the script
    shader_stage::ALL_STAGES
        .iter()
        .filter_map(|&stage| if script.shaders(stage).is_empty() {
            None
        } else {
            Some(stage.flag())
        })
        .fold(0, |a, b| a | b)
}

fn push_constant_size(script: &Script) -> usize {
    script.commands()
        .iter()
        .map(|command| match &command.op {
            Operation::SetPushCommand { offset, data } => offset + data.len(),
            _ => 0,
        })
        .max()
        .unwrap_or(0)
}

#[derive(Debug)]
struct PipelineLayout {
    handle: vk::VkPipelineLayout,
    // needed for the destructor
    window: Rc<Window>,
}

impl Drop for PipelineLayout {
    fn drop(&mut self) {
        unsafe {
            self.window.device().vkDestroyPipelineLayout.unwrap()(
                self.window.vk_device(),
                self.handle,
                ptr::null(), // allocator
            );
        }
    }
}

impl PipelineLayout {
    fn new(
        window: Rc<Window>,
        script: &Script,
        stages: vk::VkShaderStageFlagBits,
        descriptor_data: Option<&DescriptorData>
    ) -> Result<PipelineLayout, Error> {
        let mut handle = vk::null_handle();

        let push_constant_range = vk::VkPushConstantRange {
            stageFlags: stages,
            offset: 0,
            size: push_constant_size(script) as u32,
        };

        let mut create_info = vk::VkPipelineLayoutCreateInfo::default();

        create_info.sType = vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;

        if push_constant_range.size > 0 {
            create_info.pushConstantRangeCount = 1;
            create_info.pPushConstantRanges =
                ptr::addr_of!(push_constant_range);
        }

        if let Some(descriptor_data) = descriptor_data {
            create_info.setLayoutCount = descriptor_data.layouts.len() as u32;
            create_info.pSetLayouts = descriptor_data.layouts.as_ptr();
        }

        let res = unsafe {
            window.device().vkCreatePipelineLayout.unwrap()(
                window.vk_device(),
                ptr::addr_of!(create_info),
                ptr::null(), // allocator
                ptr::addr_of_mut!(handle),
            )
        };

        if res == vk::VK_SUCCESS {
            Ok(PipelineLayout { handle, window })
        } else {
            Err(Error::CreatePipelineLayoutFailed)
        }
    }
}

#[derive(Debug)]
struct VertexInputState {
    create_info: vk::VkPipelineVertexInputStateCreateInfo,
    // These are not read put the create_info struct keeps pointers to
    // them so they need to be kept alive
    _input_bindings: Vec::<vk::VkVertexInputBindingDescription>,
    _attribs: Vec::<vk::VkVertexInputAttributeDescription>,
}

impl VertexInputState {
    fn new(script: &Script, key: &pipeline_key::Key) -> VertexInputState {
        let mut input_bindings = Vec::new();
        let mut attribs = Vec::new();

        match key.source() {
            pipeline_key::Source::Rectangle => {
                input_bindings.push(vk::VkVertexInputBindingDescription {
                    binding: 0,
                    stride: mem::size_of::<RectangleVertex>() as u32,
                    inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX,
                });
                attribs.push(vk::VkVertexInputAttributeDescription {
                    location: 0,
                    binding: 0,
                    format: vk::VK_FORMAT_R32G32B32_SFLOAT,
                    offset: 0,
                });
            },
            pipeline_key::Source::VertexData => {
                if let Some(vbo) = script.vertex_data() {
                    VertexInputState::set_up_vertex_data_attribs(
                        &mut input_bindings,
                        &mut attribs,
                        vbo
                    );
                }
            }
        }

        VertexInputState {
            create_info: vk::VkPipelineVertexInputStateCreateInfo {
                sType:
                vk::VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,
                flags: 0,
                pNext: ptr::null(),
                vertexBindingDescriptionCount: input_bindings.len() as u32,
                pVertexBindingDescriptions: input_bindings.as_ptr(),
                vertexAttributeDescriptionCount: attribs.len() as u32,
                pVertexAttributeDescriptions: attribs.as_ptr(),
            },
            _input_bindings: input_bindings,
            _attribs: attribs,
        }
    }

    fn set_up_vertex_data_attribs(
        input_bindings: &mut Vec::<vk::VkVertexInputBindingDescription>,
        attribs: &mut Vec::<vk::VkVertexInputAttributeDescription>,
        vbo: &Vbo,
    ) {
        input_bindings.push(vk::VkVertexInputBindingDescription {
            binding: 0,
            stride: vbo.stride() as u32,
            inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX,
        });

        for attrib in vbo.attribs().iter() {
            attribs.push(vk::VkVertexInputAttributeDescription {
                location: attrib.location(),
                binding: 0,
                format: attrib.format().vk_format,
                offset: attrib.offset() as u32,
            });
        }
    }
}

#[derive(Debug)]
struct PipelineVec {
    handles: Vec<vk::VkPipeline>,
    // needed for the destructor
    window: Rc<Window>,
}

impl Drop for PipelineVec {
    fn drop(&mut self) {
        for &handle in self.handles.iter() {
            unsafe {
                self.window.device().vkDestroyPipeline.unwrap()(
                    self.window.vk_device(),
                    handle,
                    ptr::null(), // allocator
                );
            }
        }
    }
}

impl PipelineVec {
    fn new(
        window: Rc<Window>,
        script: &Script,
        pipeline_cache: vk::VkPipelineCache,
        layout: vk::VkPipelineLayout,
        modules: &[Option<ShaderModule>],
    ) -> Result<PipelineVec, Error> {
        let mut vec = PipelineVec { window, handles: Vec::new() };

        let mut first_graphics_pipeline: Option<vk::VkPipeline> = None;

        for key in script.pipeline_keys().iter() {
            let pipeline = match key.pipeline_type() {
                pipeline_key::Type::Graphics => {
                    let allow_derivatives =
                        first_graphics_pipeline.is_none()
                        && script.pipeline_keys().len() > 1;

                    let pipeline = PipelineVec::create_graphics_pipeline(
                        &vec.window,
                        script,
                        key,
                        pipeline_cache,
                        layout,
                        modules,
                        allow_derivatives,
                        first_graphics_pipeline,
                    )?;

                    first_graphics_pipeline.get_or_insert(pipeline);

                    pipeline
                },
                pipeline_key::Type::Compute => {
                    PipelineVec::create_compute_pipeline(
                        &vec.window,
                        key,
                        pipeline_cache,
                        layout,
                        modules[shader_stage::Stage::Compute as usize]
                            .as_ref()
                            .map(|m| m.handle)
                            .unwrap_or(vk::null_handle()),
                        script.requirements().required_subgroup_size,
                    )?
                },
            };

            vec.handles.push(pipeline);
        }

        Ok(vec)
    }

    fn null_terminated_entrypoint(
        key: &pipeline_key::Key,
        stage: shader_stage::Stage,
    ) -> String {
        let mut entrypoint = key.entrypoint(stage).to_string();
        entrypoint.push('\0');
        entrypoint
    }

    fn create_stages(
        modules: &[Option<ShaderModule>],
        entrypoints: &[String],
    ) -> Vec<vk::VkPipelineShaderStageCreateInfo> {
        let mut stages = Vec::new();

        for &stage in shader_stage::ALL_STAGES.iter() {
            if stage == shader_stage::Stage::Compute {
                continue;
            }

            let module = match &modules[stage as usize] {
                Some(module) => module.handle,
                None => continue,
            };

            stages.push(vk::VkPipelineShaderStageCreateInfo {
                sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
                flags: 0,
                pNext: ptr::null(),
                stage: stage.flag(),
                module,
                pName: entrypoints[stage as usize].as_ptr().cast(),
                pSpecializationInfo: ptr::null(),
            });
        }

        stages
    }

    fn create_graphics_pipeline(
        window: &Window,
        script: &Script,
        key: &pipeline_key::Key,
        pipeline_cache: vk::VkPipelineCache,
        layout: vk::VkPipelineLayout,
        modules: &[Option<ShaderModule>],
        allow_derivatives: bool,
        parent_pipeline: Option<vk::VkPipeline>,
    ) -> Result<vk::VkPipeline, Error> {
        let create_info_buf = key.to_create_info();

        let mut create_info = vk::VkGraphicsPipelineCreateInfo::default();

        unsafe {
            ptr::copy_nonoverlapping(
                create_info_buf.as_ptr().cast(),
                ptr::addr_of_mut!(create_info),
                1, // count
            );
        }

        let entrypoints = shader_stage::ALL_STAGES
            .iter()
            .map(|&s| PipelineVec::null_terminated_entrypoint(key, s))
            .collect::<Vec<String>>();
        let stages = PipelineVec::create_stages(modules, &entrypoints);

        let window_format = window.format();

        let viewport = vk::VkViewport {
            x: 0.0,
            y: 0.0,
            width: window_format.width as f32,
            height: window_format.height as f32,
            minDepth: 0.0,
            maxDepth: 1.0,
        };

        let scissor = vk::VkRect2D {
            offset: vk::VkOffset2D {
                x: 0,
                y: 0,
            },
            extent: vk::VkExtent2D {
                width: window_format.width as u32,
                height: window_format.height as u32,
            },
        };

        let viewport_state = vk::VkPipelineViewportStateCreateInfo {
            sType: vk::VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO,
            flags: 0,
            pNext: ptr::null(),
            viewportCount: 1,
            pViewports: ptr::addr_of!(viewport),
            scissorCount: 1,
            pScissors: ptr::addr_of!(scissor),
        };

        let multisample_state = vk::VkPipelineMultisampleStateCreateInfo {
            sType: vk::VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO,
            pNext: ptr::null(),
            flags: 0,
            rasterizationSamples: vk::VK_SAMPLE_COUNT_1_BIT,
            sampleShadingEnable: vk::VK_FALSE,
            minSampleShading: 0.0,
            pSampleMask: ptr::null(),
            alphaToCoverageEnable: vk::VK_FALSE,
            alphaToOneEnable: vk::VK_FALSE,
        };

        create_info.pViewportState = ptr::addr_of!(viewport_state);
        create_info.pMultisampleState = ptr::addr_of!(multisample_state);
        create_info.subpass = 0;
        create_info.basePipelineHandle =
            parent_pipeline.unwrap_or(vk::null_handle());
        create_info.basePipelineIndex = -1;

        create_info.stageCount = stages.len() as u32;
        create_info.pStages = stages.as_ptr();
        create_info.layout = layout;
        create_info.renderPass = window.render_passes()[0];

        if allow_derivatives {
            create_info.flags |= vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT;
        }
        if parent_pipeline.is_some() {
            create_info.flags |= vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT;
        }

        if modules[shader_stage::Stage::TessCtrl as usize].is_none()
            && modules[shader_stage::Stage::TessEval as usize].is_none()
        {
            create_info.pTessellationState = ptr::null();
        }

        let vertex_input_state = VertexInputState::new(script, key);
        create_info.pVertexInputState =
            ptr::addr_of!(vertex_input_state.create_info);

        let mut handle = vk::null_handle();

        let res = unsafe {
            window.device().vkCreateGraphicsPipelines.unwrap()(
                window.vk_device(),
                pipeline_cache,
                1, // nCreateInfos
                ptr::addr_of!(create_info),
                ptr::null(), // allocator
                ptr::addr_of_mut!(handle),
            )
        };

        if res == vk::VK_SUCCESS {
            Ok(handle)
        } else {
            Err(Error::CreatePipelineFailed)
        }
    }

    fn create_compute_pipeline(
        window: &Window,
        key: &pipeline_key::Key,
        pipeline_cache: vk::VkPipelineCache,
        layout: vk::VkPipelineLayout,
        module: vk::VkShaderModule,
        required_subgroup_size: Option<u32>,
    ) -> Result<vk::VkPipeline, Error> {
        let entrypoint = PipelineVec::null_terminated_entrypoint(
            key,
            shader_stage::Stage::Compute,
        );

        let rss_info = required_subgroup_size.map(|size|
            vk::VkPipelineShaderStageRequiredSubgroupSizeCreateInfo {
                sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
                pNext: ptr::null_mut(),
                requiredSubgroupSize: size,
            });

        let create_info = vk::VkComputePipelineCreateInfo {
            sType: vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
            pNext: ptr::null(),
            flags: 0,

            stage: vk::VkPipelineShaderStageCreateInfo {
                sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
                pNext: if let Some(ref rss_info) = rss_info {
                    ptr::addr_of!(*rss_info).cast()
                } else {
                    ptr::null_mut()
                },
                flags: 0,
                stage: vk::VK_SHADER_STAGE_COMPUTE_BIT,
                module,
                pName: entrypoint.as_ptr().cast(),
                pSpecializationInfo: ptr::null(),
            },

            layout,
            basePipelineHandle: vk::null_handle(),
            basePipelineIndex: -1,
        };

        let mut handle = vk::null_handle();

        let res = unsafe {
            window.device().vkCreateComputePipelines.unwrap()(
                window.vk_device(),
                pipeline_cache,
                1, // nCreateInfos
                ptr::addr_of!(create_info),
                ptr::null(), // allocator
                ptr::addr_of_mut!(handle),
            )
        };

        if res == vk::VK_SUCCESS {
            Ok(handle)
        } else {
            Err(Error::CreatePipelineFailed)
        }
    }
}

impl PipelineSet {
    pub fn new(
        logger: &mut Logger,
        window: Rc<Window>,
        script: &Script,
        show_disassembly: bool,
    ) -> Result<PipelineSet, Error> {
        let modules = compile_shaders(
            logger,
            &window,
            script,
            show_disassembly
        )?;

        let pipeline_cache = PipelineCache::new(Rc::clone(&window))?;

        let stages = stage_flags(script);

        let descriptor_data = if script.buffers().is_empty() {
            None
        } else {
            Some(DescriptorData::new(&window, script.buffers(), stages)?)
        };

        let layout = PipelineLayout::new(
            Rc::clone(&window),
            script,
            stages,
            descriptor_data.as_ref(),
        )?;

        let pipelines = PipelineVec::new(
            Rc::clone(&window),
            script,
            pipeline_cache.handle,
            layout.handle,
            &modules,
        )?;

        Ok(PipelineSet {
            _modules: modules,
            _pipeline_cache: pipeline_cache,
            stages,
            descriptor_data,
            layout,
            pipelines,
        })
    }

    pub fn descriptor_set_layouts(&self) -> &[vk::VkDescriptorSetLayout] {
        match self.descriptor_data.as_ref() {
            Some(data) => &data.layouts.handles,
            None => &[],
        }
    }

    pub fn stages(&self) -> vk::VkShaderStageFlagBits {
        self.stages
    }

    pub fn layout(&self) -> vk::VkPipelineLayout {
        self.layout.handle
    }

    pub fn pipelines(&self) -> &[vk::VkPipeline] {
        &self.pipelines.handles
    }

    pub fn descriptor_pool(&self) -> Option<vk::VkDescriptorPool> {
        self.descriptor_data.as_ref().map(|data| data.pool.handle)
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::fake_vulkan::{FakeVulkan, HandleType, PipelineCreateInfo};
    use crate::fake_vulkan::GraphicsPipelineCreateInfo;
    use crate::fake_vulkan::PipelineLayoutCreateInfo;
    use crate::context::Context;
    use crate::requirements::Requirements;
    use crate::source::Source;
    use crate::config::Config;

    #[derive(Debug)]
    struct TestData {
        pipeline_set: PipelineSet,
        window: Rc<Window>,
        fake_vulkan: Box<FakeVulkan>,
    }

    impl TestData {
        fn new_with_errors<F: FnOnce(&mut FakeVulkan)>(
            source: &str,
            queue_errors: F
        ) -> Result<TestData, Error> {
            let mut fake_vulkan = FakeVulkan::new();

            fake_vulkan.physical_devices.push(Default::default());
            fake_vulkan.physical_devices[0].format_properties.insert(
                vk::VK_FORMAT_B8G8R8A8_UNORM,
                vk::VkFormatProperties {
                    linearTilingFeatures: 0,
                    optimalTilingFeatures:
                    vk::VK_FORMAT_FEATURE_COLOR_ATTACHMENT_BIT
                        | vk::VK_FORMAT_FEATURE_BLIT_SRC_BIT,
                    bufferFeatures: 0,
                },
            );

            let memory_properties =
                &mut fake_vulkan.physical_devices[0].memory_properties;
            memory_properties.memoryTypes[0].propertyFlags =
                vk::VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
            memory_properties.memoryTypeCount = 1;
            fake_vulkan.memory_requirements.memoryTypeBits = 1;

            fake_vulkan.set_override();
            let context = Rc::new(Context::new(
                &Requirements::new(),
                None, // device_id
            ).unwrap());

            let window = Rc::new(Window::new(
                Rc::clone(&context),
                &Default::default(), // format
            ).unwrap());

            queue_errors(&mut fake_vulkan);

            let mut logger = Logger::new(None, ptr::null_mut());

            let source = Source::from_string(source.to_string());
            let script = Script::load(&Config::new(), &source).unwrap();

            let pipeline_set = PipelineSet::new(
                &mut logger,
                Rc::clone(&window),
                &script,
                false, // show_disassembly
            )?;

            Ok(TestData { fake_vulkan, window, pipeline_set })
        }

        fn new(source: &str) -> Result<TestData, Error> {
            TestData::new_with_errors(source, |_| ())
        }

        fn graphics_create_info(
            &mut self,
            pipeline_num: usize,
        ) -> GraphicsPipelineCreateInfo {
            let pipeline = self.pipeline_set.pipelines.handles[pipeline_num];

            match &self.fake_vulkan.get_handle(pipeline).data {
                HandleType::Pipeline(PipelineCreateInfo::Graphics(info)) => {
                    info.clone()
                },
                handle @ _ => {
                    unreachable!("unexpected handle type: {:?}", handle)
                },
            }
        }

        fn compute_create_info(
            &mut self,
            pipeline_num: usize,
        ) -> vk::VkComputePipelineCreateInfo {
            let pipeline = self.pipeline_set.pipelines.handles[pipeline_num];

            match &self.fake_vulkan.get_handle(pipeline).data {
                HandleType::Pipeline(PipelineCreateInfo::Compute(info)) => {
                    info.clone()
                },
                handle @ _ => {
                    unreachable!("unexpected handle type: {:?}", handle)
                },
            }
        }

        fn shader_module_code(
            &mut self,
            stage: shader_stage::Stage,
        ) -> Vec<u32> {
            let module = self.pipeline_set._modules[stage as usize]
                .as_ref()
                .unwrap()
                .handle;

            match &self.fake_vulkan.get_handle(module).data {
                HandleType::ShaderModule { code } => code.clone(),
                _ => unreachable!("Unexpected Vulkan handle type"),
            }
        }

        fn descriptor_set_layout_bindings(
            &mut self,
            desc_set: usize,
        ) -> Vec<vk::VkDescriptorSetLayoutBinding> {
            let desc_set_layout = self.pipeline_set
                .descriptor_set_layouts()
                [desc_set];

            match &self.fake_vulkan.get_handle(desc_set_layout).data {
                HandleType::DescriptorSetLayout { bindings } => {
                    bindings.clone()
                },
                _ => unreachable!("Unexpected Vulkan handle type"),
            }
        }

        fn pipeline_layout_create_info(&mut self) -> PipelineLayoutCreateInfo {
            let layout = self.pipeline_set.layout();

            match &self.fake_vulkan.get_handle(layout).data {
                HandleType::PipelineLayout(create_info) => create_info.clone(),
                _ => unreachable!("Unexpected Vulkan handle type"),
            }
        }
    }

    #[test]
    fn base() {
        let mut test_data = TestData::new(
            "[vertex shader passthrough]\n\
             [fragment shader]\n\
             03 02 23 07\n\
             fe ca fe ca\n\
             [vertex data]\n\
             0/R32G32_SFLOAT 1/R32_SFLOAT\n\
             -0.5 -0.5 1.52\n\
             0.5 -0.5 1.55\n\
             [compute shader]\n\
             03 02 23 07\n\
             ca fe ca fe\n\
             [test]\n\
             ubo 0 1024\n\
             ssbo 1 1024\n\
             compute 1 1 1\n\
             push float 6 42.0\n\
             draw rect 0 0 1 1\n\
             draw arrays TRIANGLE_LIST 0 2\n"
        ).unwrap();

        assert_eq!(
            test_data.pipeline_set.stages(),
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT
                | vk::VK_SHADER_STAGE_COMPUTE_BIT,
        );

        let descriptor_pool = test_data.pipeline_set.descriptor_pool();
        assert!(matches!(
            test_data.fake_vulkan.get_handle(descriptor_pool.unwrap()).data,
            HandleType::DescriptorPool,
        ));

        assert_eq!(test_data.pipeline_set.pipelines().len(), 3);

        let compute_create_info = test_data.compute_create_info(0);
        assert_eq!(compute_create_info.sType, vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO);

        let create_data = test_data.graphics_create_info(1);
        let create_info = &create_data.create_info;

        assert_eq!(
            create_info.flags,
            vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT,
        );
        assert_eq!(create_info.stageCount, 2);
        assert_eq!(create_info.layout, test_data.pipeline_set.layout());
        assert_eq!(
            create_info.renderPass,
            test_data.window.render_passes()[0]
        );
        assert_eq!(create_info.basePipelineHandle, vk::null_handle());

        assert_eq!(create_data.bindings.len(), 1);
        assert_eq!(create_data.bindings[0].binding, 0);
        assert_eq!(
            create_data.bindings[0].stride as usize,
            mem::size_of::<RectangleVertex>(),
        );
        assert_eq!(
            create_data.bindings[0].inputRate,
            vk::VK_VERTEX_INPUT_RATE_VERTEX,
        );

        assert_eq!(create_data.attribs.len(), 1);
        assert_eq!(create_data.attribs[0].location, 0);
        assert_eq!(create_data.attribs[0].binding, 0);
        assert_eq!(
            create_data.attribs[0].format,
            vk::VK_FORMAT_R32G32B32_SFLOAT
        );
        assert_eq!(create_data.attribs[0].offset, 0);

        let create_data = test_data.graphics_create_info(2);
        let create_info = &create_data.create_info;

        assert_eq!(
            create_info.flags,
            vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT,
        );
        assert_eq!(create_info.stageCount, 2);
        assert_eq!(create_info.layout, test_data.pipeline_set.layout());
        assert_eq!(
            create_info.renderPass,
            test_data.window.render_passes()[0]
        );
        assert_eq!(
            create_info.basePipelineHandle,
            test_data.pipeline_set.pipelines()[1]
        );

        assert_eq!(create_data.bindings.len(), 1);
        assert_eq!(create_data.bindings[0].binding, 0);
        assert_eq!(
            create_data.bindings[0].stride as usize,
            mem::size_of::<f32>() * 3,
        );
        assert_eq!(
            create_data.bindings[0].inputRate,
            vk::VK_VERTEX_INPUT_RATE_VERTEX,
        );

        assert_eq!(create_data.attribs.len(), 2);
        assert_eq!(create_data.attribs[0].location, 0);
        assert_eq!(create_data.attribs[0].binding, 0);
        assert_eq!(
            create_data.attribs[0].format,
            vk::VK_FORMAT_R32G32_SFLOAT
        );
        assert_eq!(create_data.attribs[0].offset, 0);
        assert_eq!(create_data.attribs[1].location, 1);
        assert_eq!(create_data.attribs[1].binding, 0);
        assert_eq!(
            create_data.attribs[1].format,
            vk::VK_FORMAT_R32_SFLOAT
        );
        assert_eq!(create_data.attribs[1].offset, 8);

        let code = test_data.shader_module_code(shader_stage::Stage::Vertex);
        assert!(code.len() > 2);

        let code = test_data.shader_module_code(shader_stage::Stage::Fragment);
        assert_eq!(&code, &[0x07230203, 0xcafecafe]);

        let code = test_data.shader_module_code(shader_stage::Stage::Compute);
        assert_eq!(&code, &[0x07230203, 0xfecafeca]);

        assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 1);
        let bindings = test_data.descriptor_set_layout_bindings(0);
        assert_eq!(bindings.len(), 2);

        assert_eq!(bindings[0].binding, 0);
        assert_eq!(
            bindings[0].descriptorType,
            vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
        );
        assert_eq!(bindings[0].descriptorCount, 1);
        assert_eq!(
            bindings[0].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT
                | vk::VK_SHADER_STAGE_COMPUTE_BIT,
        );

        assert_eq!(bindings[1].binding, 1);
        assert_eq!(
            bindings[1].descriptorType,
            vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
        );
        assert_eq!(bindings[1].descriptorCount, 1);
        assert_eq!(
            bindings[1].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT
                | vk::VK_SHADER_STAGE_COMPUTE_BIT,
        );

        let layout = test_data.pipeline_layout_create_info();
        assert_eq!(layout.create_info.sType, vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO);
        assert_eq!(layout.push_constant_ranges.len(), 1);
        assert_eq!(
            layout.push_constant_ranges[0].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT
                | vk::VK_SHADER_STAGE_COMPUTE_BIT,
        );
        assert_eq!(layout.push_constant_ranges[0].offset, 0);
        assert_eq!(
            layout.push_constant_ranges[0].size as usize,
            6 + mem::size_of::<f32>()
        );
        assert_eq!(
            &layout.layouts,
            test_data.pipeline_set.descriptor_set_layouts()
        );
    }

    #[test]
    fn descriptor_set_gaps() {
        let mut test_data = TestData::new(
            "[vertex shader passthrough]\n\
             [fragment shader]\n\
             03 02 23 07\n\
             fe ca fe ca\n\
             [test]\n\
             ubo 0 1024\n\
             ssbo 2:1 1024\n\
             ssbo 3:1 1024\n\
             ubo 5:5 1024\n\
             draw rect 0 0 1 1\n"
        ).unwrap();

        assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 6);

        let bindings = test_data.descriptor_set_layout_bindings(0);
        assert_eq!(bindings.len(), 1);

        assert_eq!(bindings[0].binding, 0);
        assert_eq!(
            bindings[0].descriptorType,
            vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
        );
        assert_eq!(bindings[0].descriptorCount, 1);
        assert_eq!(
            bindings[0].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
        );

        assert_eq!(test_data.descriptor_set_layout_bindings(1).len(), 0);

        let bindings = test_data.descriptor_set_layout_bindings(2);
        assert_eq!(bindings.len(), 1);

        assert_eq!(bindings[0].binding, 1);
        assert_eq!(
            bindings[0].descriptorType,
            vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
        );
        assert_eq!(bindings[0].descriptorCount, 1);
        assert_eq!(
            bindings[0].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
        );

        let bindings = test_data.descriptor_set_layout_bindings(3);
        assert_eq!(bindings.len(), 1);

        assert_eq!(bindings[0].binding, 1);
        assert_eq!(
            bindings[0].descriptorType,
            vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
        );
        assert_eq!(bindings[0].descriptorCount, 1);
        assert_eq!(
            bindings[0].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
        );

        assert_eq!(test_data.descriptor_set_layout_bindings(4).len(), 0);

        let bindings = test_data.descriptor_set_layout_bindings(5);
        assert_eq!(bindings.len(), 1);

        assert_eq!(bindings[0].binding, 5);
        assert_eq!(
            bindings[0].descriptorType,
            vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
        );
        assert_eq!(bindings[0].descriptorCount, 1);
        assert_eq!(
            bindings[0].stageFlags,
            vk::VK_SHADER_STAGE_VERTEX_BIT
                | vk::VK_SHADER_STAGE_FRAGMENT_BIT,
        );
    }

    #[test]
    fn compile_error() {
        let error = TestData::new(
            "[fragment shader]\n\
             12 34 56 78\n"
        ).unwrap_err();

        assert_eq!(
            &error.to_string(),
            "The compiler or assembler generated an invalid SPIR-V binary"
        );
    }

    #[test]
    fn create_errors() {
        let commands = [
            "vkCreatePipelineCache",
            "vkCreateDescriptorPool",
            "vkCreateDescriptorSetLayout",
            "vkCreatePipelineLayout",
        ];
        for command in commands {
            let error = TestData::new_with_errors(
                "[test]\n\
                 ubo 0 10\n\
                 draw rect 0 0 1 1\n",
                |fake_vulkan| fake_vulkan.queue_result(
                    command.to_string(),
                    vk::VK_ERROR_UNKNOWN,
                ),
            ).unwrap_err();

            assert_eq!(
                error.to_string(),
                format!("{} failed", command),
            );
        }
    }

    #[test]
    fn create_pipeline_errors() {
        let commands = [
            "vkCreateGraphicsPipelines",
            "vkCreateComputePipelines",
        ];
        for command in commands {
            let error = TestData::new_with_errors(
                "[compute shader]\n\
                 03 02 23 07\n\
                 ca fe ca fe\n\
                 [test]\n\
                 compute 1 1 1\n\
                 draw rect 0 0 1 1\n",
                |fake_vulkan| fake_vulkan.queue_result(
                    command.to_string(),
                    vk::VK_ERROR_UNKNOWN,
                ),
            ).unwrap_err();

            assert_eq!(
                &error.to_string(),
                "Pipeline creation function failed",
            );
        }
    }

    #[test]
    fn no_buffers() {
        let test_data = TestData::new("").unwrap();

        assert!(test_data.pipeline_set.descriptor_pool().is_none());
        assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 0);
    }
}

```

--------------------------------------------------------------------------------
/vkrunner/vkrunner/slot.rs:
--------------------------------------------------------------------------------

```rust
// vkrunner
//
// Copyright (C) 2018 Intel Corporation
// Copyright 2023 Neil Roberts
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice (including the next
// paragraph) shall be included in all copies or substantial portions of the
// Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

//! Contains utilities for accessing values stored in a buffer
//! according to the GLSL layout rules. These are decribed in detail
//! in the OpenGL specs here:
//! <https://registry.khronos.org/OpenGL/specs/gl/glspec45.core.pdf#page=159>

use crate::util;
use crate::tolerance::Tolerance;
use crate::half_float;
use std::mem;
use std::convert::TryInto;
use std::fmt;

/// Describes which layout standard is being used. The only difference
/// is that with `Std140` the offset between members of an
/// array or between elements of the minor axis of a matrix is rounded
/// up to be a multiple of 16 bytes.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum LayoutStd {
    Std140,
    Std430,
}

/// For matrix types, describes which axis is the major one. The
/// elements of the major axis are stored consecutively in memory. For
/// example, if the numbers indicate the order in memory of the
/// components of a matrix, the two orderings would look like this:
///
/// ```text
/// Column major     Row major
/// [0 3 6]          [0 1 2]
/// [1 4 7]          [3 4 5]
/// [2 5 8]          [6 7 8]
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MajorAxis {
    Column,
    Row
}

/// Combines the [MajorAxis] and the [LayoutStd] to provide a complete
/// description of the layout options used for accessing the
/// components of a type.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Layout {
    pub std: LayoutStd,
    pub major: MajorAxis,
}

/// An enum representing all of the types that can be stored in a
/// slot. These correspond to the types in GLSL.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Type {
    Int = 0,
    UInt,
    Int8,
    UInt8,
    Int16,
    UInt16,
    Int64,
    UInt64,
    Float16,
    Float,
    Double,
    F16Vec2,
    F16Vec3,
    F16Vec4,
    Vec2,
    Vec3,
    Vec4,
    DVec2,
    DVec3,
    DVec4,
    IVec2,
    IVec3,
    IVec4,
    UVec2,
    UVec3,
    UVec4,
    I8Vec2,
    I8Vec3,
    I8Vec4,
    U8Vec2,
    U8Vec3,
    U8Vec4,
    I16Vec2,
    I16Vec3,
    I16Vec4,
    U16Vec2,
    U16Vec3,
    U16Vec4,
    I64Vec2,
    I64Vec3,
    I64Vec4,
    U64Vec2,
    U64Vec3,
    U64Vec4,
    Mat2,
    Mat2x3,
    Mat2x4,
    Mat3x2,
    Mat3,
    Mat3x4,
    Mat4x2,
    Mat4x3,
    Mat4,
    DMat2,
    DMat2x3,
    DMat2x4,
    DMat3x2,
    DMat3,
    DMat3x4,
    DMat4x2,
    DMat4x3,
    DMat4,
}

/// The type of a component of the types in [Type].
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BaseType {
    Int,
    UInt,
    Int8,
    UInt8,
    Int16,
    UInt16,
    Int64,
    UInt64,
    Float16,
    Float,
    Double,
}

/// A base type combined with a slice of bytes. This implements
/// [Display](std::fmt::Display) so it can be used to extract a base
/// type from a byte array and present it as a human-readable string.
pub struct BaseTypeInSlice<'a> {
    base_type: BaseType,
    slice: &'a [u8],
}

/// A type of comparison that can be used to compare values stored in
/// a slot with the chosen criteria. Use the
/// [compare](Comparison::compare) method to perform the comparison.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Comparison {
    /// The comparison passes if the values are exactly equal.
    Equal,
    /// The comparison passes if the floating-point values are
    /// approximately equal using the given [Tolerance] settings. For
    /// integer types this is the same as [Equal](Comparison::Equal).
    FuzzyEqual,
    /// The comparison passes if the values are different.
    NotEqual,
    /// The comparison passes if all of the values in the first slot
    /// are less than the values in the second slot.
    Less,
    /// The comparison passes if all of the values in the first slot
    /// are greater than or equal to the values in the second slot.
    GreaterEqual,
    /// The comparison passes if all of the values in the first slot
    /// are greater than the values in the second slot.
    Greater,
    /// The comparison passes if all of the values in the first slot
    /// are less than or equal to the values in the second slot.
    LessEqual,
}

#[derive(Debug)]
struct TypeInfo {
    base_type: BaseType,
    columns: usize,
    rows: usize,
}

static TYPE_INFOS: [TypeInfo; 62] = [
    TypeInfo { base_type: BaseType::Int, columns: 1, rows: 1 }, // Int
    TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 1 }, // UInt
    TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 1 }, // Int8
    TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 1 }, // UInt8
    TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 1 }, // Int16
    TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 1 }, // UInt16
    TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 1 }, // Int64
    TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 1 }, // UInt64
    TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 1 }, // Float16
    TypeInfo { base_type: BaseType::Float, columns: 1, rows: 1 }, // Float
    TypeInfo { base_type: BaseType::Double, columns: 1, rows: 1 }, // Double
    TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 2 }, // F16Vec2
    TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 3 }, // F16Vec3
    TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 4 }, // F16Vec4
    TypeInfo { base_type: BaseType::Float, columns: 1, rows: 2 }, // Vec2
    TypeInfo { base_type: BaseType::Float, columns: 1, rows: 3 }, // Vec3
    TypeInfo { base_type: BaseType::Float, columns: 1, rows: 4 }, // Vec4
    TypeInfo { base_type: BaseType::Double, columns: 1, rows: 2 }, // DVec2
    TypeInfo { base_type: BaseType::Double, columns: 1, rows: 3 }, // DVec3
    TypeInfo { base_type: BaseType::Double, columns: 1, rows: 4 }, // DVec4
    TypeInfo { base_type: BaseType::Int, columns: 1, rows: 2 }, // IVec2
    TypeInfo { base_type: BaseType::Int, columns: 1, rows: 3 }, // IVec3
    TypeInfo { base_type: BaseType::Int, columns: 1, rows: 4 }, // IVec4
    TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 2 }, // UVec2
    TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 3 }, // UVec3
    TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 4 }, // UVec4
    TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 2 }, // I8Vec2
    TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 3 }, // I8Vec3
    TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 4 }, // I8Vec4
    TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 2 }, // U8Vec2
    TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 3 }, // U8Vec3
    TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 4 }, // U8Vec4
    TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 2 }, // I16Vec2
    TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 3 }, // I16Vec3
    TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 4 }, // I16Vec4
    TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 2 }, // U16Vec2
    TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 3 }, // U16Vec3
    TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 4 }, // U16Vec4
    TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 2 }, // I64Vec2
    TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 3 }, // I64Vec3
    TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 4 }, // I64Vec4
    TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 2 }, // U64Vec2
    TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 3 }, // U64Vec3
    TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 4 }, // U64Vec4
    TypeInfo { base_type: BaseType::Float, columns: 2, rows: 2 }, // Mat2
    TypeInfo { base_type: BaseType::Float, columns: 2, rows: 3 }, // Mat2x3
    TypeInfo { base_type: BaseType::Float, columns: 2, rows: 4 }, // Mat2x4
    TypeInfo { base_type: BaseType::Float, columns: 3, rows: 2 }, // Mat3x2
    TypeInfo { base_type: BaseType::Float, columns: 3, rows: 3 }, // Mat3
    TypeInfo { base_type: BaseType::Float, columns: 3, rows: 4 }, // Mat3x4
    TypeInfo { base_type: BaseType::Float, columns: 4, rows: 2 }, // Mat4x2
    TypeInfo { base_type: BaseType::Float, columns: 4, rows: 3 }, // Mat4x3
    TypeInfo { base_type: BaseType::Float, columns: 4, rows: 4 }, // Mat4
    TypeInfo { base_type: BaseType::Double, columns: 2, rows: 2 }, // DMat2
    TypeInfo { base_type: BaseType::Double, columns: 2, rows: 3 }, // DMat2x3
    TypeInfo { base_type: BaseType::Double, columns: 2, rows: 4 }, // DMat2x4
    TypeInfo { base_type: BaseType::Double, columns: 3, rows: 2 }, // DMat3x2
    TypeInfo { base_type: BaseType::Double, columns: 3, rows: 3 }, // DMat3
    TypeInfo { base_type: BaseType::Double, columns: 3, rows: 4 }, // DMat3x4
    TypeInfo { base_type: BaseType::Double, columns: 4, rows: 2 }, // DMat4x2
    TypeInfo { base_type: BaseType::Double, columns: 4, rows: 3 }, // DMat4x3
    TypeInfo { base_type: BaseType::Double, columns: 4, rows: 4 }, // DMat4
];

// Mapping from GLSL type name to slot type. Sorted alphabetically so
// we can do a binary search.
static GLSL_TYPE_NAMES: [(&'static str, Type); 68] = [
    ("dmat2", Type::DMat2),
    ("dmat2x2", Type::DMat2),
    ("dmat2x3", Type::DMat2x3),
    ("dmat2x4", Type::DMat2x4),
    ("dmat3", Type::DMat3),
    ("dmat3x2", Type::DMat3x2),
    ("dmat3x3", Type::DMat3),
    ("dmat3x4", Type::DMat3x4),
    ("dmat4", Type::DMat4),
    ("dmat4x2", Type::DMat4x2),
    ("dmat4x3", Type::DMat4x3),
    ("dmat4x4", Type::DMat4),
    ("double", Type::Double),
    ("dvec2", Type::DVec2),
    ("dvec3", Type::DVec3),
    ("dvec4", Type::DVec4),
    ("f16vec2", Type::F16Vec2),
    ("f16vec3", Type::F16Vec3),
    ("f16vec4", Type::F16Vec4),
    ("float", Type::Float),
    ("float16_t", Type::Float16),
    ("i16vec2", Type::I16Vec2),
    ("i16vec3", Type::I16Vec3),
    ("i16vec4", Type::I16Vec4),
    ("i64vec2", Type::I64Vec2),
    ("i64vec3", Type::I64Vec3),
    ("i64vec4", Type::I64Vec4),
    ("i8vec2", Type::I8Vec2),
    ("i8vec3", Type::I8Vec3),
    ("i8vec4", Type::I8Vec4),
    ("int", Type::Int),
    ("int16_t", Type::Int16),
    ("int64_t", Type::Int64),
    ("int8_t", Type::Int8),
    ("ivec2", Type::IVec2),
    ("ivec3", Type::IVec3),
    ("ivec4", Type::IVec4),
    ("mat2", Type::Mat2),
    ("mat2x2", Type::Mat2),
    ("mat2x3", Type::Mat2x3),
    ("mat2x4", Type::Mat2x4),
    ("mat3", Type::Mat3),
    ("mat3x2", Type::Mat3x2),
    ("mat3x3", Type::Mat3),
    ("mat3x4", Type::Mat3x4),
    ("mat4", Type::Mat4),
    ("mat4x2", Type::Mat4x2),
    ("mat4x3", Type::Mat4x3),
    ("mat4x4", Type::Mat4),
    ("u16vec2", Type::U16Vec2),
    ("u16vec3", Type::U16Vec3),
    ("u16vec4", Type::U16Vec4),
    ("u64vec2", Type::U64Vec2),
    ("u64vec3", Type::U64Vec3),
    ("u64vec4", Type::U64Vec4),
    ("u8vec2", Type::U8Vec2),
    ("u8vec3", Type::U8Vec3),
    ("u8vec4", Type::U8Vec4),
    ("uint", Type::UInt),
    ("uint16_t", Type::UInt16),
    ("uint64_t", Type::UInt64),
    ("uint8_t", Type::UInt8),
    ("uvec2", Type::UVec2),
    ("uvec3", Type::UVec3),
    ("uvec4", Type::UVec4),
    ("vec2", Type::Vec2),
    ("vec3", Type::Vec3),
    ("vec4", Type::Vec4),
];

// Mapping from operator name to Comparison. Sorted by ASCII values so
// we can do a binary search.
static COMPARISON_NAMES: [(&'static str, Comparison); 7] = [
    ("!=", Comparison::NotEqual),
    ("<", Comparison::Less),
    ("<=", Comparison::LessEqual),
    ("==", Comparison::Equal),
    (">", Comparison::Greater),
    (">=", Comparison::GreaterEqual),
    ("~=", Comparison::FuzzyEqual),
];

impl BaseType {
    /// Returns the size in bytes of a variable of this base type.
    pub fn size(self) -> usize {
        match self {
            BaseType::Int => mem::size_of::<i32>(),
            BaseType::UInt => mem::size_of::<u32>(),
            BaseType::Int8 => mem::size_of::<i8>(),
            BaseType::UInt8 => mem::size_of::<u8>(),
            BaseType::Int16 => mem::size_of::<i16>(),
            BaseType::UInt16 => mem::size_of::<u16>(),
            BaseType::Int64 => mem::size_of::<i64>(),
            BaseType::UInt64 => mem::size_of::<u64>(),
            BaseType::Float16 => mem::size_of::<u16>(),
            BaseType::Float => mem::size_of::<u32>(),
            BaseType::Double => mem::size_of::<u64>(),
        }
    }
}

impl<'a> BaseTypeInSlice<'a> {
    /// Combines the given base type and slice into a
    /// `BaseTypeInSlice` so that it can be displayed with the
    /// [Display](std::fmt::Display) trait. The slice is expected to
    /// contain a representation of the base type in native-endian
    /// order. The slice needs to have the same size as
    /// `base_type.size()` or the constructor will panic.
    pub fn new(base_type: BaseType, slice: &'a [u8]) -> BaseTypeInSlice<'a> {
        assert_eq!(slice.len(), base_type.size());

        BaseTypeInSlice { base_type, slice }
    }
}

impl<'a> fmt::Display for BaseTypeInSlice<'a> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self.base_type {
            BaseType::Int => {
                let v = i32::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::UInt => {
                let v = u32::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::Int8 => {
                let v = i8::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::UInt8 => {
                let v = u8::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::Int16 => {
                let v = i16::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::UInt16 => {
                let v = u16::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::Int64 => {
                let v = i64::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::UInt64 => {
                let v = u64::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::Float16 => {
                let v = u16::from_ne_bytes(self.slice.try_into().unwrap());
                let v = half_float::to_f64(v);
                write!(f, "{}", v)
            },
            BaseType::Float => {
                let v = f32::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
            BaseType::Double => {
                let v = f64::from_ne_bytes(self.slice.try_into().unwrap());
                write!(f, "{}", v)
            },
        }
    }
}

impl Type {
    /// Returns the type of the components of this slot type.
    #[inline]
    pub fn base_type(self) -> BaseType {
        TYPE_INFOS[self as usize].base_type
    }

    /// Returns the number of rows in this slot type. For primitive types
    /// this will be 1, for vectors it will be the size of the vector
    /// and for matrix types it will be the number of rows.
    #[inline]
    pub fn rows(self) -> usize {
        TYPE_INFOS[self as usize].rows
    }

    /// Return the number of columns in this slot type. For non-matrix
    /// types this will be 1.
    #[inline]
    pub fn columns(self) -> usize {
        TYPE_INFOS[self as usize].columns
    }

    /// Returns whether the type is a matrix type
    #[inline]
    pub fn is_matrix(self) -> bool {
        self.columns() > 1
    }

    // Returns the effective major axis setting that should be used
    // when calculating offsets. The setting should only affect matrix
    // types.
    fn effective_major_axis(self, layout: Layout) -> MajorAxis {
        if self.is_matrix() {
            layout.major
        } else {
            MajorAxis::Column
        }
    }

    // Return a tuple containing the size of the major and minor axes
    // (in that order) given the type and layout
    fn major_minor(self, layout: Layout) -> (usize, usize) {
        match self.effective_major_axis(layout) {
            MajorAxis::Column => (self.columns(), self.rows()),
            MajorAxis::Row => (self.rows(), self.columns()),
        }
    }

    /// Returns the matrix stride of the slot type. This is the offset
    /// in bytes between consecutive components of the minor axis. For
    /// example, in a column-major layout, the matrix stride of a vec4
    /// will be 16, because to move to the next column you need to add
    /// an offset of 16 bytes to the address of the previous column.
    pub fn matrix_stride(self, layout: Layout) -> usize {
        let component_size = self.base_type().size();
        let (_major, minor) = self.major_minor(layout);

        let base_stride = if minor == 3 {
            component_size * 4
        } else {
            component_size * minor
        };

        match layout.std {
            LayoutStd::Std140 => {
                // According to std140 the size is rounded up to a vec4
                util::align(base_stride, 16)
            },
            LayoutStd::Std430 => base_stride,
        }
    }

    /// Returns the offset between members of the elements of an array
    /// of this type of slot. There may be padding between the members
    /// to fulfill the alignment criteria of the layout.
    pub fn array_stride(self, layout: Layout) -> usize {
        let matrix_stride = self.matrix_stride(layout);
        let (major, _minor) = self.major_minor(layout);

        matrix_stride * major
    }

    /// Returns the size of the slot type. Note that unlike
    /// [array_stride](Type::array_stride), this won’t include the
    /// padding to align the values in an array according to the GLSL
    /// layout rules. It will however always have a size that is a
    /// multiple of the size of the base type so it is suitable as an
    /// array stride for packed values stored internally that aren’t
    /// passed to Vulkan.
    pub fn size(self, layout: Layout) -> usize {
        let matrix_stride = self.matrix_stride(layout);
        let base_size = self.base_type().size();
        let (major, minor) = self.major_minor(layout);

        (major - 1) * matrix_stride + base_size * minor
    }

    /// Returns an iterator that iterates over the offsets of the
    /// components of this slot. This can be used to extract the
    /// values out of an array of bytes containing the slot. The
    /// offsets are returned in column-major order (ie, all of the
    /// offsets for a row are returned before moving to the next
    /// column), but the offsets are calculated to take into account
    /// the major axis setting in the layout.
    pub fn offsets(self, layout: Layout) -> OffsetsIter {
        OffsetsIter::new(self, layout)
    }

    /// Returns a [Type] given the GLSL type name, for example `dmat2`
    /// or `i64vec3`. `None` will be returned if the type name isn’t
    /// recognised.
    pub fn from_glsl_type(type_name: &str) -> Option<Type> {
        match GLSL_TYPE_NAMES.binary_search_by(
            |&(name, _)| name.cmp(type_name)
        ) {
            Ok(pos) => Some(GLSL_TYPE_NAMES[pos].1),
            Err(_) => None,
        }
    }
}

/// Iterator over the offsets into an array to extract the components
/// of a slot. Created by [Type::offsets].
#[derive(Debug, Clone)]
pub struct OffsetsIter {
    slot_type: Type,
    row: usize,
    column: usize,
    offset: usize,
    column_offset: isize,
    row_offset: isize,
}

impl OffsetsIter {
    fn new(slot_type: Type, layout: Layout) -> OffsetsIter {
        let column_offset;
        let row_offset;
        let stride = slot_type.matrix_stride(layout);
        let base_size = slot_type.base_type().size();

        match slot_type.effective_major_axis(layout) {
            MajorAxis::Column => {
                column_offset =
                    (stride - base_size * slot_type.rows()) as isize;
                row_offset = base_size as isize;
            },
            MajorAxis::Row => {
                column_offset =
                    base_size as isize - (stride * slot_type.rows()) as isize;
                row_offset = stride as isize;
            },
        }

        OffsetsIter {
            slot_type,
            row: 0,
            column: 0,
            offset: 0,
            column_offset,
            row_offset,
        }
    }
}

impl Iterator for OffsetsIter {
    type Item = usize;

    fn next(&mut self) -> Option<usize> {
        if self.column >= self.slot_type.columns() {
            return None;
        }

        let result = self.offset;

        self.offset = ((self.offset as isize) + self.row_offset) as usize;
        self.row += 1;

        if self.row >= self.slot_type.rows() {
            self.row = 0;
            self.column += 1;
            self.offset =
                ((self.offset as isize) + self.column_offset) as usize;
        }

        Some(result)
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        let n_columns = self.slot_type.columns();

        let size = if self.column >= n_columns {
            0
        } else {
            (self.slot_type.rows() - self.row)
                + ((n_columns - 1 - self.column) * self.slot_type.rows())
        };

        (size, Some(size))
    }

    fn count(self) -> usize {
        self.size_hint().0
    }
}

impl Comparison {
    fn compare_without_fuzzy<T: PartialOrd>(self, a: T, b: T) -> bool {
        match self {
            Comparison::Equal | Comparison::FuzzyEqual => a.eq(&b),
            Comparison::NotEqual => a.ne(&b),
            Comparison::Less => a.lt(&b),
            Comparison::GreaterEqual => a.ge(&b),
            Comparison::Greater => a.gt(&b),
            Comparison::LessEqual => a.le(&b),
        }
    }

    fn compare_value(
        self,
        tolerance: &Tolerance,
        component: usize,
        base_type: BaseType,
        a: &[u8],
        b: &[u8]
    ) -> bool {
        match base_type {
            BaseType::Int => {
                let a = i32::from_ne_bytes(a.try_into().unwrap());
                let b = i32::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::UInt => {
                let a = u32::from_ne_bytes(a.try_into().unwrap());
                let b = u32::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::Int8 => {
                let a = i8::from_ne_bytes(a.try_into().unwrap());
                let b = i8::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::UInt8 => {
                let a = u8::from_ne_bytes(a.try_into().unwrap());
                let b = u8::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::Int16 => {
                let a = i16::from_ne_bytes(a.try_into().unwrap());
                let b = i16::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::UInt16 => {
                let a = u16::from_ne_bytes(a.try_into().unwrap());
                let b = u16::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::Int64 => {
                let a = i64::from_ne_bytes(a.try_into().unwrap());
                let b = i64::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::UInt64 => {
                let a = u64::from_ne_bytes(a.try_into().unwrap());
                let b = u64::from_ne_bytes(b.try_into().unwrap());
                self.compare_without_fuzzy(a, b)
            },
            BaseType::Float16 => {
                let a = u16::from_ne_bytes(a.try_into().unwrap());
                let a = half_float::to_f64(a);
                let b = u16::from_ne_bytes(b.try_into().unwrap());
                let b = half_float::to_f64(b);

                match self {
                    Comparison::FuzzyEqual => tolerance.equal(component, a, b),
                    Comparison::Equal
                        | Comparison::NotEqual
                        | Comparison::Less
                        | Comparison::GreaterEqual
                        | Comparison::Greater
                        | Comparison::LessEqual => {
                            self.compare_without_fuzzy(a, b)
                        },
                }
            },
            BaseType::Float => {
                let a = f32::from_ne_bytes(a.try_into().unwrap());
                let b = f32::from_ne_bytes(b.try_into().unwrap());

                match self {
                    Comparison::FuzzyEqual => {
                        tolerance.equal(component, a as f64, b as f64)
                    },
                    Comparison::Equal
                        | Comparison::NotEqual
                        | Comparison::Less
                        | Comparison::GreaterEqual
                        | Comparison::Greater
                        | Comparison::LessEqual => {
                            self.compare_without_fuzzy(a, b)
                        },
                }
            },
            BaseType::Double => {
                let a = f64::from_ne_bytes(a.try_into().unwrap());
                let b = f64::from_ne_bytes(b.try_into().unwrap());

                match self {
                    Comparison::FuzzyEqual => {
                        tolerance.equal(component, a, b)
                    },
                    Comparison::Equal
                        | Comparison::NotEqual
                        | Comparison::Less
                        | Comparison::GreaterEqual
                        | Comparison::Greater
                        | Comparison::LessEqual => {
                            self.compare_without_fuzzy(a, b)
                        },
                }
            },
        }
    }

    /// Compares the values in `a` with the values in `b` using the
    /// chosen criteria. The values are extracted from the two byte
    /// slices using the layout and type provided. The tolerance
    /// parameter is only used if the comparison is
    /// [FuzzyEqual](Comparison::FuzzyEqual).
    pub fn compare(
        self,
        tolerance: &Tolerance,
        slot_type: Type,
        layout: Layout,
        a: &[u8],
        b: &[u8],
    ) -> bool {
        let base_type = slot_type.base_type();
        let base_type_size = base_type.size();
        let rows = slot_type.rows();

        for (i, offset) in slot_type.offsets(layout).enumerate() {
            let a = &a[offset..offset + base_type_size];
            let b = &b[offset..offset + base_type_size];

            if !self.compare_value(tolerance, i % rows, base_type, a, b) {
                return false;
            }
        }

        true
    }

    /// Get the comparison operator given the name in GLSL like `"<"`
    /// or `">="`. It also accepts `"~="` to return the fuzzy equal
    /// comparison. If the operator isn’t recognised then it will
    /// return `None`.
    pub fn from_operator(operator: &str) -> Option<Comparison> {
        match COMPARISON_NAMES.binary_search_by(
            |&(probe, _value)| probe.cmp(operator)
        ) {
            Ok(pos) => Some(COMPARISON_NAMES[pos].1),
            Err(_) => None,
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_type_info() {
        // Test a few types

        // First
        assert!(matches!(Type::Int.base_type(), BaseType::Int));
        assert_eq!(Type::Int.columns(), 1);
        assert_eq!(Type::Int.rows(), 1);
        assert!(!Type::Int.is_matrix());

        // Last
        assert!(matches!(Type::DMat4.base_type(), BaseType::Double));
        assert_eq!(Type::DMat4.columns(), 4);
        assert_eq!(Type::DMat4.rows(), 4);
        assert!(Type::DMat4.is_matrix());

        // Vector
        assert!(matches!(Type::Vec4.base_type(), BaseType::Float));
        assert_eq!(Type::Vec4.columns(), 1);
        assert_eq!(Type::Vec4.rows(), 4);
        assert!(!Type::Vec4.is_matrix());

        // Non-square matrix
        assert!(matches!(Type::DMat4x3.base_type(), BaseType::Double));
        assert_eq!(Type::DMat4x3.columns(), 4);
        assert_eq!(Type::DMat4x3.rows(), 3);
        assert!(Type::DMat4x3.is_matrix());
    }

    #[test]
    fn test_base_type_size() {
        assert_eq!(BaseType::Int.size(), 4);
        assert_eq!(BaseType::UInt.size(), 4);
        assert_eq!(BaseType::Int8.size(), 1);
        assert_eq!(BaseType::UInt8.size(), 1);
        assert_eq!(BaseType::Int16.size(), 2);
        assert_eq!(BaseType::UInt16.size(), 2);
        assert_eq!(BaseType::Int64.size(), 8);
        assert_eq!(BaseType::UInt64.size(), 8);
        assert_eq!(BaseType::Float16.size(), 2);
        assert_eq!(BaseType::Float.size(), 4);
        assert_eq!(BaseType::Double.size(), 8);
    }

    #[test]
    fn test_matrix_stride() {
        assert_eq!(
            Type::Mat4.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            16,
        );
        // In Std140 the vecs along the minor axis are aligned to the
        // size of a vec4
        assert_eq!(
            Type::Mat4x2.matrix_stride(
                Layout { std: LayoutStd::Std140, major: MajorAxis::Column }
                ),
            16,
        );
        // In Std430 they are not
        assert_eq!(
            Type::Mat4x2.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            8,
        );
        // 3-component axis are always aligned to a vec4 regardless of
        // the layout standard
        assert_eq!(
            Type::Mat4x3.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            16,
        );

        // For the row-major the stride the axes are reversed
        assert_eq!(
            Type::Mat4x2.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
                ),
            16,
        );
        assert_eq!(
            Type::Mat2x4.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
                ),
            8,
        );

        // For non-matrix types the matrix stride is the array stride
        assert_eq!(
            Type::Vec3.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            16,
        );
        assert_eq!(
            Type::Float.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            4,
        );
        // Row-major layout does not affect vectors
        assert_eq!(
            Type::Vec2.matrix_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
                ),
            8,
        );
    }

    #[test]
    fn test_array_stride() {
        assert_eq!(
            Type::UInt8.array_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            1,
        );
        assert_eq!(
            Type::I8Vec2.array_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            2,
        );
        assert_eq!(
            Type::I8Vec3.array_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            4,
        );
        assert_eq!(
            Type::Mat4x3.array_stride(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            64,
        );
    }

    #[test]
    fn test_size() {
        assert_eq!(
            Type::UInt8.size(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            1,
        );
        assert_eq!(
            Type::I8Vec2.size(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            2,
        );
        assert_eq!(
            Type::I8Vec3.size(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            3,
        );
        assert_eq!(
            Type::Mat4x3.size(
                Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
                ),
            3 * 4 * 4 + 3 * 4,
        );
    }

    #[test]
    fn test_iterator() {
        let default_layout = Layout {
            std: LayoutStd::Std140,
            major: MajorAxis::Column,
        };

        // The components of a vec4 are packed tightly
        let vec4_offsets =
            Type::Vec4.offsets(default_layout).collect::<Vec<usize>>();
        assert_eq!(vec4_offsets, vec![0, 4, 8, 12]);

        // The major axis setting shouldn’t affect vectors
        let row_major_layout = Layout {
            std: LayoutStd::Std140,
            major: MajorAxis::Row,
        };
        let vec4_offsets =
            Type::Vec4.offsets(row_major_layout).collect::<Vec<usize>>();
        assert_eq!(vec4_offsets, vec![0, 4, 8, 12]);

        // Components of a mat4 are packed tightly
        let mat4_offsets =
            Type::Mat4.offsets(default_layout).collect::<Vec<usize>>();
        assert_eq!(
            mat4_offsets,
            vec![0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60],
        );

        // If there are 3 rows then they are padded out to align
        // each column to the size of a vec4
        let mat4x3_offsets =
            Type::Mat4x3.offsets(default_layout).collect::<Vec<usize>>();
        assert_eq!(
            mat4x3_offsets,
            vec![0, 4, 8, 16, 20, 24, 32, 36, 40, 48, 52, 56],
        );

        // If row-major order is being used then the offsets are still
        // reported in column-major order but the addresses that they
        // point to are in row-major order.
        // [0   4   8]
        // [16  20  24]
        // [32  36  40]
        // [48  52  56]
        let mat3x4_offsets =
            Type::Mat3x4.offsets(row_major_layout).collect::<Vec<usize>>();
        assert_eq!(
            mat3x4_offsets,
            vec![0, 16, 32, 48, 4, 20, 36, 52, 8, 24, 40, 56],
        );

        // Check the size_hint and count. The size hint is always
        // exact and should match the count.
        let mut iter = Type::Mat4.offsets(default_layout);

        for i in 0..16 {
            let expected_count = 16 - i;
            assert_eq!(
                iter.size_hint(),
                (expected_count, Some(expected_count))
            );
            assert_eq!(iter.clone().count(), expected_count);

            assert!(matches!(iter.next(), Some(_)));
        }

        assert_eq!(iter.size_hint(), (0, Some(0)));
        assert_eq!(iter.clone().count(), 0);
        assert!(matches!(iter.next(), None));
    }

    #[test]
    fn test_compare() {
        // Test every comparison mode
        assert!(Comparison::Equal.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[5],
        ));
        assert!(!Comparison::Equal.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[6],
            &[5],
        ));
        assert!(Comparison::NotEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[6],
            &[5],
        ));
        assert!(!Comparison::NotEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[5],
        ));
        assert!(Comparison::Less.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[4],
            &[5],
        ));
        assert!(!Comparison::Less.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[5],
        ));
        assert!(!Comparison::Less.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[6],
            &[5],
        ));
        assert!(Comparison::LessEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[4],
            &[5],
        ));
        assert!(Comparison::LessEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[5],
        ));
        assert!(!Comparison::LessEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[6],
            &[5],
        ));
        assert!(Comparison::Greater.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[4],
        ));
        assert!(!Comparison::Greater.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[5],
        ));
        assert!(!Comparison::Greater.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[6],
        ));
        assert!(Comparison::GreaterEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[4],
        ));
        assert!(Comparison::GreaterEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[5],
        ));
        assert!(!Comparison::GreaterEqual.compare(
            &Default::default(), // tolerance
            Type::UInt8,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5],
            &[6],
        ));
        assert!(Comparison::FuzzyEqual.compare(
            &Tolerance::new([1.0; 4], false),
            Type::Float,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &5.0f32.to_ne_bytes(),
            &5.9f32.to_ne_bytes(),
        ));
        assert!(!Comparison::FuzzyEqual.compare(
            &Tolerance::new([1.0; 4], false),
            Type::Float,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &5.0f32.to_ne_bytes(),
            &6.1f32.to_ne_bytes(),
        ));

        // Test all types

        macro_rules! test_compare_type_with_values {
            ($type_enum:expr, $low_value:expr, $high_value:expr) => {
                assert!(Comparison::Less.compare(
                    &Default::default(),
                    $type_enum,
                    Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
                    &$low_value.to_ne_bytes(),
                    &$high_value.to_ne_bytes(),
                ));
            };
        }

        macro_rules! test_compare_simple_type {
            ($type_num:expr, $rust_type:ty) => {
                test_compare_type_with_values!(
                    $type_num,
                    0 as $rust_type,
                    <$rust_type>::MAX
                );
            };
        }

        test_compare_simple_type!(Type::Int, i32);
        test_compare_simple_type!(Type::UInt, u32);
        test_compare_simple_type!(Type::Int8, i8);
        test_compare_simple_type!(Type::UInt8, u8);
        test_compare_simple_type!(Type::Int16, i16);
        test_compare_simple_type!(Type::UInt16, u16);
        test_compare_simple_type!(Type::Int64, i64);
        test_compare_simple_type!(Type::UInt64, u64);
        test_compare_simple_type!(Type::Float, f32);
        test_compare_simple_type!(Type::Double, f64);
        test_compare_type_with_values!(
            Type::Float16,
            half_float::from_f32(-100.0),
            half_float::from_f32(300.0)
        );

        macro_rules! test_compare_big_type {
            ($type_enum:expr, $bytes:expr) => {
                let layout = Layout {
                    std: LayoutStd::Std140,
                    major: MajorAxis::Column
                };
                let mut bytes = Vec::new();
                assert_eq!($type_enum.size(layout) % $bytes.len(), 0);
                for _ in 0..($type_enum.size(layout) / $bytes.len()) {
                    bytes.extend_from_slice($bytes);
                }
                assert!(Comparison::FuzzyEqual.compare(
                    &Default::default(),
                    $type_enum,
                    layout,
                    &bytes,
                    &bytes,
                ));
            };
        }

        test_compare_big_type!(
            Type::F16Vec2,
            &half_float::from_f32(1.0).to_ne_bytes()
        );
        test_compare_big_type!(
            Type::F16Vec3,
            &half_float::from_f32(1.0).to_ne_bytes()
        );
        test_compare_big_type!(
            Type::F16Vec4,
            &half_float::from_f32(1.0).to_ne_bytes()
        );

        test_compare_big_type!(Type::Vec2, &42.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Vec3, &24.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Vec4, &6.0f32.to_ne_bytes());
        test_compare_big_type!(Type::DVec2, &42.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DVec3, &24.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DVec4, &6.0f64.to_ne_bytes());
        test_compare_big_type!(Type::IVec2, &(-42i32).to_ne_bytes());
        test_compare_big_type!(Type::IVec3, &(-24i32).to_ne_bytes());
        test_compare_big_type!(Type::IVec4, &(-6i32).to_ne_bytes());
        test_compare_big_type!(Type::UVec2, &42u32.to_ne_bytes());
        test_compare_big_type!(Type::UVec3, &24u32.to_ne_bytes());
        test_compare_big_type!(Type::UVec4, &6u32.to_ne_bytes());
        test_compare_big_type!(Type::I8Vec2, &(-42i8).to_ne_bytes());
        test_compare_big_type!(Type::I8Vec3, &(-24i8).to_ne_bytes());
        test_compare_big_type!(Type::I8Vec4, &(-6i8).to_ne_bytes());
        test_compare_big_type!(Type::U8Vec2, &42u8.to_ne_bytes());
        test_compare_big_type!(Type::U8Vec3, &24u8.to_ne_bytes());
        test_compare_big_type!(Type::U8Vec4, &6u8.to_ne_bytes());
        test_compare_big_type!(Type::I16Vec2, &(-42i16).to_ne_bytes());
        test_compare_big_type!(Type::I16Vec3, &(-24i16).to_ne_bytes());
        test_compare_big_type!(Type::I16Vec4, &(-6i16).to_ne_bytes());
        test_compare_big_type!(Type::U16Vec2, &42u16.to_ne_bytes());
        test_compare_big_type!(Type::U16Vec3, &24u16.to_ne_bytes());
        test_compare_big_type!(Type::U16Vec4, &6u16.to_ne_bytes());
        test_compare_big_type!(Type::I64Vec2, &(-42i64).to_ne_bytes());
        test_compare_big_type!(Type::I64Vec3, &(-24i64).to_ne_bytes());
        test_compare_big_type!(Type::I64Vec4, &(-6i64).to_ne_bytes());
        test_compare_big_type!(Type::U64Vec2, &42u64.to_ne_bytes());
        test_compare_big_type!(Type::U64Vec3, &24u64.to_ne_bytes());
        test_compare_big_type!(Type::U64Vec4, &6u64.to_ne_bytes());
        test_compare_big_type!(Type::Mat2, &42.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat2x3, &42.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat2x4, &42.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat3, &24.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat3x2, &24.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat3x4, &24.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat4, &6.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat4x2, &6.0f32.to_ne_bytes());
        test_compare_big_type!(Type::Mat4x3, &6.0f32.to_ne_bytes());
        test_compare_big_type!(Type::DMat2, &42.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat2x3, &42.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat2x4, &42.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat3, &24.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat3x2, &24.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat3x4, &24.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat4, &6.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat4x2, &6.0f64.to_ne_bytes());
        test_compare_big_type!(Type::DMat4x3, &6.0f64.to_ne_bytes());

        // Check that it the comparison fails if a value other than
        // the first one is different
        assert!(!Comparison::Equal.compare(
            &Default::default(), // tolerance
            Type::U8Vec4,
            Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
            &[5, 6, 7, 8],
            &[5, 6, 7, 9],
        ));
    }

    #[test]
    fn test_from_glsl_type() {
        assert_eq!(Type::from_glsl_type("vec3"), Some(Type::Vec3));
        assert_eq!(Type::from_glsl_type("dvec4"), Some(Type::DVec4));
        assert_eq!(Type::from_glsl_type("i64vec4"), Some(Type::I64Vec4));
        assert_eq!(Type::from_glsl_type("uint"), Some(Type::UInt));
        assert_eq!(Type::from_glsl_type("uint16_t"), Some(Type::UInt16));
        assert_eq!(Type::from_glsl_type("dvec5"), None);

        // Check that all types can be found with the binary search
        for &(type_name, slot_type) in GLSL_TYPE_NAMES.iter() {
            assert_eq!(Type::from_glsl_type(type_name), Some(slot_type));
        }
    }

    #[test]
    fn test_comparison_from_operator() {
        // If a comparison is added then please add it to this test too
        assert_eq!(COMPARISON_NAMES.len(), 7);
        assert_eq!(
            Comparison::from_operator("!="),
            Some(Comparison::NotEqual),
        );
        assert_eq!(
            Comparison::from_operator("<"),
            Some(Comparison::Less),
        );
        assert_eq!(
            Comparison::from_operator("<="),
            Some(Comparison::LessEqual),
        );
        assert_eq!(
            Comparison::from_operator("=="),
            Some(Comparison::Equal),
        );
        assert_eq!(
            Comparison::from_operator(">"),
            Some(Comparison::Greater),
        );
        assert_eq!(
            Comparison::from_operator(">="),
            Some(Comparison::GreaterEqual),
        );
        assert_eq!(
            Comparison::from_operator("~="),
            Some(Comparison::FuzzyEqual),
        );

        assert_eq!(Comparison::from_operator("<=>"), None);
    }

    #[track_caller]
    fn check_base_type_format(
        base_type: BaseType,
        bytes: &[u8],
        expected: &str,
    ) {
        assert_eq!(
            &BaseTypeInSlice::new(base_type, bytes).to_string(),
            expected
        );
    }

    #[test]
    fn base_type_format() {
        check_base_type_format(BaseType::Int, &1234i32.to_ne_bytes(), "1234");
        check_base_type_format(
            BaseType::UInt,
            &u32::MAX.to_ne_bytes(),
            "4294967295",
        );
        check_base_type_format(BaseType::Int8, &(-4i8).to_ne_bytes(), "-4");
        check_base_type_format(BaseType::UInt8, &129u8.to_ne_bytes(), "129");
        check_base_type_format(BaseType::Int16, &(-4i16).to_ne_bytes(), "-4");
        check_base_type_format(
            BaseType::UInt16,
            &32769u16.to_ne_bytes(),
            "32769"
        );
        check_base_type_format(BaseType::Int64, &(-4i64).to_ne_bytes(), "-4");
        check_base_type_format(
            BaseType::UInt64,
            &u64::MAX.to_ne_bytes(),
            "18446744073709551615"
        );
        check_base_type_format(
            BaseType::Float16,
            &0xc000u16.to_ne_bytes(),
            "-2"
        );
        check_base_type_format(
            BaseType::Float,
            &2.0f32.to_ne_bytes(),
            "2"
        );
        check_base_type_format(
            BaseType::Double,
            &2.0f64.to_ne_bytes(),
            "2"
        );
    }
}

```

--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------

```rust
use anyhow::Result;
use clap::Parser;
use image::codecs::pnm::PnmDecoder;
use image::{DynamicImage, ImageError, RgbImage};
use rmcp::{
    Error as McpError, RoleServer, ServerHandler, ServiceExt, const_string, model::*, schemars,
    service::RequestContext, tool, transport::stdio,
};
use serde_json::json;
use shaderc::{self, CompileOptions, Compiler, OptimizationLevel, ShaderKind};
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing_subscriber::{self, EnvFilter};

pub fn read_and_decode_ppm_file<P: AsRef<Path>>(path: P) -> Result<RgbImage, ImageError> {
    let file = File::open(path)?;
    let reader = BufReader::new(file);
    let decoder: PnmDecoder<BufReader<File>> = PnmDecoder::new(reader)?;
    let dynamic_image = DynamicImage::from_decoder(decoder)?;
    let rgb_image = dynamic_image.into_rgb8();
    Ok(rgb_image)
}

#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderStage {
    #[schemars(description = "Vertex processing stage (transforms vertices)")]
    Vert,
    #[schemars(description = "Fragment processing stage (determines pixel colors)")]
    Frag,
    #[schemars(description = "Tessellation control stage for patch control points")]
    Tesc,
    #[schemars(description = "Tessellation evaluation stage for computing tessellated geometry")]
    Tese,
    #[schemars(description = "Geometry stage for generating/modifying primitives")]
    Geom,
    #[schemars(description = "Compute stage for general-purpose computation")]
    Comp,
}

#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerRequire {
    #[schemars(
        description = "Enables cooperative matrix operations with specified dimensions and data type"
    )]
    CooperativeMatrix {
        m: u32,
        n: u32,
        component_type: String,
    },

    #[schemars(
        description = "Specifies depth/stencil format for depth testing and stencil operations"
    )]
    DepthStencil(String),

    #[schemars(description = "Specifies framebuffer format for render target output")]
    Framebuffer(String),

    #[schemars(description = "Enables double-precision floating point operations in shaders")]
    ShaderFloat64,

    #[schemars(description = "Enables geometry shader stage for primitive manipulation")]
    GeometryShader,

    #[schemars(description = "Enables lines with width greater than 1.0 pixel")]
    WideLines,

    #[schemars(description = "Enables bitwise logical operations on framebuffer contents")]
    LogicOp,

    #[schemars(description = "Specifies required subgroup size for shader execution")]
    SubgroupSize(u32),

    #[schemars(description = "Enables memory stores and atomic operations in fragment shaders")]
    FragmentStoresAndAtomics,

    #[schemars(description = "Enables shaders to use raw buffer addresses")]
    BufferDeviceAddress,
}

#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerPass {
    #[schemars(
        description = "Use a built-in pass-through vertex shader (forwards attributes to fragment shader)"
    )]
    VertPassthrough,

    #[schemars(description = "Use compiled vertex shader from specified SPIR-V assembly file")]
    VertSpirv {
        #[schemars(
            description = "Path to the compiled vertex shader SPIR-V assembly (.spvasm) file"
        )]
        vert_spvasm_path: String,
    },

    #[schemars(description = "Use compiled fragment shader from specified SPIR-V assembly file")]
    FragSpirv {
        #[schemars(
            description = "Path to the compiled fragment shader SPIR-V assembly (.spvasm) file"
        )]
        frag_spvasm_path: String,
    },

    #[schemars(description = "Use compiled compute shader from specified SPIR-V assembly file")]
    CompSpirv {
        #[schemars(
            description = "Path to the compiled compute shader SPIR-V assembly (.spvasm) file"
        )]
        comp_spvasm_path: String,
    },

    #[schemars(description = "Use compiled geometry shader from specified SPIR-V assembly file")]
    GeomSpirv {
        #[schemars(
            description = "Path to the compiled geometry shader SPIR-V assembly (.spvasm) file"
        )]
        geom_spvasm_path: String,
    },

    #[schemars(
        description = "Use compiled tessellation control shader from specified SPIR-V assembly file"
    )]
    TescSpirv {
        #[schemars(
            description = "Path to the compiled tessellation control shader SPIR-V assembly (.spvasm) file"
        )]
        tesc_spvasm_path: String,
    },

    #[schemars(
        description = "Use compiled tessellation evaluation shader from specified SPIR-V assembly file"
    )]
    TeseSpirv {
        #[schemars(
            description = "Path to the compiled tessellation evaluation shader SPIR-V assembly (.spvasm) file"
        )]
        tese_spvasm_path: String,
    },
}

#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerVertexData {
    #[schemars(description = "Defines an attribute format at a shader location/binding")]
    AttributeFormat {
        #[schemars(description = "Location/binding number in the shader")]
        location: u32,

        #[schemars(description = "Format name (e.g., R32G32_SFLOAT, A8B8G8R8_UNORM_PACK32)")]
        format: String,
    },

    #[schemars(description = "2D position data (for R32G32_SFLOAT format)")]
    Vec2 {
        #[schemars(description = "X coordinate value")]
        x: f32,

        #[schemars(description = "Y coordinate value")]
        y: f32,
    },

    #[schemars(description = "3D position data (for R32G32B32_SFLOAT format)")]
    Vec3 {
        #[schemars(description = "X coordinate value")]
        x: f32,

        #[schemars(description = "Y coordinate value")]
        y: f32,

        #[schemars(description = "Z coordinate value")]
        z: f32,
    },

    #[schemars(description = "4D position/color data (for R32G32B32A32_SFLOAT format)")]
    Vec4 {
        #[schemars(description = "X coordinate or Red component")]
        x: f32,

        #[schemars(description = "Y coordinate or Green component")]
        y: f32,

        #[schemars(description = "Z coordinate or Blue component")]
        z: f32,

        #[schemars(description = "W coordinate or Alpha component")]
        w: f32,
    },

    #[schemars(description = "RGB color data (for R8G8B8_UNORM format)")]
    RGB {
        #[schemars(description = "Red component (0-255)")]
        r: u8,

        #[schemars(description = "Green component (0-255)")]
        g: u8,

        #[schemars(description = "Blue component (0-255)")]
        b: u8,
    },

    #[schemars(description = "ARGB color as hex (for A8B8G8R8_UNORM_PACK32 format)")]
    Hex {
        #[schemars(description = "Color as hex string (0xAARRGGBB format)")]
        value: String,
    },

    #[schemars(description = "Generic data components for custom formats")]
    GenericComponents {
        #[schemars(description = "Component values as strings, interpreted by format")]
        components: Vec<String>,
    },
}

#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerTest {
    #[schemars(description = "Set fragment shader entrypoint function name")]
    FragmentEntrypoint {
        #[schemars(description = "Function name to use as entrypoint")]
        name: String,
    },

    #[schemars(description = "Set vertex shader entrypoint function name")]
    VertexEntrypoint {
        #[schemars(description = "Function name to use as entrypoint")]
        name: String,
    },

    #[schemars(description = "Set compute shader entrypoint function name")]
    ComputeEntrypoint {
        #[schemars(description = "Function name to use as entrypoint")]
        name: String,
    },

    #[schemars(description = "Set geometry shader entrypoint function name")]
    GeometryEntrypoint {
        #[schemars(description = "Function name to use as entrypoint")]
        name: String,
    },

    #[schemars(description = "Draw a rectangle using normalized device coordinates")]
    DrawRect {
        #[schemars(description = "X coordinate of lower-left corner")]
        x: f32,

        #[schemars(description = "Y coordinate of lower-left corner")]
        y: f32,

        #[schemars(description = "Width of the rectangle")]
        width: f32,

        #[schemars(description = "Height of the rectangle")]
        height: f32,
    },

    #[schemars(description = "Draw primitives using vertex data")]
    DrawArrays {
        #[schemars(description = "Primitive type (TRIANGLE_LIST, POINT_LIST, etc.)")]
        primitive_type: String,

        #[schemars(description = "Index of first vertex")]
        first: u32,

        #[schemars(description = "Number of vertices to draw")]
        count: u32,
    },

    #[schemars(description = "Draw primitives using indexed vertex data")]
    DrawArraysIndexed {
        #[schemars(description = "Primitive type (TRIANGLE_LIST, etc.)")]
        primitive_type: String,

        #[schemars(description = "Index of first index")]
        first: u32,

        #[schemars(description = "Number of indices to use")]
        count: u32,
    },

    #[schemars(description = "Create or initialize a Shader Storage Buffer Object (SSBO)")]
    SSBO {
        #[schemars(description = "Binding point in the shader")]
        binding: u32,

        #[schemars(description = "Size in bytes (if data not provided)")]
        size: Option<u32>,

        #[schemars(description = "Initial buffer contents")]
        data: Option<Vec<u8>>,

        #[schemars(description = "Descriptor set number (default: 0)")]
        descriptor_set: Option<u32>,
    },

    #[schemars(description = "Update a portion of an SSBO with new data")]
    SSBOSubData {
        #[schemars(description = "Binding point in the shader")]
        binding: u32,

        #[schemars(description = "Data type (float, vec4, etc.)")]
        data_type: String,

        #[schemars(description = "Byte offset into the buffer")]
        offset: u32,

        #[schemars(description = "Values to write (as strings)")]
        values: Vec<String>,

        #[schemars(description = "Descriptor set number (default: 0)")]
        descriptor_set: Option<u32>,
    },

    #[schemars(description = "Create or initialize a Uniform Buffer Object (UBO)")]
    UBO {
        #[schemars(description = "Binding point in the shader")]
        binding: u32,

        #[schemars(description = "Buffer contents")]
        data: Vec<u8>,

        #[schemars(description = "Descriptor set number (default: 0)")]
        descriptor_set: Option<u32>,
    },

    #[schemars(description = "Update a portion of a UBO with new data")]
    UBOSubData {
        #[schemars(description = "Binding point in the shader")]
        binding: u32,

        #[schemars(description = "Data type (float, vec4, etc.)")]
        data_type: String,

        #[schemars(description = "Byte offset into the buffer")]
        offset: u32,

        #[schemars(description = "Values to write (as strings)")]
        values: Vec<String>,

        #[schemars(description = "Descriptor set number (default: 0)")]
        descriptor_set: Option<u32>,
    },

    #[schemars(description = "Set memory layout for buffer data")]
    BufferLayout {
        #[schemars(description = "Buffer type (ubo, ssbo)")]
        buffer_type: String,

        #[schemars(description = "Layout specification (std140, std430, row_major, column_major)")]
        layout_type: String,
    },

    #[schemars(description = "Set push constant values")]
    Push {
        #[schemars(description = "Data type (float, vec4, etc.)")]
        data_type: String,

        #[schemars(description = "Byte offset into push constant block")]
        offset: u32,

        #[schemars(description = "Values to write (as strings)")]
        values: Vec<String>,
    },

    #[schemars(description = "Set memory layout for push constants")]
    PushLayout {
        #[schemars(description = "Layout specification (std140, std430)")]
        layout_type: String,
    },

    #[schemars(description = "Execute compute shader with specified workgroup counts")]
    Compute {
        #[schemars(description = "Number of workgroups in X dimension")]
        x: u32,

        #[schemars(description = "Number of workgroups in Y dimension")]
        y: u32,

        #[schemars(description = "Number of workgroups in Z dimension")]
        z: u32,
    },

    #[schemars(description = "Verify framebuffer or buffer contents match expected values")]
    Probe {
        #[schemars(description = "Probe type (all, rect, ssbo, etc.)")]
        probe_type: String,

        #[schemars(description = "Component format (rgba, rgb, etc.)")]
        format: String,

        #[schemars(description = "Parameters (coordinates, expected values)")]
        args: Vec<String>,
    },

    #[schemars(description = "Verify contents using normalized (0-1) coordinates")]
    RelativeProbe {
        #[schemars(description = "Probe type (rect, etc.)")]
        probe_type: String,

        #[schemars(description = "Component format (rgba, rgb, etc.)")]
        format: String,

        #[schemars(description = "Parameters (coordinates, expected values)")]
        args: Vec<String>,
    },

    #[schemars(description = "Set acceptable error margin for value comparisons")]
    Tolerance {
        #[schemars(description = "Error margins (absolute or percentage with % suffix)")]
        values: Vec<f32>,
    },

    #[schemars(description = "Clear the framebuffer to default values")]
    Clear,

    #[schemars(description = "Enable/disable depth testing")]
    DepthTestEnable {
        #[schemars(description = "True to enable depth testing")]
        enable: bool,
    },

    #[schemars(description = "Enable/disable writing to depth buffer")]
    DepthWriteEnable {
        #[schemars(description = "True to enable depth writes")]
        enable: bool,
    },

    #[schemars(description = "Set depth comparison function")]
    DepthCompareOp {
        #[schemars(description = "Function name (VK_COMPARE_OP_LESS, etc.)")]
        op: String,
    },

    #[schemars(description = "Enable/disable stencil testing")]
    StencilTestEnable {
        #[schemars(description = "True to enable stencil testing")]
        enable: bool,
    },

    #[schemars(description = "Define which winding order is front-facing")]
    FrontFace {
        #[schemars(description = "Mode (VK_FRONT_FACE_CLOCKWISE, etc.)")]
        mode: String,
    },

    #[schemars(description = "Configure stencil operation for a face")]
    StencilOp {
        #[schemars(description = "Face (front, back)")]
        face: String,

        #[schemars(description = "Operation name (passOp, failOp, etc.)")]
        op_name: String,

        #[schemars(description = "Value (VK_STENCIL_OP_REPLACE, etc.)")]
        value: String,
    },

    #[schemars(description = "Set reference value for stencil comparisons")]
    StencilReference {
        #[schemars(description = "Face (front, back)")]
        face: String,

        #[schemars(description = "Reference value")]
        value: u32,
    },

    #[schemars(description = "Set stencil comparison function")]
    StencilCompareOp {
        #[schemars(description = "Face (front, back)")]
        face: String,

        #[schemars(description = "Function (VK_COMPARE_OP_EQUAL, etc.)")]
        op: String,
    },

    #[schemars(description = "Control which color channels can be written")]
    ColorWriteMask {
        #[schemars(
            description = "Bit flags (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT, etc.)"
        )]
        mask: String,
    },

    #[schemars(description = "Enable/disable logical operations on colors")]
    LogicOpEnable {
        #[schemars(description = "True to enable logic operations")]
        enable: bool,
    },

    #[schemars(description = "Set type of logical operation on colors")]
    LogicOp {
        #[schemars(description = "Operation (VK_LOGIC_OP_XOR, etc.)")]
        op: String,
    },

    #[schemars(description = "Set face culling mode")]
    CullMode {
        #[schemars(description = "Mode (VK_CULL_MODE_BACK_BIT, etc.)")]
        mode: String,
    },

    #[schemars(description = "Set width for line primitives")]
    LineWidth {
        #[schemars(description = "Width in pixels")]
        width: f32,
    },

    #[schemars(description = "Specify a feature required by the test")]
    Require {
        #[schemars(description = "Feature name (subgroup_size, depthstencil, etc.)")]
        feature: String,

        #[schemars(description = "Feature parameters")]
        parameters: Vec<String>,
    },
}

#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub struct CompileRequest {
    #[schemars(description = "The shader stage to compile (vert, frag, comp, geom, tesc, tese)")]
    pub stage: ShaderStage,
    #[schemars(description = "GLSL shader source code to compile")]
    pub source: String,
    #[schemars(description = "Path where compiled SPIR-V assembly (.spvasm) will be saved")]
    pub tmp_output_path: String,
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct CompileShadersRequest {
    #[schemars(description = "List of shader compile requests (produces SPIR-V assemblies)")]
    pub requests: Vec<CompileRequest>,
}

#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct CompileRunShadersRequest {
    #[schemars(
        description = "List of shader compile requests - each produces a SPIR-V assembly file"
    )]
    pub requests: Vec<CompileRequest>,
    #[schemars(description = "Optional hardware/feature requirements needed for shader execution")]
    pub requirements: Option<Vec<ShaderRunnerRequire>>,
    #[schemars(
        description = "Shader pipeline configuration (references compiled SPIR-V files by path)"
    )]
    pub passes: Vec<ShaderRunnerPass>,
    #[schemars(description = "Optional vertex data for rendering geometry")]
    pub vertex_data: Option<Vec<ShaderRunnerVertexData>>,
    #[schemars(description = "Test commands to execute (drawing, compute, verification, etc.)")]
    pub tests: Vec<ShaderRunnerTest>,
    #[schemars(description = "Optional path to save output image (PNG format)")]
    pub output_path: Option<String>,
}
#[derive(Clone)]
pub struct ShadercVkrunnerMcp {}
#[tool(tool_box)]
impl ShadercVkrunnerMcp {
    #[must_use]
    pub fn new() -> Self {
        Self {}
    }

    #[tool(
        description = "REQUIRES: (1) Shader compile requests that produce SPIR-V assembly files, (2) Passes referencing these files by path, and (3) Test commands for drawing/computation. Workflow: First compile GLSL to SPIR-V, then reference compiled files in passes, then execute drawing commands in tests to render/compute. Tests MUST include draw/compute commands to produce visible output. Optional: requirements for hardware features, vertex data for geometry, output path for saving image. Every shader MUST have its compiled output referenced in passes."
    )]
    fn compile_run_shaders(
        &self,
        #[tool(aggr)] request: CompileRunShadersRequest,
    ) -> Result<CallToolResult, McpError> {
        use std::fs::File;
        use std::io::{Read, Write};
        use std::path::Path;
        use std::process::{Command, Stdio}; // Still needed for vkrunner

        fn io_err(e: std::io::Error) -> McpError {
            McpError::internal_error("IO operation failed", Some(json!({"error": e.to_string()})))
        }

        for req in &request.requests {
            let shader_kind = match req.stage {
                ShaderStage::Vert => ShaderKind::Vertex,
                ShaderStage::Frag => ShaderKind::Fragment,
                ShaderStage::Tesc => ShaderKind::TessControl,
                ShaderStage::Tese => ShaderKind::TessEvaluation,
                ShaderStage::Geom => ShaderKind::Geometry,
                ShaderStage::Comp => ShaderKind::Compute,
            };

            let stage_flag = match req.stage {
                ShaderStage::Vert => "vert",
                ShaderStage::Frag => "frag",
                ShaderStage::Tesc => "tesc",
                ShaderStage::Tese => "tese",
                ShaderStage::Geom => "geom",
                ShaderStage::Comp => "comp",
            };

            let tmp_output_path = if req.tmp_output_path.starts_with("/tmp") {
                req.tmp_output_path.clone()
            } else {
                format!("/tmp/{}", req.tmp_output_path)
            };

            if let Some(parent) = Path::new(&tmp_output_path).parent() {
                std::fs::create_dir_all(parent).map_err(|e| {
                    McpError::internal_error(
                        "Failed to create temporary output directory",
                        Some(json!({"error": e.to_string()})),
                    )
                })?;
            }

            // Create compiler and options
            let mut compiler = Compiler::new().ok().ok_or_else(|| {
                McpError::internal_error(
                    "Failed to create shaderc compiler",
                    Some(json!({"error": "Could not instantiate shaderc compiler"})),
                )
            })?;

            let mut options = CompileOptions::new().ok().ok_or_else(|| {
                McpError::internal_error(
                    "Failed to create shaderc compile options",
                    Some(json!({"error": "Could not create compiler options"})),
                )
            })?;

            // Set options equivalent to the CLI flags
            options.set_target_env(
                shaderc::TargetEnv::Vulkan,
                shaderc::EnvVersion::Vulkan1_4 as u32,
            );
            options.set_optimization_level(OptimizationLevel::Performance);
            options.set_generate_debug_info();

            // Compile to SPIR-V assembly
            let artifact = match compiler.compile_into_spirv_assembly(
                &req.source,
                shader_kind,
                "shader.glsl", // source name for error reporting
                "main",        // entry point
                Some(&options),
            ) {
                Ok(artifact) => artifact,
                Err(e) => {
                    let stage_name = match req.stage {
                        ShaderStage::Vert => "Vertex",
                        ShaderStage::Frag => "Fragment",
                        ShaderStage::Tesc => "Tessellation Control",
                        ShaderStage::Tese => "Tessellation Evaluation",
                        ShaderStage::Geom => "Geometry",
                        ShaderStage::Comp => "Compute",
                    };

                    let error_details = format!("{}", e);

                    return Ok(CallToolResult::success(vec![Content::text(format!(
                        "Shader compilation failed for {} shader:\n\nError:\n{}\n\nShader Source:\n{}\n",
                        stage_name, error_details, req.source
                    ))]));
                }
            };

            // Write the compiled SPIR-V assembly to the output file, filtering unsupported lines
            let spv_text = artifact.as_text();
            let filtered_spv = spv_text
                .lines()
                .filter(|l| !l.trim_start().starts_with("OpModuleProcessed"))
                .collect::<Vec<_>>()
                .join("\n");
            std::fs::write(&tmp_output_path, filtered_spv).map_err(|e| {
                McpError::internal_error(
                    "Failed to write compiled shader to file",
                    Some(json!({"error": e.to_string()})),
                )
            })?;
        }

        let shader_test_path = "/tmp/vkrunner_test.shader_test";
        let mut shader_test_file = File::create(shader_test_path).map_err(io_err)?;

        if let Some(requirements) = &request.requirements {
            if !requirements.is_empty() {
                writeln!(shader_test_file, "[require]").map_err(io_err)?;

                for req in requirements {
                    match req {
                        ShaderRunnerRequire::CooperativeMatrix {
                            m,
                            n,
                            component_type,
                        } => {
                            writeln!(
                                shader_test_file,
                                "cooperative_matrix m={m} n={n} c={component_type}"
                            )
                            .map_err(io_err)?;
                        }
                        ShaderRunnerRequire::DepthStencil(format) => {
                            writeln!(shader_test_file, "depthstencil {format}").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::Framebuffer(format) => {
                            writeln!(shader_test_file, "framebuffer {format}").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::ShaderFloat64 => {
                            writeln!(shader_test_file, "shaderFloat64").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::GeometryShader => {
                            writeln!(shader_test_file, "geometryShader").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::WideLines => {
                            writeln!(shader_test_file, "wideLines").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::LogicOp => {
                            writeln!(shader_test_file, "logicOp").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::SubgroupSize(size) => {
                            writeln!(shader_test_file, "subgroup_size {size}").map_err(io_err)?;
                        }
                        ShaderRunnerRequire::FragmentStoresAndAtomics => {
                            writeln!(shader_test_file, "fragmentStoresAndAtomics")
                                .map_err(io_err)?;
                        }
                        ShaderRunnerRequire::BufferDeviceAddress => {
                            writeln!(shader_test_file, "bufferDeviceAddress").map_err(io_err)?;
                        }
                    }
                }

                writeln!(shader_test_file).map_err(io_err)?;
            }
        }

        for pass in &request.passes {
            match pass {
                ShaderRunnerPass::VertPassthrough => {
                    writeln!(shader_test_file, "[vertex shader passthrough]").map_err(io_err)?;
                }
                ShaderRunnerPass::VertSpirv { vert_spvasm_path } => {
                    writeln!(shader_test_file, "[vertex shader spirv]").map_err(io_err)?;

                    let mut spvasm = String::new();
                    let path = if vert_spvasm_path.starts_with("/tmp") {
                        vert_spvasm_path.clone()
                    } else {
                        format!("/tmp/{vert_spvasm_path}")
                    };

                    File::open(&path)
                        .map_err(|e| {
                            McpError::internal_error(
                                format!("Failed to open vertex shader SPIR-V file at {path}"),
                                Some(json!({"error": e.to_string()})),
                            )
                        })?
                        .read_to_string(&mut spvasm)
                        .map_err(|e| {
                            McpError::internal_error(
                                "Failed to read vertex shader SPIR-V file",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;
                    writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
                }
                ShaderRunnerPass::FragSpirv { frag_spvasm_path } => {
                    writeln!(shader_test_file, "[fragment shader spirv]").map_err(io_err)?;

                    let mut spvasm = String::new();
                    let path = if frag_spvasm_path.starts_with("/tmp") {
                        frag_spvasm_path.clone()
                    } else {
                        format!("/tmp/{frag_spvasm_path}")
                    };

                    File::open(&path)
                        .map_err(|e| {
                            McpError::internal_error(
                                format!("Failed to open fragment shader SPIR-V file at {path}"),
                                Some(json!({"error": e.to_string()})),
                            )
                        })?
                        .read_to_string(&mut spvasm)
                        .map_err(|e| {
                            McpError::internal_error(
                                "Failed to read fragment shader SPIR-V file",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;
                    writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
                }
                ShaderRunnerPass::CompSpirv { comp_spvasm_path } => {
                    writeln!(shader_test_file, "[compute shader spirv]").map_err(io_err)?;

                    let mut spvasm = String::new();
                    let path = if comp_spvasm_path.starts_with("/tmp") {
                        comp_spvasm_path.clone()
                    } else {
                        format!("/tmp/{comp_spvasm_path}")
                    };

                    File::open(&path)
                        .map_err(|e| {
                            McpError::internal_error(
                                format!("Failed to open compute shader SPIR-V file at {path}"),
                                Some(json!({"error": e.to_string()})),
                            )
                        })?
                        .read_to_string(&mut spvasm)
                        .map_err(|e| {
                            McpError::internal_error(
                                "Failed to read compute shader SPIR-V file",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;
                    writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
                }
                ShaderRunnerPass::GeomSpirv { geom_spvasm_path } => {
                    writeln!(shader_test_file, "[geometry shader spirv]").map_err(io_err)?;

                    let mut spvasm = String::new();
                    let path = if geom_spvasm_path.starts_with("/tmp") {
                        geom_spvasm_path.clone()
                    } else {
                        format!("/tmp/{geom_spvasm_path}")
                    };

                    File::open(&path)
                        .map_err(|e| {
                            McpError::internal_error(
                                format!("Failed to open geometry shader SPIR-V file at {path}"),
                                Some(json!({"error": e.to_string()})),
                            )
                        })?
                        .read_to_string(&mut spvasm)
                        .map_err(|e| {
                            McpError::internal_error(
                                "Failed to read geometry shader SPIR-V file",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;
                    writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
                }
                ShaderRunnerPass::TescSpirv { tesc_spvasm_path } => {
                    writeln!(shader_test_file, "[tessellation control shader spirv]")
                        .map_err(io_err)?;

                    let mut spvasm = String::new();
                    let path = if tesc_spvasm_path.starts_with("/tmp") {
                        tesc_spvasm_path.clone()
                    } else {
                        format!("/tmp/{tesc_spvasm_path}")
                    };

                    File::open(&path)
                        .map_err(|e| {
                            McpError::internal_error(
                                format!(
                                    "Failed to open tessellation control shader SPIR-V file at {path}"
                                ),
                                Some(json!({"error": e.to_string()})),
                            )
                        })?
                        .read_to_string(&mut spvasm)
                        .map_err(|e| {
                            McpError::internal_error(
                                "Failed to read tessellation control shader SPIR-V file",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;
                    writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
                }
                ShaderRunnerPass::TeseSpirv { tese_spvasm_path } => {
                    writeln!(shader_test_file, "[tessellation evaluation shader spirv]")
                        .map_err(io_err)?;

                    let mut spvasm = String::new();
                    let path = if tese_spvasm_path.starts_with("/tmp") {
                        tese_spvasm_path.clone()
                    } else {
                        format!("/tmp/{tese_spvasm_path}")
                    };

                    File::open(&path)
                        .map_err(|e| {
                            McpError::internal_error(
                                format!(
                                    "Failed to open tessellation evaluation shader SPIR-V file at {path}"
                                ),
                                Some(json!({"error": e.to_string()})),
                            )
                        })?
                        .read_to_string(&mut spvasm)
                        .map_err(|e| {
                            McpError::internal_error(
                                "Failed to read tessellation evaluation shader SPIR-V file",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;
                    writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
                }
            }

            writeln!(shader_test_file).map_err(io_err)?;
        }

        if let Some(vertex_data) = &request.vertex_data {
            writeln!(shader_test_file, "[vertex data]").map_err(io_err)?;

            for data in vertex_data {
                if let ShaderRunnerVertexData::AttributeFormat { location, format } = data {
                    writeln!(shader_test_file, "{location}/{format}").map_err(io_err)?;
                } else {
                    match data {
                        ShaderRunnerVertexData::Vec2 { x, y } => {
                            write!(shader_test_file, "{x} {y}").map_err(io_err)?;
                        }
                        ShaderRunnerVertexData::Vec3 { x, y, z } => {
                            write!(shader_test_file, "{x} {y} {z}").map_err(io_err)?;
                        }
                        ShaderRunnerVertexData::Vec4 { x, y, z, w } => {
                            write!(shader_test_file, "{x} {y} {z} {w}").map_err(io_err)?;
                        }
                        ShaderRunnerVertexData::RGB { r, g, b } => {
                            write!(shader_test_file, "{r} {g} {b}").map_err(io_err)?;
                        }
                        ShaderRunnerVertexData::Hex { value } => {
                            write!(shader_test_file, "{value}").map_err(io_err)?;
                        }
                        ShaderRunnerVertexData::GenericComponents { components } => {
                            for component in components {
                                write!(shader_test_file, "{component} ").map_err(io_err)?;
                            }
                        }
                        _ => {}
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
            }

            writeln!(shader_test_file).map_err(io_err)?;
        }

        writeln!(shader_test_file, "[test]").map_err(io_err)?;

        for test_cmd in &request.tests {
            match test_cmd {
                ShaderRunnerTest::FragmentEntrypoint { name } => {
                    writeln!(shader_test_file, "fragment entrypoint {name}").map_err(io_err)?;
                }
                ShaderRunnerTest::VertexEntrypoint { name } => {
                    writeln!(shader_test_file, "vertex entrypoint {name}").map_err(io_err)?;
                }
                ShaderRunnerTest::ComputeEntrypoint { name } => {
                    writeln!(shader_test_file, "compute entrypoint {name}").map_err(io_err)?;
                }
                ShaderRunnerTest::GeometryEntrypoint { name } => {
                    writeln!(shader_test_file, "geometry entrypoint {name}").map_err(io_err)?;
                }
                ShaderRunnerTest::DrawRect {
                    x,
                    y,
                    width,
                    height,
                } => {
                    writeln!(shader_test_file, "draw rect {x} {y} {width} {height}")
                        .map_err(io_err)?;
                }
                ShaderRunnerTest::DrawArrays {
                    primitive_type,
                    first,
                    count,
                } => {
                    writeln!(
                        shader_test_file,
                        "draw arrays {primitive_type} {first} {count}"
                    )
                    .map_err(io_err)?;
                }
                ShaderRunnerTest::DrawArraysIndexed {
                    primitive_type,
                    first,
                    count,
                } => {
                    writeln!(
                        shader_test_file,
                        "draw arrays indexed {primitive_type} {first} {count}"
                    )
                    .map_err(io_err)?;
                }
                ShaderRunnerTest::SSBO {
                    binding,
                    size,
                    data,
                    descriptor_set,
                } => {
                    let set_prefix = if let Some(set) = descriptor_set {
                        format!("{set}:")
                    } else {
                        String::new()
                    };

                    if let Some(size) = size {
                        writeln!(shader_test_file, "ssbo {set_prefix}{binding} {size}")
                            .map_err(io_err)?;
                    } else if let Some(_data) = data {
                        writeln!(shader_test_file, "ssbo {set_prefix}{binding} data")
                            .map_err(io_err)?;
                    }
                }
                ShaderRunnerTest::SSBOSubData {
                    binding,
                    data_type,
                    offset,
                    values,
                    descriptor_set,
                } => {
                    let set_prefix = if let Some(set) = descriptor_set {
                        format!("{set}:")
                    } else {
                        String::new()
                    };

                    write!(
                        shader_test_file,
                        "ssbo {set_prefix}{binding} subdata {data_type} {offset}"
                    )
                    .map_err(io_err)?;
                    for value in values {
                        write!(shader_test_file, " {value}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
                ShaderRunnerTest::UBO {
                    binding,
                    data: _,
                    descriptor_set,
                } => {
                    let set_prefix = if let Some(set) = descriptor_set {
                        format!("{set}:")
                    } else {
                        String::new()
                    };

                    writeln!(shader_test_file, "ubo {set_prefix}{binding} data").map_err(io_err)?;
                }
                ShaderRunnerTest::UBOSubData {
                    binding,
                    data_type,
                    offset,
                    values,
                    descriptor_set,
                } => {
                    let set_prefix = if let Some(set) = descriptor_set {
                        format!("{set}:")
                    } else {
                        String::new()
                    };

                    write!(
                        shader_test_file,
                        "ubo {set_prefix}{binding} subdata {data_type} {offset}"
                    )
                    .map_err(io_err)?;
                    for value in values {
                        write!(shader_test_file, " {value}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
                ShaderRunnerTest::BufferLayout {
                    buffer_type,
                    layout_type,
                } => {
                    writeln!(shader_test_file, "{buffer_type} layout {layout_type}")
                        .map_err(io_err)?;
                }
                ShaderRunnerTest::Push {
                    data_type,
                    offset,
                    values,
                } => {
                    write!(shader_test_file, "push {data_type} {offset}").map_err(io_err)?;
                    for value in values {
                        write!(shader_test_file, " {value}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
                ShaderRunnerTest::PushLayout { layout_type } => {
                    writeln!(shader_test_file, "push layout {layout_type}").map_err(io_err)?;
                }
                ShaderRunnerTest::Compute { x, y, z } => {
                    writeln!(shader_test_file, "compute {x} {y} {z}").map_err(io_err)?;
                }
                ShaderRunnerTest::Probe {
                    probe_type,
                    format,
                    args,
                } => {
                    write!(shader_test_file, "probe {probe_type} {format}").map_err(io_err)?;
                    for arg in args {
                        write!(shader_test_file, " {arg}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
                ShaderRunnerTest::RelativeProbe {
                    probe_type,
                    format,
                    args,
                } => {
                    write!(shader_test_file, "relative probe {probe_type} {format}")
                        .map_err(io_err)?;
                    for arg in args {
                        write!(shader_test_file, " {arg}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
                ShaderRunnerTest::Tolerance { values } => {
                    write!(shader_test_file, "tolerance").map_err(io_err)?;
                    for value in values {
                        write!(shader_test_file, " {value}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
                ShaderRunnerTest::Clear => {
                    writeln!(shader_test_file, "clear").map_err(io_err)?;
                }
                ShaderRunnerTest::DepthTestEnable { enable } => {
                    writeln!(shader_test_file, "depthTestEnable {enable}").map_err(io_err)?;
                }
                ShaderRunnerTest::DepthWriteEnable { enable } => {
                    writeln!(shader_test_file, "depthWriteEnable {enable}").map_err(io_err)?;
                }
                ShaderRunnerTest::DepthCompareOp { op } => {
                    writeln!(shader_test_file, "depthCompareOp {op}").map_err(io_err)?;
                }
                ShaderRunnerTest::StencilTestEnable { enable } => {
                    writeln!(shader_test_file, "stencilTestEnable {enable}").map_err(io_err)?;
                }
                ShaderRunnerTest::FrontFace { mode } => {
                    writeln!(shader_test_file, "frontFace {mode}").map_err(io_err)?;
                }
                ShaderRunnerTest::StencilOp {
                    face,
                    op_name,
                    value,
                } => {
                    writeln!(shader_test_file, "{face}.{op_name} {value}",).map_err(io_err)?;
                }
                ShaderRunnerTest::StencilReference { face, value } => {
                    writeln!(shader_test_file, "{face}.reference {value}",).map_err(io_err)?;
                }
                ShaderRunnerTest::StencilCompareOp { face, op } => {
                    writeln!(shader_test_file, "{face}.compareOp {op}",).map_err(io_err)?;
                }
                ShaderRunnerTest::ColorWriteMask { mask } => {
                    writeln!(shader_test_file, "colorWriteMask {mask}",).map_err(io_err)?;
                }
                ShaderRunnerTest::LogicOpEnable { enable } => {
                    writeln!(shader_test_file, "logicOpEnable {enable}",).map_err(io_err)?;
                }
                ShaderRunnerTest::LogicOp { op } => {
                    writeln!(shader_test_file, "logicOp {op}",).map_err(io_err)?;
                }
                ShaderRunnerTest::CullMode { mode } => {
                    writeln!(shader_test_file, "cullMode {mode}",).map_err(io_err)?;
                }
                ShaderRunnerTest::LineWidth { width } => {
                    writeln!(shader_test_file, "lineWidth {width}").map_err(io_err)?;
                }
                ShaderRunnerTest::Require {
                    feature,
                    parameters,
                } => {
                    write!(shader_test_file, "require {feature}").map_err(io_err)?;
                    for param in parameters {
                        write!(shader_test_file, " {param}").map_err(io_err)?;
                    }
                    writeln!(shader_test_file).map_err(io_err)?;
                }
            }
        }

        shader_test_file.flush().map_err(io_err)?;

        let tmp_image_path = "/tmp/vkrunner_output.ppm";
        if let Some(output_path) = &request.output_path {
            if output_path.starts_with("/tmp") {
                tmp_image_path.to_string()
            } else {
                format!("/tmp/{output_path}")
            }
        } else {
            tmp_image_path.to_string()
        };
        let mut vkrunner_args = vec![shader_test_path];

        if request.output_path.is_some() {
            vkrunner_args.push("--image");
            vkrunner_args.push(tmp_image_path);
        }

        let vkrunner_output = Command::new("vkrunner")
            .args(&vkrunner_args)
            .stdout(Stdio::piped())
            .stderr(Stdio::piped())
            .output()
            .map_err(|e| {
                McpError::internal_error(
                    "Failed to run vkrunner",
                    Some(json!({"error": e.to_string()})),
                )
            })?;

        let stdout = String::from_utf8_lossy(&vkrunner_output.stdout).to_string();
        let stderr = String::from_utf8_lossy(&vkrunner_output.stderr).to_string();

        let mut result_message = if vkrunner_output.status.success() {
            format!(
                "Shader compilation successful using shaderc-rs.\nVkRunner execution successful.\n\nOutput:\n{stdout}\n\n"
            )
        } else {
            format!(
                "Shader compilation successful using shaderc-rs.\nVkRunner execution failed.\n\nOutput:\n{stdout}\n\nError:\n{stderr}\n\n",
            )
        };

        if let Some(output_path) = &request.output_path {
            if vkrunner_output.status.success() && Path::new(tmp_image_path).exists() {
                match read_and_decode_ppm_file(tmp_image_path) {
                    Ok(img) => {
                        if let Some(parent) = Path::new(output_path).parent() {
                            if !parent.as_os_str().is_empty() {
                                std::fs::create_dir_all(parent).map_err(|e| {
                                    McpError::internal_error(
                                        "Failed to create output directory",
                                        Some(json!({"error": e.to_string()})),
                                    )
                                })?;
                            }
                        }

                        img.save(output_path).map_err(|e| {
                            McpError::internal_error(
                                "Failed to save output image",
                                Some(json!({"error": e.to_string()})),
                            )
                        })?;

                        result_message.push_str(&format!("Image saved to: {output_path}\n"));
                    }
                    Err(e) => {
                        result_message.push_str(&format!("Failed to convert output image: {e}\n"));
                    }
                }
            } else if vkrunner_output.status.success() {
                result_message.push_str("No output image was generated by VkRunner.\n");
            }
        }

        result_message.push_str("\nShader Test File Contents:\n");
        result_message.push_str(
            &std::fs::read_to_string(shader_test_path)
                .unwrap_or_else(|_| "Failed to read shader test file".to_string()),
        );

        Ok(CallToolResult::success(vec![Content::text(result_message)]))
    }
}

impl Default for ShadercVkrunnerMcp {
    fn default() -> Self {
        Self::new()
    }
}

const_string!(Echo = "echo");
#[tool(tool_box)]
impl ServerHandler for ShadercVkrunnerMcp {
    fn get_info(&self) -> ServerInfo {
        ServerInfo {
            protocol_version: ProtocolVersion::V_2024_11_05,
            capabilities: ServerCapabilities::builder()
                .enable_tools()
                .build(),
            server_info: Implementation::from_build_env(),
            instructions: Some("This server provides tools for compiling and running GLSL shaders using Vulkan infrastructure. The typical workflow is:

            1. Compile GLSL source code to SPIR-V assembly using the 'compile_run_shaders' tool
            2. Specify hardware/feature requirements if needed (e.g., geometry shaders, floating-point formats)
            3. Reference compiled SPIR-V files in shader passes
            4. Define vertex data if rendering geometry
            5. Set up test commands to draw or compute
            6. Optionally save the rendered output as an image".to_string()),
        }
    }
}

#[derive(Parser, Debug)]
#[command(about)]
struct Args {
    #[clap(short, long, value_parser)]
    work_dir: Option<PathBuf>,
}

#[tokio::main]
async fn main() -> Result<()> {
    let args = Args::parse();

    if let Some(work_dir) = args.work_dir {
        std::env::set_current_dir(&work_dir)
            .map_err(|e| {
                eprintln!("Failed to set working directory to {work_dir:?}: {e}");
                std::process::exit(1);
            })
            .unwrap();
    }

    tracing_subscriber::fmt()
        .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
        .with_writer(std::io::stderr)
        .with_ansi(false)
        .init();

    tracing::info!("Starting MCP server");

    let service = ShadercVkrunnerMcp::new()
        .serve(stdio())
        .await
        .inspect_err(|e| {
            tracing::error!("serving error: {:?}", e);
        })?;

    service.waiting().await?;
    Ok(())
}

```

--------------------------------------------------------------------------------
/vkrunner/vkrunner/requirements.rs:
--------------------------------------------------------------------------------

```rust
// vkrunner
//
// Copyright (C) 2019 Intel Corporation
// Copyright 2023 Neil Roberts
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice (including the next
// paragraph) shall be included in all copies or substantial portions of the
// Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::vk;
use crate::vulkan_funcs;
use crate::vulkan_funcs::{NEXT_PTR_OFFSET, FIRST_FEATURE_OFFSET};
use crate::result;
use std::mem;
use std::collections::{HashMap, HashSet};
use std::ffi::CStr;
use std::convert::TryInto;
use std::fmt;
use std::cell::UnsafeCell;
use std::ptr;


#[derive(Debug)]
struct Extension {
    // The name of the extension that provides this set of features
    // stored as a null-terminated byte sequence. It is stored this
    // way because that is how bindgen generates string literals from
    // headers.
    name_bytes: &'static [u8],
    // The size of the corresponding features struct
    struct_size: usize,
    // The enum for this struct
    struct_type: vk::VkStructureType,
    // List of feature names in this extension in the order they
    // appear in the features struct. The names are as written in the
    // struct definition.
    features: &'static [&'static str],
}

impl Extension {
    fn name(&self) -> &str {
        // The name comes from static data so it should always be
        // valid
        CStr::from_bytes_with_nul(self.name_bytes)
            .unwrap()
            .to_str()
            .unwrap()
    }
}

include!("features.rs");

// Lazily generated extensions data used for passing to Vulkan
struct LazyExtensions {
    // All of the null-terminated strings concatenated into a buffer
    strings: Vec<u8>,
    // Pointers to the start of each string
    pointers: Vec<* const u8>,
}

// Lazily generated structures data used for passing to Vulkan
struct LazyStructures {
    // A buffer used to return a linked chain of structs that can be
    // accessed from C.
    list: Vec<u8>,
}

// Number of features in VkPhysicalDeviceFeatures. The struct is just
// a series of VkBool32 so we should be able to safely calculate the
// number based on the struct size.
const N_PHYSICAL_DEVICE_FEATURES: usize =
    mem::size_of::<vk::VkPhysicalDeviceFeatures>()
    / mem::size_of::<vk::VkBool32>();

// Lazily generated VkPhysicalDeviceFeatures struct to pass to Vulkan
struct LazyBaseFeatures {
    // Array big enough to store a VkPhysicalDeciveFeatures struct.
    // It’s easier to manipulate as an array instead of the actual
    // struct because we want to update the bools by index.
    bools: [vk::VkBool32; N_PHYSICAL_DEVICE_FEATURES],
}

#[derive(Debug, Default, Clone)]
pub struct CooperativeMatrix {
    pub m_size: Option<u32>,
    pub n_size: Option<u32>,
    pub k_size: Option<u32>,
    pub a_type: Option<vk::VkComponentTypeKHR>,
    pub b_type: Option<vk::VkComponentTypeKHR>,
    pub c_type: Option<vk::VkComponentTypeKHR>,
    pub result_type: Option<vk::VkComponentTypeKHR>,
    pub saturating_accumulation: Option<vk::VkBool32>,
    pub scope: Option<vk::VkScopeKHR>,

    pub line: String,
}

#[derive(Debug)]
pub struct Requirements {
    // Minimum vulkan version
    version: u32,
    // Set of extension names required
    extensions: HashSet<String>,
    // A map indexed by extension number. The value is an array of
    // bools representing whether we need each feature in the
    // extension’s feature struct.
    features: HashMap<usize, Box<[bool]>>,
    // An array of bools corresponding to each feature in the base
    // VkPhysicalDeviceFeatures struct
    base_features: [bool; N_BASE_FEATURES],
    // Required subgroup size to be used with
    // VkPipelineShaderStageRequiredSubgroupSizeCreateInfo.
    pub required_subgroup_size: Option<u32>,

    cooperative_matrix_reqs: Vec<CooperativeMatrix>,

    // The rest of the struct is lazily created from the above data
    // and shouldn’t be part of the PartialEq implementation.

    lazy_extensions: UnsafeCell<Option<LazyExtensions>>,
    lazy_structures: UnsafeCell<Option<LazyStructures>>,
    lazy_base_features: UnsafeCell<Option<LazyBaseFeatures>>,
}

/// Error returned by [Requirements::check]
#[derive(Debug)]
pub enum Error {
    EnumerateDeviceExtensionPropertiesFailed,
    ExtensionMissingNullTerminator,
    ExtensionInvalidUtf8,
    /// A required base feature from VkPhysicalDeviceFeatures is missing.
    MissingBaseFeature(usize),
    /// A required extension is missing. The string is the name of the
    /// extension.
    MissingExtension(String),
    /// A required feature is missing.
    MissingFeature { extension: usize, feature: usize },
    /// The API version reported by the driver is too low
    VersionTooLow { required_version: u32, actual_version: u32 },
    /// Required subgroup size out of supported range.
    RequiredSubgroupSizeInvalid { size: u32, min: u32, max: u32 },
    MissingCooperativeMatrixProperties(String),
}

impl Error {
    pub fn result(&self) -> result::Result {
        match self {
            Error::EnumerateDeviceExtensionPropertiesFailed => {
                result::Result::Fail
            },
            Error::ExtensionMissingNullTerminator => result::Result::Fail,
            Error::ExtensionInvalidUtf8 => result::Result::Fail,
            Error::MissingBaseFeature(_) => result::Result::Skip,
            Error::MissingExtension(_) => result::Result::Skip,
            Error::MissingFeature { .. } => result::Result::Skip,
            Error::VersionTooLow { .. } => result::Result::Skip,
            Error::RequiredSubgroupSizeInvalid { .. } => result::Result::Skip,
            Error::MissingCooperativeMatrixProperties(_) => result::Result::Skip,
        }
    }
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Error::EnumerateDeviceExtensionPropertiesFailed => {
                write!(f, "vkEnumerateDeviceExtensionProperties failed")
            },
            Error::ExtensionMissingNullTerminator => {
                write!(
                    f,
                    "NULL terminator missing in string returned from \
                     vkEnumerateDeviceExtensionProperties"
                )
            },
            Error::ExtensionInvalidUtf8 => {
                write!(
                    f,
                    "Invalid UTF-8 in string returned from \
                     vkEnumerateDeviceExtensionProperties"
                )
            },
            &Error::MissingBaseFeature(feature_num) => {
                write!(
                    f,
                    "Missing required feature: {}",
                    BASE_FEATURES[feature_num],
                )
            },
            Error::MissingExtension(s) => {
                write!(f, "Missing required extension: {}", s)
            },
            &Error::MissingFeature { extension, feature } => {
                write!(
                    f,
                    "Missing required feature “{}” from extension “{}”",
                    EXTENSIONS[extension].features[feature],
                    EXTENSIONS[extension].name(),
                )
            },
            &Error::VersionTooLow { required_version, actual_version } => {
                let (req_major, req_minor, req_patch) =
                    extract_version(required_version);
                let (actual_major, actual_minor, actual_patch) =
                    extract_version(actual_version);
                write!(
                    f,
                    "Vulkan API version {}.{}.{} required but the driver \
                     reported {}.{}.{}",
                    req_major, req_minor, req_patch,
                    actual_major, actual_minor, actual_patch,
                )
            },
            &Error::RequiredSubgroupSizeInvalid { size, min, max } => {
                write!(
                    f,
                    "Required subgroup size {} is not in the supported range {}-{}",
                    size, min, max
                )
            },
            Error::MissingCooperativeMatrixProperties(s) => {
                write!(f, "Physical device can't fulfill cooperative matrix requirements: {}", s.trim())
            },
        }
    }
}

/// Convert a decomposed Vulkan version into an integer. This is the
/// same as the `VK_MAKE_VERSION` macro in the Vulkan headers.
pub const fn make_version(major: u32, minor: u32, patch: u32) -> u32 {
    (major << 22) | (minor << 12) | patch
}

/// Decompose a Vulkan version into its component major, minor and
/// patch parts. This is the equivalent of the `VK_VERSION_MAJOR`,
/// `VK_VERSION_MINOR` and `VK_VERSION_PATCH` macros in the Vulkan
/// header.
pub const fn extract_version(version: u32) -> (u32, u32, u32) {
    (version >> 22, (version >> 12) & 0x3ff, version & 0xfff)
}

impl Requirements {
    pub fn new() -> Requirements {
        Requirements {
            version: make_version(1, 0, 0),
            extensions: HashSet::new(),
            features: HashMap::new(),
            base_features: [false; N_BASE_FEATURES],
            required_subgroup_size: None,
            lazy_extensions: UnsafeCell::new(None),
            lazy_structures: UnsafeCell::new(None),
            lazy_base_features: UnsafeCell::new(None),
            cooperative_matrix_reqs: Vec::new(),
        }
    }

    /// Get the required Vulkan version that was previously set with
    /// [add_version](Requirements::add_version).
    pub fn version(&self) -> u32 {
        self.version
    }

    /// Set the minimum required Vulkan version.
    pub fn add_version(&mut self, major: u32, minor: u32, patch: u32) {
        self.version = std::cmp::max(self.version, make_version(major, minor, patch));
    }

    pub fn add_cooperative_matrix_req(&mut self, req: CooperativeMatrix) {
        self.cooperative_matrix_reqs.push(req);
    }

    fn update_lazy_extensions(&self) -> &[* const u8] {
        // SAFETY: The only case where lazy_extensions will be
        // modified is if it is None and the only way for that to
        // happen is if something modifies the requirements through a
        // mutable reference. If that is the case then at that point
        // there were no other references to the requirements so
        // nothing can be holding a reference to the lazy data.
        let extensions = unsafe {
            match &mut *self.lazy_extensions.get() {
                Some(data) => return data.pointers.as_ref(),
                option @ None => {
                    option.insert(LazyExtensions {
                        strings: Vec::new(),
                        pointers: Vec::new(),
                    })
                },
            }
        };

        // Store a list of offsets into the c_extensions array for
        // each extension. We can’t directly store the pointers yet
        // because the Vec will probably be reallocated while we are
        // adding to it.
        let mut offsets = Vec::<usize>::new();

        for extension in self.extensions.iter() {
            offsets.push(extensions.strings.len());

            extensions.strings.extend_from_slice(extension.as_bytes());
            // Add the null terminator
            extensions.strings.push(0);
        }

        let base_ptr = extensions.strings.as_ptr();

        extensions.pointers.reserve(offsets.len());

        for offset in offsets {
            // SAFETY: These are all valid offsets into the
            // c_extensions Vec so they should all be in the same
            // allocation and no overflow is possible.
            extensions.pointers.push(unsafe { base_ptr.add(offset) });
        }

        extensions.pointers.as_ref()
    }

    /// Return a reference to an array of pointers to C-style strings
    /// that can be passed to `vkCreateDevice`.
    pub fn c_extensions(&self) -> &[* const u8] {
        self.update_lazy_extensions()
    }

    // Make a linked list of features structures that is suitable for
    // passing to VkPhysicalDeviceFeatures2 or vkCreateDevice. All of
    // the bools are set to 0. The vec with the structs is returned as
    // well as a list of offsets to the bools along with the extension
    // number.
    fn make_empty_structures(
        &self
    ) -> (Vec<u8>, Vec<(usize, usize)>)
    {
        let mut structures = Vec::new();

        // Keep an array of offsets and corresponding extension num in a
        // vec for each structure that will be returned. The offset is
        // also used to update the pNext pointers. We have to do this
        // after filling the vec because the vec will probably
        // reallocate while we are adding to it which would invalidate
        // the pointers.
        let mut offsets = Vec::<(usize, usize)>::new();

        for (&extension_num, features) in self.features.iter() {
            let extension = &EXTENSIONS[extension_num];

            // Make sure the struct size is big enough to hold all of
            // the bools
            assert!(
                extension.struct_size
                    >= (FIRST_FEATURE_OFFSET
                        + features.len() * mem::size_of::<vk::VkBool32>())
            );

            let offset = structures.len();
            structures.resize(offset + extension.struct_size, 0);

            // The structure type is the first field in the structure
            let type_end = offset + mem::size_of::<vk::VkStructureType>();
            structures[offset..type_end]
                .copy_from_slice(&extension.struct_type.to_ne_bytes());

            offsets.push((offset + FIRST_FEATURE_OFFSET, extension_num));
        }

        let mut last_offset = 0;

        for &(offset, _) in &offsets[1..] {
            let offset = offset - FIRST_FEATURE_OFFSET;

            // SAFETY: The offsets are all valid offsets into the
            // structures vec so they should all be in the same
            // allocation and no overflow should occur.
            let ptr = unsafe { structures.as_ptr().add(offset) };

            let ptr_start = last_offset + NEXT_PTR_OFFSET;
            let ptr_end = ptr_start + mem::size_of::<*const u8>();
            structures[ptr_start..ptr_end]
                .copy_from_slice(&(ptr as usize).to_ne_bytes());

            last_offset = offset;
        }

        (structures, offsets)
    }

    fn update_lazy_structures(&self) -> &[u8] {
        // SAFETY: The only case where lazy_structures will be
        // modified is if it is None and the only way for that to
        // happen is if something modifies the requirements through a
        // mutable reference. If that is the case then at that point
        // there were no other references to the requirements so
        // nothing can be holding a reference to the lazy data.
        let (structures, offsets) = unsafe {
            match &mut *self.lazy_structures.get() {
                Some(data) => return data.list.as_slice(),
                option @ None => {
                    let (list, offsets) = self.make_empty_structures();

                    (option.insert(LazyStructures { list }), offsets)
                },
            }
        };

        for (offset, extension_num) in offsets {
            let features = &self.features[&extension_num];

            for (feature_num, &feature) in features.iter().enumerate() {
                let feature_start =
                    offset
                    + mem::size_of::<vk::VkBool32>()
                    * feature_num;
                let feature_end =
                    feature_start + mem::size_of::<vk::VkBool32>();

                structures.list[feature_start..feature_end]
                    .copy_from_slice(&(feature as vk::VkBool32).to_ne_bytes());
            }
        }

        structures.list.as_slice()
    }

    /// Return a pointer to a linked list of feature structures that
    /// can be passed to `vkCreateDevice`, or `None` if no feature
    /// structs are required.
    pub fn c_structures(&self) -> Option<&[u8]> {
        if self.features.is_empty() {
            None
        } else {
            Some(self.update_lazy_structures())
        }
    }

    fn update_lazy_base_features(&self) -> &[vk::VkBool32] {
        // SAFETY: The only case where lazy_structures will be
        // modified is if it is None and the only way for that to
        // happen is if something modifies the requirements through a
        // mutable reference. If that is the case then at that point
        // there were no other references to the requirements so
        // nothing can be holding a reference to the lazy data.
        let base_features = unsafe {
            match &mut *self.lazy_base_features.get() {
                Some(data) => return &data.bools,
                option @ None => {
                    option.insert(LazyBaseFeatures {
                        bools: [0; N_PHYSICAL_DEVICE_FEATURES],
                    })
                },
            }
        };

        for (feature_num, &feature) in self.base_features.iter().enumerate() {
            base_features.bools[feature_num] = feature as vk::VkBool32;
        }

        &base_features.bools
    }

    /// Return a pointer to a `VkPhysicalDeviceFeatures` struct that
    /// can be passed to `vkCreateDevice`.
    pub fn c_base_features(&self) -> &vk::VkPhysicalDeviceFeatures {
        unsafe {
            &*self.update_lazy_base_features().as_ptr().cast()
        }
    }

    fn add_extension_name(&mut self, name: &str) {
        // It would be nice to use get_or_insert_owned here if it
        // becomes stable. It’s probably better not to use
        // HashSet::replace directly because if it’s already in the
        // set then we’ll pointlessly copy the str slice and
        // immediately free it.
        if !self.extensions.contains(name) {
            self.extensions.insert(name.to_owned());
            // SAFETY: self is immutable so there should be no other
            // reference to the lazy data
            unsafe {
                *self.lazy_extensions.get() = None;
            }
        }
    }

    /// Adds a requirement to the list of requirements.
    ///
    /// The name can be either a feature as written in the
    /// corresponding features struct or the name of an extension. If
    /// it is a feature it needs to be either the name of a field in
    /// the `VkPhysicalDeviceFeatures` struct or a field in any of the
    /// features structs of the extensions that vkrunner knows about.
    /// In the latter case the name of the corresponding extension
    /// will be automatically added as a requirement.
    pub fn add(&mut self, name: &str) {
        if let Some((extension_num, feature_num)) = find_feature(name) {
            let extension = &EXTENSIONS[extension_num];

            self.add_extension_name(extension.name());

            let features = self
                .features
                .entry(extension_num)
                .or_insert_with(|| {
                    vec![false; extension.features.len()].into_boxed_slice()
                });

            if !features[feature_num] {
                // SAFETY: self is immutable so there should be no other
                // reference to the lazy data
                unsafe {
                    *self.lazy_structures.get() = None;
                }
                features[feature_num] = true;
            }
        } else if let Some(num) = find_base_feature(name) {
            if !self.base_features[num] {
                self.base_features[num] = true;
                // SAFETY: self is immutable so there should be no other
                // reference to the lazy data
                unsafe {
                    *self.lazy_structures.get() = None;
                }
            }
        } else {
            self.add_extension_name(name);
        }
    }

    fn check_base_features(
        &self,
        vkinst: &vulkan_funcs::Instance,
        device: vk::VkPhysicalDevice
    ) -> Result<(), Error> {
        let mut actual_features: [vk::VkBool32; N_BASE_FEATURES] =
            [0; N_BASE_FEATURES];

        unsafe {
            vkinst.vkGetPhysicalDeviceFeatures.unwrap()(
                device,
                actual_features.as_mut_ptr().cast()
            );
        }

        for (feature_num, &required) in self.base_features.iter().enumerate() {
            if required && actual_features[feature_num] == 0 {
                return Err(Error::MissingBaseFeature(feature_num));
            }
        }

        Ok(())
    }

    fn get_device_extensions(
        &self,
        vkinst: &vulkan_funcs::Instance,
        device: vk::VkPhysicalDevice
    ) -> Result<HashSet::<String>, Error> {
        let mut property_count = 0u32;

        let res = unsafe {
            vkinst.vkEnumerateDeviceExtensionProperties.unwrap()(
                device,
                std::ptr::null(), // layerName
                &mut property_count as *mut u32,
                std::ptr::null_mut(), // properties
            )
        };

        if res != vk::VK_SUCCESS {
            return Err(Error::EnumerateDeviceExtensionPropertiesFailed);
        }

        let mut extensions = Vec::<vk::VkExtensionProperties>::with_capacity(
            property_count as usize
        );

        unsafe {
            let res = vkinst.vkEnumerateDeviceExtensionProperties.unwrap()(
                device,
                std::ptr::null(), // layerName
                &mut property_count as *mut u32,
                extensions.as_mut_ptr(),
            );

            if res != vk::VK_SUCCESS {
                return Err(Error::EnumerateDeviceExtensionPropertiesFailed);
            }

            // SAFETY: The FFI call to
            // vkEnumerateDeviceExtensionProperties should have filled
            // the extensions array with valid values so we can safely
            // set the length to the capacity we allocated earlier
            extensions.set_len(property_count as usize);
        };

        let mut extensions_set = HashSet::new();

        for extension in extensions.iter() {
            let name = &extension.extensionName;
            // Make sure it has a NULL terminator
            if let None = name.iter().find(|&&b| b == 0) {
                return Err(Error::ExtensionMissingNullTerminator);
            }
            // SAFETY: we just checked that the array has a null terminator
            let name = unsafe { CStr::from_ptr(name.as_ptr()) };
            let name = match name.to_str() {
                Err(_) => {
                    return Err(Error::ExtensionInvalidUtf8);
                },
                Ok(s) => s,
            };
            extensions_set.insert(name.to_owned());
        }

        Ok(extensions_set)
    }

    fn check_extensions(
        &self,
        vkinst: &vulkan_funcs::Instance,
        device: vk::VkPhysicalDevice
    ) -> Result<(), Error> {
        if self.extensions.is_empty() {
            return Ok(());
        }

        let actual_extensions = self.get_device_extensions(vkinst, device)?;

        for extension in self.extensions.iter() {
            if !actual_extensions.contains(extension) {
                return Err(Error::MissingExtension(extension.to_string()));
            }
        }

        Ok(())
    }

    fn check_structures(
        &self,
        vkinst: &vulkan_funcs::Instance,
        device: vk::VkPhysicalDevice
    ) -> Result<(), Error> {
        if self.features.is_empty() {
            return Ok(());
        }

        // If vkGetPhysicalDeviceFeatures2KHR isn’t available then we
        // can probably assume that none of the extensions are
        // available.
        let get_features = match vkinst.vkGetPhysicalDeviceFeatures2KHR {
            None => {
                // Find the first feature and report that as missing
                for (&extension, features) in self.features.iter() {
                    for (feature, &enabled) in features.iter().enumerate() {
                        if enabled {
                            return Err(Error::MissingFeature {
                                extension,
                                feature,
                            });
                        }
                    }
                }
                unreachable!("Requirements::features should be empty if no \
                              features are required");
            },
            Some(func) => func,
        };

        let (mut actual_structures, offsets) = self.make_empty_structures();

        let mut features_query = vk::VkPhysicalDeviceFeatures2 {
            sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2,
            pNext: actual_structures.as_mut_ptr().cast(),
            features: Default::default(),
        };

        unsafe {
            get_features(
                device,
                &mut features_query as *mut vk::VkPhysicalDeviceFeatures2,
            );
        }

        for (offset, extension_num) in offsets {
            let features = &self.features[&extension_num];

            for (feature_num, &feature) in features.iter().enumerate() {
                if !feature {
                    continue;
                }

                let feature_start =
                    offset
                    + mem::size_of::<vk::VkBool32>()
                    * feature_num;
                let feature_end =
                    feature_start + mem::size_of::<vk::VkBool32>();
                let actual_value = vk::VkBool32::from_ne_bytes(
                    actual_structures[feature_start..feature_end]
                        .try_into()
                        .unwrap()
                );

                if actual_value == 0 {
                    return Err(Error::MissingFeature {
                        extension: extension_num,
                        feature: feature_num,
                    });
                }
            }
        }

        Ok(())
    }

    fn check_version(
        &self,
        vkinst: &vulkan_funcs::Instance,
        device: vk::VkPhysicalDevice,
    ) -> Result<(), Error> {
        let actual_version = if self.version() >= make_version(1, 1, 0) {
            let mut subgroup_size_props = vk::VkPhysicalDeviceSubgroupSizeControlProperties {
                sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_PROPERTIES,
                ..Default::default()
            };
            let mut props = vk::VkPhysicalDeviceProperties2 {
                sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2,
                pNext: std::ptr::addr_of_mut!(subgroup_size_props).cast(),
                ..Default::default()
            };

            unsafe {
                vkinst.vkGetPhysicalDeviceProperties2.unwrap()(
                    device,
                    &mut props as *mut vk::VkPhysicalDeviceProperties2,
                );
            }

            if let Some(size) = self.required_subgroup_size {
                let min = subgroup_size_props.minSubgroupSize;
                let max = subgroup_size_props.maxSubgroupSize;
                if size < min || size > max {
                    return Err(Error::RequiredSubgroupSizeInvalid {
                        size, min, max
                    });
                }
            }

            props.properties.apiVersion
        } else {
            let mut props = vk::VkPhysicalDeviceProperties::default();

            unsafe {
                vkinst.vkGetPhysicalDeviceProperties.unwrap()(
                    device,
                    &mut props as *mut vk::VkPhysicalDeviceProperties,
                );
            }

            props.apiVersion
        };

        if actual_version < self.version() {
            Err(Error::VersionTooLow {
                required_version: self.version(),
                actual_version,
            })
        } else {
            Ok(())
        }
    }

    fn check_cooperative_matrix(&self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice) -> Result<(), Error> {
        if self.cooperative_matrix_reqs.len() == 0 {
            return Ok(());
        }

        let mut count = 0u32;

        unsafe {
            vkinst.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR.unwrap()(
                device,
                ptr::addr_of_mut!(count),
                ptr::null_mut(),
            );
        }

        let mut props = Vec::<vk::VkCooperativeMatrixPropertiesKHR>::new();
        props.resize_with(count as usize, Default::default);

        for prop in &mut props {
            prop.sType = vk::VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
        }

        unsafe {
            vkinst.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR.unwrap()(
                device,
                ptr::addr_of_mut!(count),
                props.as_mut_ptr(),
            );
        }

        for req in &self.cooperative_matrix_reqs {
            let mut found = false;
            for prop in &props {
                if let Some(m) = req.m_size {
                    if m != prop.MSize {
                        continue
                    }
                }
                if let Some(n) = req.n_size {
                    if n != prop.NSize {
                        continue
                    }
                }
                if let Some(k) = req.k_size {
                    if k != prop.KSize {
                        continue
                    }
                }
                if let Some(a) = req.a_type {
                    if a != prop.AType {
                        continue
                    }
                }
                if let Some(b) = req.b_type {
                    if b != prop.BType {
                        continue
                    }
                }
                if let Some(c) = req.c_type {
                    if c != prop.CType {
                        continue
                    }
                }
                if let Some(result) = req.result_type {
                    if result != prop.ResultType {
                        continue
                    }
                }
                if let Some(sa) = req.saturating_accumulation {
                    if sa != prop.saturatingAccumulation {
                        continue
                    }
                }
                if let Some(scope) = req.scope {
                    if scope != prop.scope {
                        continue
                    }
                }
                found = true;
                break
            }
            if !found {
                return Err(Error::MissingCooperativeMatrixProperties(req.line.clone()));
            }
        }

        Ok(())
    }

    pub fn check(
        &self,
        vkinst: &vulkan_funcs::Instance,
        device: vk::VkPhysicalDevice
    ) -> Result<(), Error> {
        self.check_base_features(vkinst, device)?;
        self.check_extensions(vkinst, device)?;
        self.check_structures(vkinst, device)?;
        self.check_version(vkinst, device)?;
        self.check_cooperative_matrix(vkinst, device)?;

        Ok(())
    }
}

// Looks for a feature with the given name. If found it returns the
// index of the extension it was found in and the index of the feature
// name.
fn find_feature(name: &str) -> Option<(usize, usize)> {
    for (extension_num, extension) in EXTENSIONS.iter().enumerate() {
        for (feature_num, &feature) in extension.features.iter().enumerate() {
            if feature == name {
                return Some((extension_num, feature_num));
            }
        }
    }

    None
}

fn find_base_feature(name: &str) -> Option<usize> {
    BASE_FEATURES.iter().position(|&f| f == name)
}

impl PartialEq for Requirements {
    fn eq(&self, other: &Requirements) -> bool {
        self.version == other.version
            && self.extensions == other.extensions
            && self.features == other.features
            && self.base_features == other.base_features
            && self.required_subgroup_size == other.required_subgroup_size
    }
}

impl Eq for Requirements {
}

// Manual implementation of clone to avoid copying the lazy state
impl Clone for Requirements {
    fn clone(&self) -> Requirements {
        Requirements {
            version: self.version,
            extensions: self.extensions.clone(),
            features: self.features.clone(),
            base_features: self.base_features.clone(),
            required_subgroup_size: self.required_subgroup_size,
            lazy_extensions: UnsafeCell::new(None),
            lazy_structures: UnsafeCell::new(None),
            lazy_base_features: UnsafeCell::new(None),
            cooperative_matrix_reqs: self.cooperative_matrix_reqs.clone(),
        }
    }

    fn clone_from(&mut self, source: &Requirements) {
        self.version = source.version;
        self.extensions.clone_from(&source.extensions);
        self.features.clone_from(&source.features);
        self.base_features.clone_from(&source.base_features);
        self.required_subgroup_size = source.required_subgroup_size;
        self.cooperative_matrix_reqs.clone_from(&source.cooperative_matrix_reqs);
        // SAFETY: self is immutable so there should be no other
        // reference to the lazy data
        unsafe {
            *self.lazy_extensions.get() = None;
            *self.lazy_structures.get() = None;
            *self.lazy_base_features.get() = None;
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use std::ffi::{c_char, c_void};
    use crate::fake_vulkan;

    fn get_struct_type(structure: &[u8]) -> vk::VkStructureType {
        let slice = &structure[0..mem::size_of::<vk::VkStructureType>()];
        let mut bytes = [0; mem::size_of::<vk::VkStructureType>()];
        bytes.copy_from_slice(slice);
        vk::VkStructureType::from_ne_bytes(bytes)
    }

    fn get_next_structure(structure: &[u8]) -> &[u8] {
        let slice = &structure[
            NEXT_PTR_OFFSET..NEXT_PTR_OFFSET + mem::size_of::<usize>()
        ];
        let mut bytes = [0; mem::size_of::<usize>()];
        bytes.copy_from_slice(slice);
        let pointer = usize::from_ne_bytes(bytes);
        let offset = pointer - structure.as_ptr() as usize;
        &structure[offset..]
    }

    fn find_bools_for_extension<'a, 'b>(
        mut structures: &'a [u8],
        extension: &'b Extension
    ) -> &'a [vk::VkBool32] {
        while !structures.is_empty() {
            let struct_type = get_struct_type(structures);

            if struct_type == extension.struct_type {
                unsafe {
                    let bools_ptr = structures
                        .as_ptr()
                        .add(FIRST_FEATURE_OFFSET);
                    return std::slice::from_raw_parts(
                        bools_ptr as *const vk::VkBool32,
                        extension.features.len()
                    );
                }
            }

            structures = get_next_structure(structures);
        }

        unreachable!("No structure found for extension “{}”", extension.name());
    }

    unsafe fn extension_in_c_extensions(
        reqs: &mut Requirements,
        ext: &str
    ) -> bool {
        for &p in reqs.c_extensions().iter() {
            if CStr::from_ptr(p as *const c_char).to_str().unwrap() == ext {
                return true;
            }
        }

        false
    }

    #[test]
    fn test_all_features() {
        let mut reqs = Requirements::new();

        for extension in EXTENSIONS.iter() {
            for feature in extension.features.iter() {
                reqs.add(feature);
            }
        }

        for feature in BASE_FEATURES.iter() {
            reqs.add(feature);
        }

        for (extension_num, extension) in EXTENSIONS.iter().enumerate() {
            // All of the extensions should be in the set
            assert!(reqs.extensions.contains(extension.name()));
            // All of the features of every extension should be true
            assert!(reqs.features[&extension_num].iter().all(|&b| b));
        }

        // All of the base features should be enabled
        assert!(reqs.base_features.iter().all(|&b| b));

        let base_features = reqs.c_base_features();

        assert_eq!(base_features.robustBufferAccess, 1);
        assert_eq!(base_features.fullDrawIndexUint32, 1);
        assert_eq!(base_features.imageCubeArray, 1);
        assert_eq!(base_features.independentBlend, 1);
        assert_eq!(base_features.geometryShader, 1);
        assert_eq!(base_features.tessellationShader, 1);
        assert_eq!(base_features.sampleRateShading, 1);
        assert_eq!(base_features.dualSrcBlend, 1);
        assert_eq!(base_features.logicOp, 1);
        assert_eq!(base_features.multiDrawIndirect, 1);
        assert_eq!(base_features.drawIndirectFirstInstance, 1);
        assert_eq!(base_features.depthClamp, 1);
        assert_eq!(base_features.depthBiasClamp, 1);
        assert_eq!(base_features.fillModeNonSolid, 1);
        assert_eq!(base_features.depthBounds, 1);
        assert_eq!(base_features.wideLines, 1);
        assert_eq!(base_features.largePoints, 1);
        assert_eq!(base_features.alphaToOne, 1);
        assert_eq!(base_features.multiViewport, 1);
        assert_eq!(base_features.samplerAnisotropy, 1);
        assert_eq!(base_features.textureCompressionETC2, 1);
        assert_eq!(base_features.textureCompressionASTC_LDR, 1);
        assert_eq!(base_features.textureCompressionBC, 1);
        assert_eq!(base_features.occlusionQueryPrecise, 1);
        assert_eq!(base_features.pipelineStatisticsQuery, 1);
        assert_eq!(base_features.vertexPipelineStoresAndAtomics, 1);
        assert_eq!(base_features.fragmentStoresAndAtomics, 1);
        assert_eq!(base_features.shaderTessellationAndGeometryPointSize, 1);
        assert_eq!(base_features.shaderImageGatherExtended, 1);
        assert_eq!(base_features.shaderStorageImageExtendedFormats, 1);
        assert_eq!(base_features.shaderStorageImageMultisample, 1);
        assert_eq!(base_features.shaderStorageImageReadWithoutFormat, 1);
        assert_eq!(base_features.shaderStorageImageWriteWithoutFormat, 1);
        assert_eq!(base_features.shaderUniformBufferArrayDynamicIndexing, 1);
        assert_eq!(base_features.shaderSampledImageArrayDynamicIndexing, 1);
        assert_eq!(base_features.shaderStorageBufferArrayDynamicIndexing, 1);
        assert_eq!(base_features.shaderStorageImageArrayDynamicIndexing, 1);
        assert_eq!(base_features.shaderClipDistance, 1);
        assert_eq!(base_features.shaderCullDistance, 1);
        assert_eq!(base_features.shaderFloat64, 1);
        assert_eq!(base_features.shaderInt64, 1);
        assert_eq!(base_features.shaderInt16, 1);
        assert_eq!(base_features.shaderResourceResidency, 1);
        assert_eq!(base_features.shaderResourceMinLod, 1);
        assert_eq!(base_features.sparseBinding, 1);
        assert_eq!(base_features.sparseResidencyBuffer, 1);
        assert_eq!(base_features.sparseResidencyImage2D, 1);
        assert_eq!(base_features.sparseResidencyImage3D, 1);
        assert_eq!(base_features.sparseResidency2Samples, 1);
        assert_eq!(base_features.sparseResidency4Samples, 1);
        assert_eq!(base_features.sparseResidency8Samples, 1);
        assert_eq!(base_features.sparseResidency16Samples, 1);
        assert_eq!(base_features.sparseResidencyAliased, 1);
        assert_eq!(base_features.variableMultisampleRate, 1);
        assert_eq!(base_features.inheritedQueries, 1);

        // All of the values should be set in the C structs
        for extension in EXTENSIONS.iter() {
            let structs = reqs.c_structures().unwrap();

            assert!(
                find_bools_for_extension(structs, extension)
                    .iter()
                    .all(|&b| b == 1)
            );

            assert!(unsafe { &*reqs.lazy_structures.get() }.is_some());
        }

        // All of the extensions should be in the c_extensions
        for extension in EXTENSIONS.iter() {
            assert!(unsafe {
                extension_in_c_extensions(&mut reqs, extension.name())
            });
            assert!(unsafe { &*reqs.lazy_extensions.get() }.is_some());
        }

        // Sanity check that a made-up extension isn’t in c_extensions
        assert!(!unsafe {
            extension_in_c_extensions(&mut reqs, "not_a_real_ext")
        });
    }

    #[test]
    fn test_version() {
        let mut reqs = Requirements::new();
        reqs.add_version(2, 1, 5);
        assert_eq!(reqs.version(), 0x801005);
    }

    #[test]
    fn test_empty() {
        let reqs = Requirements::new();

        assert!(unsafe { &*reqs.lazy_extensions.get() }.is_none());
        assert!(unsafe { &*reqs.lazy_structures.get() }.is_none());

        let base_features_ptr = reqs.c_base_features()
            as *const vk::VkPhysicalDeviceFeatures
            as *const vk::VkBool32;

        unsafe {
            let base_features = std::slice::from_raw_parts(
                base_features_ptr,
                N_BASE_FEATURES
            );

            assert!(base_features.iter().all(|&b| b == 0));
        }
    }

    #[test]
    fn test_eq() {
        let mut reqs_a = Requirements::new();
        let mut reqs_b = Requirements::new();

        assert_eq!(reqs_a, reqs_b);

        reqs_a.add("advancedBlendCoherentOperations");
        assert_ne!(reqs_a, reqs_b);

        reqs_b.add("advancedBlendCoherentOperations");
        assert_eq!(reqs_a, reqs_b);

        // Getting the C data shouldn’t affect the equality
        reqs_a.c_structures();
        reqs_a.c_base_features();
        reqs_a.c_extensions();

        assert_eq!(reqs_a, reqs_b);

        // The order of adding shouldn’t matter
        reqs_a.add("fake_extension");
        reqs_a.add("another_fake_extension");

        reqs_b.add("another_fake_extension");
        reqs_b.add("fake_extension");

        assert_eq!(reqs_a, reqs_b);

        reqs_a.add("wideLines");
        assert_ne!(reqs_a, reqs_b);

        reqs_b.add("wideLines");
        assert_eq!(reqs_a, reqs_b);

        reqs_a.add_version(3, 1, 2);
        assert_ne!(reqs_a, reqs_b);

        reqs_b.add_version(3, 1, 2);
        assert_eq!(reqs_a, reqs_b);
    }

    #[test]
    fn test_clone() {
        let mut reqs = Requirements::new();

        reqs.add("wideLines");
        reqs.add("fake_extension");
        reqs.add("advancedBlendCoherentOperations");
        assert_eq!(reqs.clone(), reqs);

        assert!(unsafe { &*reqs.lazy_structures.get() }.is_none());
        assert_eq!(reqs.c_extensions().len(), 2);
        assert_eq!(reqs.c_base_features().wideLines, 1);

        let empty = Requirements::new();

        reqs.clone_from(&empty);

        assert!(unsafe { &*reqs.lazy_structures.get() }.is_none());
        assert_eq!(reqs.c_extensions().len(), 0);
        assert_eq!(reqs.c_base_features().wideLines, 0);

        assert_eq!(reqs, empty);
    }

    struct FakeVulkanData {
        fake_vulkan: Box<fake_vulkan::FakeVulkan>,
        _vklib: vulkan_funcs::Library,
        vkinst: vulkan_funcs::Instance,
        instance: vk::VkInstance,
        device: vk::VkPhysicalDevice,
    }

    impl FakeVulkanData {
        fn new() -> FakeVulkanData {
            let mut fake_vulkan = fake_vulkan::FakeVulkan::new();

            fake_vulkan.physical_devices.push(Default::default());

            fake_vulkan.set_override();
            let vklib = vulkan_funcs::Library::new().unwrap();

            let mut instance = std::ptr::null_mut();

            unsafe {
                let res = vklib.vkCreateInstance.unwrap()(
                    std::ptr::null(), // pCreateInfo
                    std::ptr::null(), // pAllocator
                    std::ptr::addr_of_mut!(instance),
                );
                assert_eq!(res, vk::VK_SUCCESS);
            }

            extern "C" fn get_instance_proc(
                func_name: *const c_char,
                user_data: *const c_void,
            ) -> *const c_void {
                unsafe {
                    let fake_vulkan =
                        &*user_data.cast::<fake_vulkan::FakeVulkan>();
                    std::mem::transmute(
                        fake_vulkan.get_function(func_name.cast())
                    )
                }
            }

            let vkinst = unsafe {
                vulkan_funcs::Instance::new(
                    get_instance_proc,
                    fake_vulkan.as_ref()
                        as *const fake_vulkan::FakeVulkan
                        as *const c_void,
                )
            };

            let mut device = std::ptr::null_mut();

            unsafe {
                let mut count = 1;
                let res = vkinst.vkEnumeratePhysicalDevices.unwrap()(
                    instance,
                    std::ptr::addr_of_mut!(count),
                    std::ptr::addr_of_mut!(device),
                );

                assert_eq!(res, vk::VK_SUCCESS);
            }

            FakeVulkanData {
                fake_vulkan,
                _vklib: vklib,
                vkinst,
                instance,
                device,
            }
        }
    }

    impl Drop for FakeVulkanData {
        fn drop(&mut self) {
            if !std::thread::panicking() {
                unsafe {
                    self.vkinst.vkDestroyInstance.unwrap()(
                        self.instance,
                        std::ptr::null(), // allocator
                    );
                }
            }
        }
    }

    fn check_base_features<'a>(
        reqs: &'a Requirements,
        features: &vk::VkPhysicalDeviceFeatures
    ) -> Result<(), Error> {
        let mut data = FakeVulkanData::new();

        data.fake_vulkan.physical_devices[0].features = features.clone();

        reqs.check(&data.vkinst, data.device)
    }

    #[test]
    fn test_check_base_features() {
        let mut features = Default::default();

        assert!(matches!(
            check_base_features(&Requirements::new(), &features),
            Ok(()),
        ));

        features.geometryShader = vk::VK_TRUE;

        assert!(matches!(
            check_base_features(&Requirements::new(), &features),
            Ok(()),
        ));

        let mut reqs = Requirements::new();
        reqs.add("geometryShader");
        assert!(matches!(
            check_base_features(&reqs, &features),
            Ok(()),
        ));

        reqs.add("depthBounds");
        match check_base_features(&reqs, &features) {
            Ok(()) => unreachable!("Requirements::check was supposed to fail"),
            Err(e) => {
                assert!(matches!(e, Error::MissingBaseFeature(_)));
                assert_eq!(
                    e.to_string(),
                    "Missing required feature: depthBounds"
                );
                assert_eq!(e.result(), result::Result::Skip);
            },
        }
    }

    #[test]
    fn test_check_extensions() {
        let mut data = FakeVulkanData::new();
        let mut reqs = Requirements::new();

        assert!(matches!(
            reqs.check(&data.vkinst, data.device),
            Ok(()),
        ));

        reqs.add("fake_extension");

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected extensions check to fail"),
            Err(e) => {
                assert!(matches!(
                    e,
                    Error::MissingExtension(_),
                ));
                assert_eq!(
                    e.to_string(),
                    "Missing required extension: fake_extension"
                );
                assert_eq!(e.result(), result::Result::Skip);
            },
        };

        data.fake_vulkan.physical_devices[0].add_extension("fake_extension");

        assert!(matches!(
            reqs.check(&data.vkinst, data.device),
            Ok(()),
        ));

        // Add an extension via a feature
        reqs.add("multiviewGeometryShader");

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected extensions check to fail"),
            Err(e) => {
                assert!(matches!(
                    e,
                    Error::MissingExtension(_)
                ));
                assert_eq!(
                    e.to_string(),
                    "Missing required extension: VK_KHR_multiview",
                );
                assert_eq!(e.result(), result::Result::Skip);
            },
        };

        data.fake_vulkan.physical_devices[0].add_extension("VK_KHR_multiview");

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected extensions check to fail"),
            Err(e) => {
                assert!(matches!(
                    e,
                    Error::MissingFeature { .. },
                ));
                assert_eq!(
                    e.to_string(),
                    "Missing required feature “multiviewGeometryShader” from \
                     extension “VK_KHR_multiview”",
                );
                assert_eq!(e.result(), result::Result::Skip);
            },
        };

        // Make an unterminated UTF-8 character
        let extension_name = &mut data
            .fake_vulkan
            .physical_devices[0]
            .extensions[0]
            .extensionName;
        extension_name[0] = -1i8 as c_char;
        extension_name[1] = 0;

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected extensions check to fail"),
            Err(e) => {
                assert_eq!(
                    e.to_string(),
                    "Invalid UTF-8 in string returned from \
                     vkEnumerateDeviceExtensionProperties"
                );
                assert!(matches!(e, Error::ExtensionInvalidUtf8));
                assert_eq!(e.result(), result::Result::Fail);
            },
        };

        // No null-terminator in the extension
        extension_name.fill(32);

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected extensions check to fail"),
            Err(e) => {
                assert_eq!(
                    e.to_string(),
                    "NULL terminator missing in string returned from \
                     vkEnumerateDeviceExtensionProperties"
                );
                assert!(matches!(e, Error::ExtensionMissingNullTerminator));
                assert_eq!(e.result(), result::Result::Fail);
            },
        };
    }

    #[test]
    fn test_check_structures() {
        let mut data = FakeVulkanData::new();
        let mut reqs = Requirements::new();

        reqs.add("multiview");
        data.fake_vulkan.physical_devices[0].add_extension("VK_KHR_multiview");

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected features check to fail"),
            Err(e) => {
                assert_eq!(
                    e.to_string(),
                    "Missing required feature “multiview” \
                     from extension “VK_KHR_multiview”",
                );
                assert!(matches!(
                    e,
                    Error::MissingFeature { .. },
                ));
                assert_eq!(e.result(), result::Result::Skip);
            },
        };

        data.fake_vulkan.physical_devices[0].multiview.multiview = vk::VK_TRUE;

        assert!(matches!(
            reqs.check(&data.vkinst, data.device),
            Ok(()),
        ));

        reqs.add("shaderBufferInt64Atomics");
        data.fake_vulkan.physical_devices[0].add_extension(
            "VK_KHR_shader_atomic_int64"
        );

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected features check to fail"),
            Err(e) => {
                assert_eq!(
                    e.to_string(),
                    "Missing required feature “shaderBufferInt64Atomics” \
                     from extension “VK_KHR_shader_atomic_int64”",
                );
                assert!(matches!(
                    e,
                    Error::MissingFeature { .. },
                ));
                assert_eq!(e.result(), result::Result::Skip);
            },
        };

        data.fake_vulkan
            .physical_devices[0]
            .shader_atomic
            .shaderBufferInt64Atomics =
            vk::VK_TRUE;

        assert!(matches!(
            reqs.check(&data.vkinst, data.device),
            Ok(()),
        ));
    }

    #[test]
    fn test_check_version() {
        let mut data = FakeVulkanData::new();
        let mut reqs = Requirements::new();

        reqs.add_version(1, 2, 0);
        data.fake_vulkan.physical_devices[0].properties.apiVersion =
            make_version(1, 1, 0);

        match reqs.check(&data.vkinst, data.device) {
            Ok(()) => unreachable!("expected version check to fail"),
            Err(e) => {
                // The check will report that the version is 1.0.0
                // because vkEnumerateInstanceVersion is not available
                assert_eq!(
                    e.to_string(),
                    "Vulkan API version 1.2.0 required but the driver \
                     reported 1.1.0",
                );
                assert!(matches!(
                    e,
                    Error::VersionTooLow { .. },
                ));
                assert_eq!(e.result(), result::Result::Skip);
            },
        };

        // Set a valid version
        data.fake_vulkan.physical_devices[0].properties.apiVersion =
            make_version(1, 3, 0);

        assert!(matches!(
            reqs.check(&data.vkinst, data.device),
            Ok(()),
        ));
    }
}

```
Page 4/7FirstPrevNextLast