This is page 4 of 7. Use http://codebase.md/mehmetoguzderin/shaderc-vkrunner-mcp?lines=false&page={x} to view the full context. # Directory Structure ``` ├── .devcontainer │ ├── devcontainer.json │ ├── docker-compose.yml │ └── Dockerfile ├── .gitattributes ├── .github │ └── workflows │ └── build-push-image.yml ├── .gitignore ├── .vscode │ └── mcp.json ├── Cargo.lock ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.adoc ├── shaderc-vkrunner-mcp.jpg ├── src │ └── main.rs └── vkrunner ├── .editorconfig ├── .gitignore ├── .gitlab-ci.yml ├── build.rs ├── Cargo.toml ├── COPYING ├── examples │ ├── compute-shader.shader_test │ ├── cooperative-matrix.shader_test │ ├── depth-buffer.shader_test │ ├── desc_set_and_binding.shader_test │ ├── entrypoint.shader_test │ ├── float-framebuffer.shader_test │ ├── frexp.shader_test │ ├── geometry.shader_test │ ├── indices.shader_test │ ├── layouts.shader_test │ ├── properties.shader_test │ ├── push-constants.shader_test │ ├── require-subgroup-size.shader_test │ ├── row-major.shader_test │ ├── spirv.shader_test │ ├── ssbo.shader_test │ ├── tolerance.shader_test │ ├── tricolore.shader_test │ ├── ubo.shader_test │ ├── vertex-data-piglit.shader_test │ └── vertex-data.shader_test ├── include │ ├── vk_video │ │ ├── vulkan_video_codec_av1std_decode.h │ │ ├── vulkan_video_codec_av1std_encode.h │ │ ├── vulkan_video_codec_av1std.h │ │ ├── vulkan_video_codec_h264std_decode.h │ │ ├── vulkan_video_codec_h264std_encode.h │ │ ├── vulkan_video_codec_h264std.h │ │ ├── vulkan_video_codec_h265std_decode.h │ │ ├── vulkan_video_codec_h265std_encode.h │ │ ├── vulkan_video_codec_h265std.h │ │ └── vulkan_video_codecs_common.h │ └── vulkan │ ├── vk_platform.h │ ├── vulkan_core.h │ └── vulkan.h ├── precompile-script.py ├── README.md ├── scripts │ └── update-vulkan.sh ├── src │ └── main.rs ├── test-build.sh └── vkrunner ├── allocate_store.rs ├── buffer.rs ├── compiler │ └── fake_process.rs ├── compiler.rs ├── config.rs ├── context.rs ├── enum_table.rs ├── env_var_test.rs ├── executor.rs ├── fake_vulkan.rs ├── features.rs ├── flush_memory.rs ├── format_table.rs ├── format.rs ├── half_float.rs ├── hex.rs ├── inspect.rs ├── lib.rs ├── logger.rs ├── make-enums.py ├── make-features.py ├── make-formats.py ├── make-pipeline-key-data.py ├── make-vulkan-funcs-data.py ├── parse_num.rs ├── pipeline_key_data.rs ├── pipeline_key.rs ├── pipeline_set.rs ├── requirements.rs ├── result.rs ├── script.rs ├── shader_stage.rs ├── slot.rs ├── small_float.rs ├── source.rs ├── stream.rs ├── temp_file.rs ├── tester.rs ├── tolerance.rs ├── util.rs ├── vbo.rs ├── vk.rs ├── vulkan_funcs_data.rs ├── vulkan_funcs.rs ├── window_format.rs └── window.rs ``` # Files -------------------------------------------------------------------------------- /vkrunner/vkrunner/pipeline_set.rs: -------------------------------------------------------------------------------- ```rust // vkrunner // // Copyright (C) 2018 Intel Corporation // Copypright 2023 Neil Roberts // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice (including the next // paragraph) shall be included in all copies or substantial portions of the // Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. use crate::window::Window; use crate::compiler; use crate::shader_stage; use crate::vk; use crate::script::{Script, Buffer, BufferType, Operation}; use crate::pipeline_key; use crate::logger::Logger; use crate::vbo::Vbo; use std::rc::Rc; use std::ptr; use std::mem; use std::fmt; #[derive(Debug)] pub struct PipelineSet { pipelines: PipelineVec, layout: PipelineLayout, // The descriptor data is only created if there are buffers in the // script descriptor_data: Option<DescriptorData>, stages: vk::VkShaderStageFlagBits, // These are never read but they should probably be kept alive // until the pipelines are destroyed. _pipeline_cache: PipelineCache, _modules: [Option<ShaderModule>; shader_stage::N_STAGES], } #[repr(C)] #[derive(Debug)] pub struct RectangleVertex { pub x: f32, pub y: f32, pub z: f32, } /// An error that can be returned by [PipelineSet::new]. #[derive(Debug)] pub enum Error { /// Compiling one of the shaders in the script failed CompileError(compiler::Error), /// vkCreatePipelineCache failed CreatePipelineCacheFailed, /// vkCreateDescriptorPool failed CreateDescriptorPoolFailed, /// vkCreateDescriptorSetLayout failed CreateDescriptorSetLayoutFailed, /// vkCreatePipelineLayout failed CreatePipelineLayoutFailed, /// vkCreatePipeline failed CreatePipelineFailed, } impl From<compiler::Error> for Error { fn from(e: compiler::Error) -> Error { Error::CompileError(e) } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match self { Error::CompileError(e) => e.fmt(f), Error::CreatePipelineCacheFailed => { write!(f, "vkCreatePipelineCache failed") }, Error::CreateDescriptorPoolFailed => { write!(f, "vkCreateDescriptorPool failed") }, Error::CreateDescriptorSetLayoutFailed => { write!(f, "vkCreateDescriptorSetLayout failed") }, Error::CreatePipelineLayoutFailed => { write!(f, "vkCreatePipelineLayout failed") }, Error::CreatePipelineFailed => { write!(f, "Pipeline creation function failed") }, } } } #[derive(Debug)] struct ShaderModule { handle: vk::VkShaderModule, // needed for the destructor window: Rc<Window>, } impl Drop for ShaderModule { fn drop(&mut self) { unsafe { self.window.device().vkDestroyShaderModule.unwrap()( self.window.vk_device(), self.handle, ptr::null(), // allocator ); } } } #[derive(Debug)] struct PipelineCache { handle: vk::VkPipelineCache, // needed for the destructor window: Rc<Window>, } impl Drop for PipelineCache { fn drop(&mut self) { unsafe { self.window.device().vkDestroyPipelineCache.unwrap()( self.window.vk_device(), self.handle, ptr::null(), // allocator ); } } } impl PipelineCache { fn new(window: Rc<Window>) -> Result<PipelineCache, Error> { let create_info = vk::VkPipelineCacheCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, flags: 0, pNext: ptr::null(), initialDataSize: 0, pInitialData: ptr::null(), }; let mut handle = vk::null_handle(); let res = unsafe { window.device().vkCreatePipelineCache.unwrap()( window.vk_device(), ptr::addr_of!(create_info), ptr::null(), // allocator ptr::addr_of_mut!(handle), ) }; if res == vk::VK_SUCCESS { Ok(PipelineCache { handle, window }) } else { Err(Error::CreatePipelineCacheFailed) } } } #[derive(Debug)] struct DescriptorPool { handle: vk::VkDescriptorPool, // needed for the destructor window: Rc<Window>, } impl Drop for DescriptorPool { fn drop(&mut self) { unsafe { self.window.device().vkDestroyDescriptorPool.unwrap()( self.window.vk_device(), self.handle, ptr::null(), // allocator ); } } } impl DescriptorPool { fn new( window: Rc<Window>, buffers: &[Buffer], ) -> Result<DescriptorPool, Error> { let mut n_ubos = 0; let mut n_ssbos = 0; for buffer in buffers { match buffer.buffer_type { BufferType::Ubo => n_ubos += 1, BufferType::Ssbo => n_ssbos += 1, } } let mut pool_sizes = Vec::<vk::VkDescriptorPoolSize>::new(); if n_ubos > 0 { pool_sizes.push(vk::VkDescriptorPoolSize { type_: vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, descriptorCount: n_ubos, }); } if n_ssbos > 0 { pool_sizes.push(vk::VkDescriptorPoolSize { type_: vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, descriptorCount: n_ssbos, }); } // The descriptor pool shouldn’t have been created if there // were no buffers assert!(!pool_sizes.is_empty()); let create_info = vk::VkDescriptorPoolCreateInfo { sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, flags: vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, pNext: ptr::null(), maxSets: n_desc_sets(buffers) as u32, poolSizeCount: pool_sizes.len() as u32, pPoolSizes: pool_sizes.as_ptr(), }; let mut handle = vk::null_handle(); let res = unsafe { window.device().vkCreateDescriptorPool.unwrap()( window.vk_device(), ptr::addr_of!(create_info), ptr::null(), // allocator ptr::addr_of_mut!(handle), ) }; if res == vk::VK_SUCCESS { Ok(DescriptorPool { handle, window }) } else { Err(Error::CreateDescriptorPoolFailed) } } } #[derive(Debug)] struct DescriptorSetLayoutVec { handles: Vec<vk::VkDescriptorSetLayout>, // needed for the destructor window: Rc<Window>, } impl Drop for DescriptorSetLayoutVec { fn drop(&mut self) { for &handle in self.handles.iter() { unsafe { self.window.device().vkDestroyDescriptorSetLayout.unwrap()( self.window.vk_device(), handle, ptr::null(), // allocator ); } } } } impl DescriptorSetLayoutVec { fn new(window: Rc<Window>) -> DescriptorSetLayoutVec { DescriptorSetLayoutVec { window, handles: Vec::new() } } fn len(&self) -> usize { self.handles.len() } fn as_ptr(&self) -> *const vk::VkDescriptorSetLayout { self.handles.as_ptr() } fn add( &mut self, bindings: &[vk::VkDescriptorSetLayoutBinding], ) -> Result<(), Error> { let create_info = vk::VkDescriptorSetLayoutCreateInfo { sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, flags: 0, pNext: ptr::null(), bindingCount: bindings.len() as u32, pBindings: bindings.as_ptr(), }; let mut handle = vk::null_handle(); let res = unsafe { self.window.device().vkCreateDescriptorSetLayout.unwrap()( self.window.vk_device(), ptr::addr_of!(create_info), ptr::null(), // allocator ptr::addr_of_mut!(handle), ) }; if res == vk::VK_SUCCESS { self.handles.push(handle); Ok(()) } else { Err(Error::CreateDescriptorSetLayoutFailed) } } } fn create_descriptor_set_layouts( window: Rc<Window>, buffers: &[Buffer], stages: vk::VkShaderStageFlags, ) -> Result<DescriptorSetLayoutVec, Error> { let n_desc_sets = n_desc_sets(buffers); let mut layouts = DescriptorSetLayoutVec::new(window); let mut bindings = Vec::new(); let mut buffer_num = 0; for desc_set in 0..n_desc_sets { bindings.clear(); while buffer_num < buffers.len() && buffers[buffer_num].desc_set as usize == desc_set { let buffer = &buffers[buffer_num]; bindings.push(vk::VkDescriptorSetLayoutBinding { binding: buffer.binding, descriptorType: match buffer.buffer_type { BufferType::Ubo => vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, BufferType::Ssbo => vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, }, descriptorCount: 1, stageFlags: stages, pImmutableSamplers: ptr::null(), }); buffer_num += 1; } layouts.add(&bindings)?; } assert_eq!(layouts.len(), n_desc_sets); Ok(layouts) } fn n_desc_sets(buffers: &[Buffer]) -> usize { match buffers.last() { // The number of descriptor sets is the highest used // descriptor set index + 1. The buffers are in order so the // highest one should be the last one. Some(last) => last.desc_set as usize + 1, None => 0, } } #[derive(Debug)] struct DescriptorData { pool: DescriptorPool, layouts: DescriptorSetLayoutVec, } impl DescriptorData { fn new( window: &Rc<Window>, buffers: &[Buffer], stages: vk::VkShaderStageFlagBits, ) -> Result<DescriptorData, Error> { let pool = DescriptorPool::new( Rc::clone(&window), buffers, )?; let layouts = create_descriptor_set_layouts( Rc::clone(&window), buffers, stages, )?; Ok(DescriptorData { pool, layouts }) } } fn compile_shaders( logger: &mut Logger, window: &Rc<Window>, script: &Script, show_disassembly: bool, ) -> Result<[Option<ShaderModule>; shader_stage::N_STAGES], Error> { let mut modules: [Option<ShaderModule>; shader_stage::N_STAGES] = Default::default(); for &stage in shader_stage::ALL_STAGES.iter() { if script.shaders(stage).is_empty() { continue; } modules[stage as usize] = Some(ShaderModule { handle: compiler::build_stage( logger, window.context(), script, stage, show_disassembly, )?, window: Rc::clone(window), }); } Ok(modules) } fn stage_flags(script: &Script) -> vk::VkShaderStageFlagBits { // Set a flag for each stage that has a shader in the script shader_stage::ALL_STAGES .iter() .filter_map(|&stage| if script.shaders(stage).is_empty() { None } else { Some(stage.flag()) }) .fold(0, |a, b| a | b) } fn push_constant_size(script: &Script) -> usize { script.commands() .iter() .map(|command| match &command.op { Operation::SetPushCommand { offset, data } => offset + data.len(), _ => 0, }) .max() .unwrap_or(0) } #[derive(Debug)] struct PipelineLayout { handle: vk::VkPipelineLayout, // needed for the destructor window: Rc<Window>, } impl Drop for PipelineLayout { fn drop(&mut self) { unsafe { self.window.device().vkDestroyPipelineLayout.unwrap()( self.window.vk_device(), self.handle, ptr::null(), // allocator ); } } } impl PipelineLayout { fn new( window: Rc<Window>, script: &Script, stages: vk::VkShaderStageFlagBits, descriptor_data: Option<&DescriptorData> ) -> Result<PipelineLayout, Error> { let mut handle = vk::null_handle(); let push_constant_range = vk::VkPushConstantRange { stageFlags: stages, offset: 0, size: push_constant_size(script) as u32, }; let mut create_info = vk::VkPipelineLayoutCreateInfo::default(); create_info.sType = vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; if push_constant_range.size > 0 { create_info.pushConstantRangeCount = 1; create_info.pPushConstantRanges = ptr::addr_of!(push_constant_range); } if let Some(descriptor_data) = descriptor_data { create_info.setLayoutCount = descriptor_data.layouts.len() as u32; create_info.pSetLayouts = descriptor_data.layouts.as_ptr(); } let res = unsafe { window.device().vkCreatePipelineLayout.unwrap()( window.vk_device(), ptr::addr_of!(create_info), ptr::null(), // allocator ptr::addr_of_mut!(handle), ) }; if res == vk::VK_SUCCESS { Ok(PipelineLayout { handle, window }) } else { Err(Error::CreatePipelineLayoutFailed) } } } #[derive(Debug)] struct VertexInputState { create_info: vk::VkPipelineVertexInputStateCreateInfo, // These are not read put the create_info struct keeps pointers to // them so they need to be kept alive _input_bindings: Vec::<vk::VkVertexInputBindingDescription>, _attribs: Vec::<vk::VkVertexInputAttributeDescription>, } impl VertexInputState { fn new(script: &Script, key: &pipeline_key::Key) -> VertexInputState { let mut input_bindings = Vec::new(); let mut attribs = Vec::new(); match key.source() { pipeline_key::Source::Rectangle => { input_bindings.push(vk::VkVertexInputBindingDescription { binding: 0, stride: mem::size_of::<RectangleVertex>() as u32, inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX, }); attribs.push(vk::VkVertexInputAttributeDescription { location: 0, binding: 0, format: vk::VK_FORMAT_R32G32B32_SFLOAT, offset: 0, }); }, pipeline_key::Source::VertexData => { if let Some(vbo) = script.vertex_data() { VertexInputState::set_up_vertex_data_attribs( &mut input_bindings, &mut attribs, vbo ); } } } VertexInputState { create_info: vk::VkPipelineVertexInputStateCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO, flags: 0, pNext: ptr::null(), vertexBindingDescriptionCount: input_bindings.len() as u32, pVertexBindingDescriptions: input_bindings.as_ptr(), vertexAttributeDescriptionCount: attribs.len() as u32, pVertexAttributeDescriptions: attribs.as_ptr(), }, _input_bindings: input_bindings, _attribs: attribs, } } fn set_up_vertex_data_attribs( input_bindings: &mut Vec::<vk::VkVertexInputBindingDescription>, attribs: &mut Vec::<vk::VkVertexInputAttributeDescription>, vbo: &Vbo, ) { input_bindings.push(vk::VkVertexInputBindingDescription { binding: 0, stride: vbo.stride() as u32, inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX, }); for attrib in vbo.attribs().iter() { attribs.push(vk::VkVertexInputAttributeDescription { location: attrib.location(), binding: 0, format: attrib.format().vk_format, offset: attrib.offset() as u32, }); } } } #[derive(Debug)] struct PipelineVec { handles: Vec<vk::VkPipeline>, // needed for the destructor window: Rc<Window>, } impl Drop for PipelineVec { fn drop(&mut self) { for &handle in self.handles.iter() { unsafe { self.window.device().vkDestroyPipeline.unwrap()( self.window.vk_device(), handle, ptr::null(), // allocator ); } } } } impl PipelineVec { fn new( window: Rc<Window>, script: &Script, pipeline_cache: vk::VkPipelineCache, layout: vk::VkPipelineLayout, modules: &[Option<ShaderModule>], ) -> Result<PipelineVec, Error> { let mut vec = PipelineVec { window, handles: Vec::new() }; let mut first_graphics_pipeline: Option<vk::VkPipeline> = None; for key in script.pipeline_keys().iter() { let pipeline = match key.pipeline_type() { pipeline_key::Type::Graphics => { let allow_derivatives = first_graphics_pipeline.is_none() && script.pipeline_keys().len() > 1; let pipeline = PipelineVec::create_graphics_pipeline( &vec.window, script, key, pipeline_cache, layout, modules, allow_derivatives, first_graphics_pipeline, )?; first_graphics_pipeline.get_or_insert(pipeline); pipeline }, pipeline_key::Type::Compute => { PipelineVec::create_compute_pipeline( &vec.window, key, pipeline_cache, layout, modules[shader_stage::Stage::Compute as usize] .as_ref() .map(|m| m.handle) .unwrap_or(vk::null_handle()), script.requirements().required_subgroup_size, )? }, }; vec.handles.push(pipeline); } Ok(vec) } fn null_terminated_entrypoint( key: &pipeline_key::Key, stage: shader_stage::Stage, ) -> String { let mut entrypoint = key.entrypoint(stage).to_string(); entrypoint.push('\0'); entrypoint } fn create_stages( modules: &[Option<ShaderModule>], entrypoints: &[String], ) -> Vec<vk::VkPipelineShaderStageCreateInfo> { let mut stages = Vec::new(); for &stage in shader_stage::ALL_STAGES.iter() { if stage == shader_stage::Stage::Compute { continue; } let module = match &modules[stage as usize] { Some(module) => module.handle, None => continue, }; stages.push(vk::VkPipelineShaderStageCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, flags: 0, pNext: ptr::null(), stage: stage.flag(), module, pName: entrypoints[stage as usize].as_ptr().cast(), pSpecializationInfo: ptr::null(), }); } stages } fn create_graphics_pipeline( window: &Window, script: &Script, key: &pipeline_key::Key, pipeline_cache: vk::VkPipelineCache, layout: vk::VkPipelineLayout, modules: &[Option<ShaderModule>], allow_derivatives: bool, parent_pipeline: Option<vk::VkPipeline>, ) -> Result<vk::VkPipeline, Error> { let create_info_buf = key.to_create_info(); let mut create_info = vk::VkGraphicsPipelineCreateInfo::default(); unsafe { ptr::copy_nonoverlapping( create_info_buf.as_ptr().cast(), ptr::addr_of_mut!(create_info), 1, // count ); } let entrypoints = shader_stage::ALL_STAGES .iter() .map(|&s| PipelineVec::null_terminated_entrypoint(key, s)) .collect::<Vec<String>>(); let stages = PipelineVec::create_stages(modules, &entrypoints); let window_format = window.format(); let viewport = vk::VkViewport { x: 0.0, y: 0.0, width: window_format.width as f32, height: window_format.height as f32, minDepth: 0.0, maxDepth: 1.0, }; let scissor = vk::VkRect2D { offset: vk::VkOffset2D { x: 0, y: 0, }, extent: vk::VkExtent2D { width: window_format.width as u32, height: window_format.height as u32, }, }; let viewport_state = vk::VkPipelineViewportStateCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO, flags: 0, pNext: ptr::null(), viewportCount: 1, pViewports: ptr::addr_of!(viewport), scissorCount: 1, pScissors: ptr::addr_of!(scissor), }; let multisample_state = vk::VkPipelineMultisampleStateCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO, pNext: ptr::null(), flags: 0, rasterizationSamples: vk::VK_SAMPLE_COUNT_1_BIT, sampleShadingEnable: vk::VK_FALSE, minSampleShading: 0.0, pSampleMask: ptr::null(), alphaToCoverageEnable: vk::VK_FALSE, alphaToOneEnable: vk::VK_FALSE, }; create_info.pViewportState = ptr::addr_of!(viewport_state); create_info.pMultisampleState = ptr::addr_of!(multisample_state); create_info.subpass = 0; create_info.basePipelineHandle = parent_pipeline.unwrap_or(vk::null_handle()); create_info.basePipelineIndex = -1; create_info.stageCount = stages.len() as u32; create_info.pStages = stages.as_ptr(); create_info.layout = layout; create_info.renderPass = window.render_passes()[0]; if allow_derivatives { create_info.flags |= vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; } if parent_pipeline.is_some() { create_info.flags |= vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT; } if modules[shader_stage::Stage::TessCtrl as usize].is_none() && modules[shader_stage::Stage::TessEval as usize].is_none() { create_info.pTessellationState = ptr::null(); } let vertex_input_state = VertexInputState::new(script, key); create_info.pVertexInputState = ptr::addr_of!(vertex_input_state.create_info); let mut handle = vk::null_handle(); let res = unsafe { window.device().vkCreateGraphicsPipelines.unwrap()( window.vk_device(), pipeline_cache, 1, // nCreateInfos ptr::addr_of!(create_info), ptr::null(), // allocator ptr::addr_of_mut!(handle), ) }; if res == vk::VK_SUCCESS { Ok(handle) } else { Err(Error::CreatePipelineFailed) } } fn create_compute_pipeline( window: &Window, key: &pipeline_key::Key, pipeline_cache: vk::VkPipelineCache, layout: vk::VkPipelineLayout, module: vk::VkShaderModule, required_subgroup_size: Option<u32>, ) -> Result<vk::VkPipeline, Error> { let entrypoint = PipelineVec::null_terminated_entrypoint( key, shader_stage::Stage::Compute, ); let rss_info = required_subgroup_size.map(|size| vk::VkPipelineShaderStageRequiredSubgroupSizeCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO, pNext: ptr::null_mut(), requiredSubgroupSize: size, }); let create_info = vk::VkComputePipelineCreateInfo { sType: vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, pNext: ptr::null(), flags: 0, stage: vk::VkPipelineShaderStageCreateInfo { sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, pNext: if let Some(ref rss_info) = rss_info { ptr::addr_of!(*rss_info).cast() } else { ptr::null_mut() }, flags: 0, stage: vk::VK_SHADER_STAGE_COMPUTE_BIT, module, pName: entrypoint.as_ptr().cast(), pSpecializationInfo: ptr::null(), }, layout, basePipelineHandle: vk::null_handle(), basePipelineIndex: -1, }; let mut handle = vk::null_handle(); let res = unsafe { window.device().vkCreateComputePipelines.unwrap()( window.vk_device(), pipeline_cache, 1, // nCreateInfos ptr::addr_of!(create_info), ptr::null(), // allocator ptr::addr_of_mut!(handle), ) }; if res == vk::VK_SUCCESS { Ok(handle) } else { Err(Error::CreatePipelineFailed) } } } impl PipelineSet { pub fn new( logger: &mut Logger, window: Rc<Window>, script: &Script, show_disassembly: bool, ) -> Result<PipelineSet, Error> { let modules = compile_shaders( logger, &window, script, show_disassembly )?; let pipeline_cache = PipelineCache::new(Rc::clone(&window))?; let stages = stage_flags(script); let descriptor_data = if script.buffers().is_empty() { None } else { Some(DescriptorData::new(&window, script.buffers(), stages)?) }; let layout = PipelineLayout::new( Rc::clone(&window), script, stages, descriptor_data.as_ref(), )?; let pipelines = PipelineVec::new( Rc::clone(&window), script, pipeline_cache.handle, layout.handle, &modules, )?; Ok(PipelineSet { _modules: modules, _pipeline_cache: pipeline_cache, stages, descriptor_data, layout, pipelines, }) } pub fn descriptor_set_layouts(&self) -> &[vk::VkDescriptorSetLayout] { match self.descriptor_data.as_ref() { Some(data) => &data.layouts.handles, None => &[], } } pub fn stages(&self) -> vk::VkShaderStageFlagBits { self.stages } pub fn layout(&self) -> vk::VkPipelineLayout { self.layout.handle } pub fn pipelines(&self) -> &[vk::VkPipeline] { &self.pipelines.handles } pub fn descriptor_pool(&self) -> Option<vk::VkDescriptorPool> { self.descriptor_data.as_ref().map(|data| data.pool.handle) } } #[cfg(test)] mod test { use super::*; use crate::fake_vulkan::{FakeVulkan, HandleType, PipelineCreateInfo}; use crate::fake_vulkan::GraphicsPipelineCreateInfo; use crate::fake_vulkan::PipelineLayoutCreateInfo; use crate::context::Context; use crate::requirements::Requirements; use crate::source::Source; use crate::config::Config; #[derive(Debug)] struct TestData { pipeline_set: PipelineSet, window: Rc<Window>, fake_vulkan: Box<FakeVulkan>, } impl TestData { fn new_with_errors<F: FnOnce(&mut FakeVulkan)>( source: &str, queue_errors: F ) -> Result<TestData, Error> { let mut fake_vulkan = FakeVulkan::new(); fake_vulkan.physical_devices.push(Default::default()); fake_vulkan.physical_devices[0].format_properties.insert( vk::VK_FORMAT_B8G8R8A8_UNORM, vk::VkFormatProperties { linearTilingFeatures: 0, optimalTilingFeatures: vk::VK_FORMAT_FEATURE_COLOR_ATTACHMENT_BIT | vk::VK_FORMAT_FEATURE_BLIT_SRC_BIT, bufferFeatures: 0, }, ); let memory_properties = &mut fake_vulkan.physical_devices[0].memory_properties; memory_properties.memoryTypes[0].propertyFlags = vk::VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; memory_properties.memoryTypeCount = 1; fake_vulkan.memory_requirements.memoryTypeBits = 1; fake_vulkan.set_override(); let context = Rc::new(Context::new( &Requirements::new(), None, // device_id ).unwrap()); let window = Rc::new(Window::new( Rc::clone(&context), &Default::default(), // format ).unwrap()); queue_errors(&mut fake_vulkan); let mut logger = Logger::new(None, ptr::null_mut()); let source = Source::from_string(source.to_string()); let script = Script::load(&Config::new(), &source).unwrap(); let pipeline_set = PipelineSet::new( &mut logger, Rc::clone(&window), &script, false, // show_disassembly )?; Ok(TestData { fake_vulkan, window, pipeline_set }) } fn new(source: &str) -> Result<TestData, Error> { TestData::new_with_errors(source, |_| ()) } fn graphics_create_info( &mut self, pipeline_num: usize, ) -> GraphicsPipelineCreateInfo { let pipeline = self.pipeline_set.pipelines.handles[pipeline_num]; match &self.fake_vulkan.get_handle(pipeline).data { HandleType::Pipeline(PipelineCreateInfo::Graphics(info)) => { info.clone() }, handle @ _ => { unreachable!("unexpected handle type: {:?}", handle) }, } } fn compute_create_info( &mut self, pipeline_num: usize, ) -> vk::VkComputePipelineCreateInfo { let pipeline = self.pipeline_set.pipelines.handles[pipeline_num]; match &self.fake_vulkan.get_handle(pipeline).data { HandleType::Pipeline(PipelineCreateInfo::Compute(info)) => { info.clone() }, handle @ _ => { unreachable!("unexpected handle type: {:?}", handle) }, } } fn shader_module_code( &mut self, stage: shader_stage::Stage, ) -> Vec<u32> { let module = self.pipeline_set._modules[stage as usize] .as_ref() .unwrap() .handle; match &self.fake_vulkan.get_handle(module).data { HandleType::ShaderModule { code } => code.clone(), _ => unreachable!("Unexpected Vulkan handle type"), } } fn descriptor_set_layout_bindings( &mut self, desc_set: usize, ) -> Vec<vk::VkDescriptorSetLayoutBinding> { let desc_set_layout = self.pipeline_set .descriptor_set_layouts() [desc_set]; match &self.fake_vulkan.get_handle(desc_set_layout).data { HandleType::DescriptorSetLayout { bindings } => { bindings.clone() }, _ => unreachable!("Unexpected Vulkan handle type"), } } fn pipeline_layout_create_info(&mut self) -> PipelineLayoutCreateInfo { let layout = self.pipeline_set.layout(); match &self.fake_vulkan.get_handle(layout).data { HandleType::PipelineLayout(create_info) => create_info.clone(), _ => unreachable!("Unexpected Vulkan handle type"), } } } #[test] fn base() { let mut test_data = TestData::new( "[vertex shader passthrough]\n\ [fragment shader]\n\ 03 02 23 07\n\ fe ca fe ca\n\ [vertex data]\n\ 0/R32G32_SFLOAT 1/R32_SFLOAT\n\ -0.5 -0.5 1.52\n\ 0.5 -0.5 1.55\n\ [compute shader]\n\ 03 02 23 07\n\ ca fe ca fe\n\ [test]\n\ ubo 0 1024\n\ ssbo 1 1024\n\ compute 1 1 1\n\ push float 6 42.0\n\ draw rect 0 0 1 1\n\ draw arrays TRIANGLE_LIST 0 2\n" ).unwrap(); assert_eq!( test_data.pipeline_set.stages(), vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT | vk::VK_SHADER_STAGE_COMPUTE_BIT, ); let descriptor_pool = test_data.pipeline_set.descriptor_pool(); assert!(matches!( test_data.fake_vulkan.get_handle(descriptor_pool.unwrap()).data, HandleType::DescriptorPool, )); assert_eq!(test_data.pipeline_set.pipelines().len(), 3); let compute_create_info = test_data.compute_create_info(0); assert_eq!(compute_create_info.sType, vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO); let create_data = test_data.graphics_create_info(1); let create_info = &create_data.create_info; assert_eq!( create_info.flags, vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT, ); assert_eq!(create_info.stageCount, 2); assert_eq!(create_info.layout, test_data.pipeline_set.layout()); assert_eq!( create_info.renderPass, test_data.window.render_passes()[0] ); assert_eq!(create_info.basePipelineHandle, vk::null_handle()); assert_eq!(create_data.bindings.len(), 1); assert_eq!(create_data.bindings[0].binding, 0); assert_eq!( create_data.bindings[0].stride as usize, mem::size_of::<RectangleVertex>(), ); assert_eq!( create_data.bindings[0].inputRate, vk::VK_VERTEX_INPUT_RATE_VERTEX, ); assert_eq!(create_data.attribs.len(), 1); assert_eq!(create_data.attribs[0].location, 0); assert_eq!(create_data.attribs[0].binding, 0); assert_eq!( create_data.attribs[0].format, vk::VK_FORMAT_R32G32B32_SFLOAT ); assert_eq!(create_data.attribs[0].offset, 0); let create_data = test_data.graphics_create_info(2); let create_info = &create_data.create_info; assert_eq!( create_info.flags, vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT, ); assert_eq!(create_info.stageCount, 2); assert_eq!(create_info.layout, test_data.pipeline_set.layout()); assert_eq!( create_info.renderPass, test_data.window.render_passes()[0] ); assert_eq!( create_info.basePipelineHandle, test_data.pipeline_set.pipelines()[1] ); assert_eq!(create_data.bindings.len(), 1); assert_eq!(create_data.bindings[0].binding, 0); assert_eq!( create_data.bindings[0].stride as usize, mem::size_of::<f32>() * 3, ); assert_eq!( create_data.bindings[0].inputRate, vk::VK_VERTEX_INPUT_RATE_VERTEX, ); assert_eq!(create_data.attribs.len(), 2); assert_eq!(create_data.attribs[0].location, 0); assert_eq!(create_data.attribs[0].binding, 0); assert_eq!( create_data.attribs[0].format, vk::VK_FORMAT_R32G32_SFLOAT ); assert_eq!(create_data.attribs[0].offset, 0); assert_eq!(create_data.attribs[1].location, 1); assert_eq!(create_data.attribs[1].binding, 0); assert_eq!( create_data.attribs[1].format, vk::VK_FORMAT_R32_SFLOAT ); assert_eq!(create_data.attribs[1].offset, 8); let code = test_data.shader_module_code(shader_stage::Stage::Vertex); assert!(code.len() > 2); let code = test_data.shader_module_code(shader_stage::Stage::Fragment); assert_eq!(&code, &[0x07230203, 0xcafecafe]); let code = test_data.shader_module_code(shader_stage::Stage::Compute); assert_eq!(&code, &[0x07230203, 0xfecafeca]); assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 1); let bindings = test_data.descriptor_set_layout_bindings(0); assert_eq!(bindings.len(), 2); assert_eq!(bindings[0].binding, 0); assert_eq!( bindings[0].descriptorType, vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER ); assert_eq!(bindings[0].descriptorCount, 1); assert_eq!( bindings[0].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT | vk::VK_SHADER_STAGE_COMPUTE_BIT, ); assert_eq!(bindings[1].binding, 1); assert_eq!( bindings[1].descriptorType, vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER ); assert_eq!(bindings[1].descriptorCount, 1); assert_eq!( bindings[1].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT | vk::VK_SHADER_STAGE_COMPUTE_BIT, ); let layout = test_data.pipeline_layout_create_info(); assert_eq!(layout.create_info.sType, vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO); assert_eq!(layout.push_constant_ranges.len(), 1); assert_eq!( layout.push_constant_ranges[0].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT | vk::VK_SHADER_STAGE_COMPUTE_BIT, ); assert_eq!(layout.push_constant_ranges[0].offset, 0); assert_eq!( layout.push_constant_ranges[0].size as usize, 6 + mem::size_of::<f32>() ); assert_eq!( &layout.layouts, test_data.pipeline_set.descriptor_set_layouts() ); } #[test] fn descriptor_set_gaps() { let mut test_data = TestData::new( "[vertex shader passthrough]\n\ [fragment shader]\n\ 03 02 23 07\n\ fe ca fe ca\n\ [test]\n\ ubo 0 1024\n\ ssbo 2:1 1024\n\ ssbo 3:1 1024\n\ ubo 5:5 1024\n\ draw rect 0 0 1 1\n" ).unwrap(); assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 6); let bindings = test_data.descriptor_set_layout_bindings(0); assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].binding, 0); assert_eq!( bindings[0].descriptorType, vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER ); assert_eq!(bindings[0].descriptorCount, 1); assert_eq!( bindings[0].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT, ); assert_eq!(test_data.descriptor_set_layout_bindings(1).len(), 0); let bindings = test_data.descriptor_set_layout_bindings(2); assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].binding, 1); assert_eq!( bindings[0].descriptorType, vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER ); assert_eq!(bindings[0].descriptorCount, 1); assert_eq!( bindings[0].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT, ); let bindings = test_data.descriptor_set_layout_bindings(3); assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].binding, 1); assert_eq!( bindings[0].descriptorType, vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER ); assert_eq!(bindings[0].descriptorCount, 1); assert_eq!( bindings[0].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT, ); assert_eq!(test_data.descriptor_set_layout_bindings(4).len(), 0); let bindings = test_data.descriptor_set_layout_bindings(5); assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].binding, 5); assert_eq!( bindings[0].descriptorType, vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER ); assert_eq!(bindings[0].descriptorCount, 1); assert_eq!( bindings[0].stageFlags, vk::VK_SHADER_STAGE_VERTEX_BIT | vk::VK_SHADER_STAGE_FRAGMENT_BIT, ); } #[test] fn compile_error() { let error = TestData::new( "[fragment shader]\n\ 12 34 56 78\n" ).unwrap_err(); assert_eq!( &error.to_string(), "The compiler or assembler generated an invalid SPIR-V binary" ); } #[test] fn create_errors() { let commands = [ "vkCreatePipelineCache", "vkCreateDescriptorPool", "vkCreateDescriptorSetLayout", "vkCreatePipelineLayout", ]; for command in commands { let error = TestData::new_with_errors( "[test]\n\ ubo 0 10\n\ draw rect 0 0 1 1\n", |fake_vulkan| fake_vulkan.queue_result( command.to_string(), vk::VK_ERROR_UNKNOWN, ), ).unwrap_err(); assert_eq!( error.to_string(), format!("{} failed", command), ); } } #[test] fn create_pipeline_errors() { let commands = [ "vkCreateGraphicsPipelines", "vkCreateComputePipelines", ]; for command in commands { let error = TestData::new_with_errors( "[compute shader]\n\ 03 02 23 07\n\ ca fe ca fe\n\ [test]\n\ compute 1 1 1\n\ draw rect 0 0 1 1\n", |fake_vulkan| fake_vulkan.queue_result( command.to_string(), vk::VK_ERROR_UNKNOWN, ), ).unwrap_err(); assert_eq!( &error.to_string(), "Pipeline creation function failed", ); } } #[test] fn no_buffers() { let test_data = TestData::new("").unwrap(); assert!(test_data.pipeline_set.descriptor_pool().is_none()); assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 0); } } ``` -------------------------------------------------------------------------------- /vkrunner/vkrunner/slot.rs: -------------------------------------------------------------------------------- ```rust // vkrunner // // Copyright (C) 2018 Intel Corporation // Copyright 2023 Neil Roberts // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice (including the next // paragraph) shall be included in all copies or substantial portions of the // Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. //! Contains utilities for accessing values stored in a buffer //! according to the GLSL layout rules. These are decribed in detail //! in the OpenGL specs here: //! <https://registry.khronos.org/OpenGL/specs/gl/glspec45.core.pdf#page=159> use crate::util; use crate::tolerance::Tolerance; use crate::half_float; use std::mem; use std::convert::TryInto; use std::fmt; /// Describes which layout standard is being used. The only difference /// is that with `Std140` the offset between members of an /// array or between elements of the minor axis of a matrix is rounded /// up to be a multiple of 16 bytes. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum LayoutStd { Std140, Std430, } /// For matrix types, describes which axis is the major one. The /// elements of the major axis are stored consecutively in memory. For /// example, if the numbers indicate the order in memory of the /// components of a matrix, the two orderings would look like this: /// /// ```text /// Column major Row major /// [0 3 6] [0 1 2] /// [1 4 7] [3 4 5] /// [2 5 8] [6 7 8] /// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MajorAxis { Column, Row } /// Combines the [MajorAxis] and the [LayoutStd] to provide a complete /// description of the layout options used for accessing the /// components of a type. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Layout { pub std: LayoutStd, pub major: MajorAxis, } /// An enum representing all of the types that can be stored in a /// slot. These correspond to the types in GLSL. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { Int = 0, UInt, Int8, UInt8, Int16, UInt16, Int64, UInt64, Float16, Float, Double, F16Vec2, F16Vec3, F16Vec4, Vec2, Vec3, Vec4, DVec2, DVec3, DVec4, IVec2, IVec3, IVec4, UVec2, UVec3, UVec4, I8Vec2, I8Vec3, I8Vec4, U8Vec2, U8Vec3, U8Vec4, I16Vec2, I16Vec3, I16Vec4, U16Vec2, U16Vec3, U16Vec4, I64Vec2, I64Vec3, I64Vec4, U64Vec2, U64Vec3, U64Vec4, Mat2, Mat2x3, Mat2x4, Mat3x2, Mat3, Mat3x4, Mat4x2, Mat4x3, Mat4, DMat2, DMat2x3, DMat2x4, DMat3x2, DMat3, DMat3x4, DMat4x2, DMat4x3, DMat4, } /// The type of a component of the types in [Type]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BaseType { Int, UInt, Int8, UInt8, Int16, UInt16, Int64, UInt64, Float16, Float, Double, } /// A base type combined with a slice of bytes. This implements /// [Display](std::fmt::Display) so it can be used to extract a base /// type from a byte array and present it as a human-readable string. pub struct BaseTypeInSlice<'a> { base_type: BaseType, slice: &'a [u8], } /// A type of comparison that can be used to compare values stored in /// a slot with the chosen criteria. Use the /// [compare](Comparison::compare) method to perform the comparison. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Comparison { /// The comparison passes if the values are exactly equal. Equal, /// The comparison passes if the floating-point values are /// approximately equal using the given [Tolerance] settings. For /// integer types this is the same as [Equal](Comparison::Equal). FuzzyEqual, /// The comparison passes if the values are different. NotEqual, /// The comparison passes if all of the values in the first slot /// are less than the values in the second slot. Less, /// The comparison passes if all of the values in the first slot /// are greater than or equal to the values in the second slot. GreaterEqual, /// The comparison passes if all of the values in the first slot /// are greater than the values in the second slot. Greater, /// The comparison passes if all of the values in the first slot /// are less than or equal to the values in the second slot. LessEqual, } #[derive(Debug)] struct TypeInfo { base_type: BaseType, columns: usize, rows: usize, } static TYPE_INFOS: [TypeInfo; 62] = [ TypeInfo { base_type: BaseType::Int, columns: 1, rows: 1 }, // Int TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 1 }, // UInt TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 1 }, // Int8 TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 1 }, // UInt8 TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 1 }, // Int16 TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 1 }, // UInt16 TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 1 }, // Int64 TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 1 }, // UInt64 TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 1 }, // Float16 TypeInfo { base_type: BaseType::Float, columns: 1, rows: 1 }, // Float TypeInfo { base_type: BaseType::Double, columns: 1, rows: 1 }, // Double TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 2 }, // F16Vec2 TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 3 }, // F16Vec3 TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 4 }, // F16Vec4 TypeInfo { base_type: BaseType::Float, columns: 1, rows: 2 }, // Vec2 TypeInfo { base_type: BaseType::Float, columns: 1, rows: 3 }, // Vec3 TypeInfo { base_type: BaseType::Float, columns: 1, rows: 4 }, // Vec4 TypeInfo { base_type: BaseType::Double, columns: 1, rows: 2 }, // DVec2 TypeInfo { base_type: BaseType::Double, columns: 1, rows: 3 }, // DVec3 TypeInfo { base_type: BaseType::Double, columns: 1, rows: 4 }, // DVec4 TypeInfo { base_type: BaseType::Int, columns: 1, rows: 2 }, // IVec2 TypeInfo { base_type: BaseType::Int, columns: 1, rows: 3 }, // IVec3 TypeInfo { base_type: BaseType::Int, columns: 1, rows: 4 }, // IVec4 TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 2 }, // UVec2 TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 3 }, // UVec3 TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 4 }, // UVec4 TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 2 }, // I8Vec2 TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 3 }, // I8Vec3 TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 4 }, // I8Vec4 TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 2 }, // U8Vec2 TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 3 }, // U8Vec3 TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 4 }, // U8Vec4 TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 2 }, // I16Vec2 TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 3 }, // I16Vec3 TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 4 }, // I16Vec4 TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 2 }, // U16Vec2 TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 3 }, // U16Vec3 TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 4 }, // U16Vec4 TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 2 }, // I64Vec2 TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 3 }, // I64Vec3 TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 4 }, // I64Vec4 TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 2 }, // U64Vec2 TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 3 }, // U64Vec3 TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 4 }, // U64Vec4 TypeInfo { base_type: BaseType::Float, columns: 2, rows: 2 }, // Mat2 TypeInfo { base_type: BaseType::Float, columns: 2, rows: 3 }, // Mat2x3 TypeInfo { base_type: BaseType::Float, columns: 2, rows: 4 }, // Mat2x4 TypeInfo { base_type: BaseType::Float, columns: 3, rows: 2 }, // Mat3x2 TypeInfo { base_type: BaseType::Float, columns: 3, rows: 3 }, // Mat3 TypeInfo { base_type: BaseType::Float, columns: 3, rows: 4 }, // Mat3x4 TypeInfo { base_type: BaseType::Float, columns: 4, rows: 2 }, // Mat4x2 TypeInfo { base_type: BaseType::Float, columns: 4, rows: 3 }, // Mat4x3 TypeInfo { base_type: BaseType::Float, columns: 4, rows: 4 }, // Mat4 TypeInfo { base_type: BaseType::Double, columns: 2, rows: 2 }, // DMat2 TypeInfo { base_type: BaseType::Double, columns: 2, rows: 3 }, // DMat2x3 TypeInfo { base_type: BaseType::Double, columns: 2, rows: 4 }, // DMat2x4 TypeInfo { base_type: BaseType::Double, columns: 3, rows: 2 }, // DMat3x2 TypeInfo { base_type: BaseType::Double, columns: 3, rows: 3 }, // DMat3 TypeInfo { base_type: BaseType::Double, columns: 3, rows: 4 }, // DMat3x4 TypeInfo { base_type: BaseType::Double, columns: 4, rows: 2 }, // DMat4x2 TypeInfo { base_type: BaseType::Double, columns: 4, rows: 3 }, // DMat4x3 TypeInfo { base_type: BaseType::Double, columns: 4, rows: 4 }, // DMat4 ]; // Mapping from GLSL type name to slot type. Sorted alphabetically so // we can do a binary search. static GLSL_TYPE_NAMES: [(&'static str, Type); 68] = [ ("dmat2", Type::DMat2), ("dmat2x2", Type::DMat2), ("dmat2x3", Type::DMat2x3), ("dmat2x4", Type::DMat2x4), ("dmat3", Type::DMat3), ("dmat3x2", Type::DMat3x2), ("dmat3x3", Type::DMat3), ("dmat3x4", Type::DMat3x4), ("dmat4", Type::DMat4), ("dmat4x2", Type::DMat4x2), ("dmat4x3", Type::DMat4x3), ("dmat4x4", Type::DMat4), ("double", Type::Double), ("dvec2", Type::DVec2), ("dvec3", Type::DVec3), ("dvec4", Type::DVec4), ("f16vec2", Type::F16Vec2), ("f16vec3", Type::F16Vec3), ("f16vec4", Type::F16Vec4), ("float", Type::Float), ("float16_t", Type::Float16), ("i16vec2", Type::I16Vec2), ("i16vec3", Type::I16Vec3), ("i16vec4", Type::I16Vec4), ("i64vec2", Type::I64Vec2), ("i64vec3", Type::I64Vec3), ("i64vec4", Type::I64Vec4), ("i8vec2", Type::I8Vec2), ("i8vec3", Type::I8Vec3), ("i8vec4", Type::I8Vec4), ("int", Type::Int), ("int16_t", Type::Int16), ("int64_t", Type::Int64), ("int8_t", Type::Int8), ("ivec2", Type::IVec2), ("ivec3", Type::IVec3), ("ivec4", Type::IVec4), ("mat2", Type::Mat2), ("mat2x2", Type::Mat2), ("mat2x3", Type::Mat2x3), ("mat2x4", Type::Mat2x4), ("mat3", Type::Mat3), ("mat3x2", Type::Mat3x2), ("mat3x3", Type::Mat3), ("mat3x4", Type::Mat3x4), ("mat4", Type::Mat4), ("mat4x2", Type::Mat4x2), ("mat4x3", Type::Mat4x3), ("mat4x4", Type::Mat4), ("u16vec2", Type::U16Vec2), ("u16vec3", Type::U16Vec3), ("u16vec4", Type::U16Vec4), ("u64vec2", Type::U64Vec2), ("u64vec3", Type::U64Vec3), ("u64vec4", Type::U64Vec4), ("u8vec2", Type::U8Vec2), ("u8vec3", Type::U8Vec3), ("u8vec4", Type::U8Vec4), ("uint", Type::UInt), ("uint16_t", Type::UInt16), ("uint64_t", Type::UInt64), ("uint8_t", Type::UInt8), ("uvec2", Type::UVec2), ("uvec3", Type::UVec3), ("uvec4", Type::UVec4), ("vec2", Type::Vec2), ("vec3", Type::Vec3), ("vec4", Type::Vec4), ]; // Mapping from operator name to Comparison. Sorted by ASCII values so // we can do a binary search. static COMPARISON_NAMES: [(&'static str, Comparison); 7] = [ ("!=", Comparison::NotEqual), ("<", Comparison::Less), ("<=", Comparison::LessEqual), ("==", Comparison::Equal), (">", Comparison::Greater), (">=", Comparison::GreaterEqual), ("~=", Comparison::FuzzyEqual), ]; impl BaseType { /// Returns the size in bytes of a variable of this base type. pub fn size(self) -> usize { match self { BaseType::Int => mem::size_of::<i32>(), BaseType::UInt => mem::size_of::<u32>(), BaseType::Int8 => mem::size_of::<i8>(), BaseType::UInt8 => mem::size_of::<u8>(), BaseType::Int16 => mem::size_of::<i16>(), BaseType::UInt16 => mem::size_of::<u16>(), BaseType::Int64 => mem::size_of::<i64>(), BaseType::UInt64 => mem::size_of::<u64>(), BaseType::Float16 => mem::size_of::<u16>(), BaseType::Float => mem::size_of::<u32>(), BaseType::Double => mem::size_of::<u64>(), } } } impl<'a> BaseTypeInSlice<'a> { /// Combines the given base type and slice into a /// `BaseTypeInSlice` so that it can be displayed with the /// [Display](std::fmt::Display) trait. The slice is expected to /// contain a representation of the base type in native-endian /// order. The slice needs to have the same size as /// `base_type.size()` or the constructor will panic. pub fn new(base_type: BaseType, slice: &'a [u8]) -> BaseTypeInSlice<'a> { assert_eq!(slice.len(), base_type.size()); BaseTypeInSlice { base_type, slice } } } impl<'a> fmt::Display for BaseTypeInSlice<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.base_type { BaseType::Int => { let v = i32::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::UInt => { let v = u32::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::Int8 => { let v = i8::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::UInt8 => { let v = u8::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::Int16 => { let v = i16::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::UInt16 => { let v = u16::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::Int64 => { let v = i64::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::UInt64 => { let v = u64::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::Float16 => { let v = u16::from_ne_bytes(self.slice.try_into().unwrap()); let v = half_float::to_f64(v); write!(f, "{}", v) }, BaseType::Float => { let v = f32::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, BaseType::Double => { let v = f64::from_ne_bytes(self.slice.try_into().unwrap()); write!(f, "{}", v) }, } } } impl Type { /// Returns the type of the components of this slot type. #[inline] pub fn base_type(self) -> BaseType { TYPE_INFOS[self as usize].base_type } /// Returns the number of rows in this slot type. For primitive types /// this will be 1, for vectors it will be the size of the vector /// and for matrix types it will be the number of rows. #[inline] pub fn rows(self) -> usize { TYPE_INFOS[self as usize].rows } /// Return the number of columns in this slot type. For non-matrix /// types this will be 1. #[inline] pub fn columns(self) -> usize { TYPE_INFOS[self as usize].columns } /// Returns whether the type is a matrix type #[inline] pub fn is_matrix(self) -> bool { self.columns() > 1 } // Returns the effective major axis setting that should be used // when calculating offsets. The setting should only affect matrix // types. fn effective_major_axis(self, layout: Layout) -> MajorAxis { if self.is_matrix() { layout.major } else { MajorAxis::Column } } // Return a tuple containing the size of the major and minor axes // (in that order) given the type and layout fn major_minor(self, layout: Layout) -> (usize, usize) { match self.effective_major_axis(layout) { MajorAxis::Column => (self.columns(), self.rows()), MajorAxis::Row => (self.rows(), self.columns()), } } /// Returns the matrix stride of the slot type. This is the offset /// in bytes between consecutive components of the minor axis. For /// example, in a column-major layout, the matrix stride of a vec4 /// will be 16, because to move to the next column you need to add /// an offset of 16 bytes to the address of the previous column. pub fn matrix_stride(self, layout: Layout) -> usize { let component_size = self.base_type().size(); let (_major, minor) = self.major_minor(layout); let base_stride = if minor == 3 { component_size * 4 } else { component_size * minor }; match layout.std { LayoutStd::Std140 => { // According to std140 the size is rounded up to a vec4 util::align(base_stride, 16) }, LayoutStd::Std430 => base_stride, } } /// Returns the offset between members of the elements of an array /// of this type of slot. There may be padding between the members /// to fulfill the alignment criteria of the layout. pub fn array_stride(self, layout: Layout) -> usize { let matrix_stride = self.matrix_stride(layout); let (major, _minor) = self.major_minor(layout); matrix_stride * major } /// Returns the size of the slot type. Note that unlike /// [array_stride](Type::array_stride), this won’t include the /// padding to align the values in an array according to the GLSL /// layout rules. It will however always have a size that is a /// multiple of the size of the base type so it is suitable as an /// array stride for packed values stored internally that aren’t /// passed to Vulkan. pub fn size(self, layout: Layout) -> usize { let matrix_stride = self.matrix_stride(layout); let base_size = self.base_type().size(); let (major, minor) = self.major_minor(layout); (major - 1) * matrix_stride + base_size * minor } /// Returns an iterator that iterates over the offsets of the /// components of this slot. This can be used to extract the /// values out of an array of bytes containing the slot. The /// offsets are returned in column-major order (ie, all of the /// offsets for a row are returned before moving to the next /// column), but the offsets are calculated to take into account /// the major axis setting in the layout. pub fn offsets(self, layout: Layout) -> OffsetsIter { OffsetsIter::new(self, layout) } /// Returns a [Type] given the GLSL type name, for example `dmat2` /// or `i64vec3`. `None` will be returned if the type name isn’t /// recognised. pub fn from_glsl_type(type_name: &str) -> Option<Type> { match GLSL_TYPE_NAMES.binary_search_by( |&(name, _)| name.cmp(type_name) ) { Ok(pos) => Some(GLSL_TYPE_NAMES[pos].1), Err(_) => None, } } } /// Iterator over the offsets into an array to extract the components /// of a slot. Created by [Type::offsets]. #[derive(Debug, Clone)] pub struct OffsetsIter { slot_type: Type, row: usize, column: usize, offset: usize, column_offset: isize, row_offset: isize, } impl OffsetsIter { fn new(slot_type: Type, layout: Layout) -> OffsetsIter { let column_offset; let row_offset; let stride = slot_type.matrix_stride(layout); let base_size = slot_type.base_type().size(); match slot_type.effective_major_axis(layout) { MajorAxis::Column => { column_offset = (stride - base_size * slot_type.rows()) as isize; row_offset = base_size as isize; }, MajorAxis::Row => { column_offset = base_size as isize - (stride * slot_type.rows()) as isize; row_offset = stride as isize; }, } OffsetsIter { slot_type, row: 0, column: 0, offset: 0, column_offset, row_offset, } } } impl Iterator for OffsetsIter { type Item = usize; fn next(&mut self) -> Option<usize> { if self.column >= self.slot_type.columns() { return None; } let result = self.offset; self.offset = ((self.offset as isize) + self.row_offset) as usize; self.row += 1; if self.row >= self.slot_type.rows() { self.row = 0; self.column += 1; self.offset = ((self.offset as isize) + self.column_offset) as usize; } Some(result) } fn size_hint(&self) -> (usize, Option<usize>) { let n_columns = self.slot_type.columns(); let size = if self.column >= n_columns { 0 } else { (self.slot_type.rows() - self.row) + ((n_columns - 1 - self.column) * self.slot_type.rows()) }; (size, Some(size)) } fn count(self) -> usize { self.size_hint().0 } } impl Comparison { fn compare_without_fuzzy<T: PartialOrd>(self, a: T, b: T) -> bool { match self { Comparison::Equal | Comparison::FuzzyEqual => a.eq(&b), Comparison::NotEqual => a.ne(&b), Comparison::Less => a.lt(&b), Comparison::GreaterEqual => a.ge(&b), Comparison::Greater => a.gt(&b), Comparison::LessEqual => a.le(&b), } } fn compare_value( self, tolerance: &Tolerance, component: usize, base_type: BaseType, a: &[u8], b: &[u8] ) -> bool { match base_type { BaseType::Int => { let a = i32::from_ne_bytes(a.try_into().unwrap()); let b = i32::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::UInt => { let a = u32::from_ne_bytes(a.try_into().unwrap()); let b = u32::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::Int8 => { let a = i8::from_ne_bytes(a.try_into().unwrap()); let b = i8::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::UInt8 => { let a = u8::from_ne_bytes(a.try_into().unwrap()); let b = u8::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::Int16 => { let a = i16::from_ne_bytes(a.try_into().unwrap()); let b = i16::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::UInt16 => { let a = u16::from_ne_bytes(a.try_into().unwrap()); let b = u16::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::Int64 => { let a = i64::from_ne_bytes(a.try_into().unwrap()); let b = i64::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::UInt64 => { let a = u64::from_ne_bytes(a.try_into().unwrap()); let b = u64::from_ne_bytes(b.try_into().unwrap()); self.compare_without_fuzzy(a, b) }, BaseType::Float16 => { let a = u16::from_ne_bytes(a.try_into().unwrap()); let a = half_float::to_f64(a); let b = u16::from_ne_bytes(b.try_into().unwrap()); let b = half_float::to_f64(b); match self { Comparison::FuzzyEqual => tolerance.equal(component, a, b), Comparison::Equal | Comparison::NotEqual | Comparison::Less | Comparison::GreaterEqual | Comparison::Greater | Comparison::LessEqual => { self.compare_without_fuzzy(a, b) }, } }, BaseType::Float => { let a = f32::from_ne_bytes(a.try_into().unwrap()); let b = f32::from_ne_bytes(b.try_into().unwrap()); match self { Comparison::FuzzyEqual => { tolerance.equal(component, a as f64, b as f64) }, Comparison::Equal | Comparison::NotEqual | Comparison::Less | Comparison::GreaterEqual | Comparison::Greater | Comparison::LessEqual => { self.compare_without_fuzzy(a, b) }, } }, BaseType::Double => { let a = f64::from_ne_bytes(a.try_into().unwrap()); let b = f64::from_ne_bytes(b.try_into().unwrap()); match self { Comparison::FuzzyEqual => { tolerance.equal(component, a, b) }, Comparison::Equal | Comparison::NotEqual | Comparison::Less | Comparison::GreaterEqual | Comparison::Greater | Comparison::LessEqual => { self.compare_without_fuzzy(a, b) }, } }, } } /// Compares the values in `a` with the values in `b` using the /// chosen criteria. The values are extracted from the two byte /// slices using the layout and type provided. The tolerance /// parameter is only used if the comparison is /// [FuzzyEqual](Comparison::FuzzyEqual). pub fn compare( self, tolerance: &Tolerance, slot_type: Type, layout: Layout, a: &[u8], b: &[u8], ) -> bool { let base_type = slot_type.base_type(); let base_type_size = base_type.size(); let rows = slot_type.rows(); for (i, offset) in slot_type.offsets(layout).enumerate() { let a = &a[offset..offset + base_type_size]; let b = &b[offset..offset + base_type_size]; if !self.compare_value(tolerance, i % rows, base_type, a, b) { return false; } } true } /// Get the comparison operator given the name in GLSL like `"<"` /// or `">="`. It also accepts `"~="` to return the fuzzy equal /// comparison. If the operator isn’t recognised then it will /// return `None`. pub fn from_operator(operator: &str) -> Option<Comparison> { match COMPARISON_NAMES.binary_search_by( |&(probe, _value)| probe.cmp(operator) ) { Ok(pos) => Some(COMPARISON_NAMES[pos].1), Err(_) => None, } } } #[cfg(test)] mod test { use super::*; #[test] fn test_type_info() { // Test a few types // First assert!(matches!(Type::Int.base_type(), BaseType::Int)); assert_eq!(Type::Int.columns(), 1); assert_eq!(Type::Int.rows(), 1); assert!(!Type::Int.is_matrix()); // Last assert!(matches!(Type::DMat4.base_type(), BaseType::Double)); assert_eq!(Type::DMat4.columns(), 4); assert_eq!(Type::DMat4.rows(), 4); assert!(Type::DMat4.is_matrix()); // Vector assert!(matches!(Type::Vec4.base_type(), BaseType::Float)); assert_eq!(Type::Vec4.columns(), 1); assert_eq!(Type::Vec4.rows(), 4); assert!(!Type::Vec4.is_matrix()); // Non-square matrix assert!(matches!(Type::DMat4x3.base_type(), BaseType::Double)); assert_eq!(Type::DMat4x3.columns(), 4); assert_eq!(Type::DMat4x3.rows(), 3); assert!(Type::DMat4x3.is_matrix()); } #[test] fn test_base_type_size() { assert_eq!(BaseType::Int.size(), 4); assert_eq!(BaseType::UInt.size(), 4); assert_eq!(BaseType::Int8.size(), 1); assert_eq!(BaseType::UInt8.size(), 1); assert_eq!(BaseType::Int16.size(), 2); assert_eq!(BaseType::UInt16.size(), 2); assert_eq!(BaseType::Int64.size(), 8); assert_eq!(BaseType::UInt64.size(), 8); assert_eq!(BaseType::Float16.size(), 2); assert_eq!(BaseType::Float.size(), 4); assert_eq!(BaseType::Double.size(), 8); } #[test] fn test_matrix_stride() { assert_eq!( Type::Mat4.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 16, ); // In Std140 the vecs along the minor axis are aligned to the // size of a vec4 assert_eq!( Type::Mat4x2.matrix_stride( Layout { std: LayoutStd::Std140, major: MajorAxis::Column } ), 16, ); // In Std430 they are not assert_eq!( Type::Mat4x2.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 8, ); // 3-component axis are always aligned to a vec4 regardless of // the layout standard assert_eq!( Type::Mat4x3.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 16, ); // For the row-major the stride the axes are reversed assert_eq!( Type::Mat4x2.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Row } ), 16, ); assert_eq!( Type::Mat2x4.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Row } ), 8, ); // For non-matrix types the matrix stride is the array stride assert_eq!( Type::Vec3.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 16, ); assert_eq!( Type::Float.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 4, ); // Row-major layout does not affect vectors assert_eq!( Type::Vec2.matrix_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Row } ), 8, ); } #[test] fn test_array_stride() { assert_eq!( Type::UInt8.array_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 1, ); assert_eq!( Type::I8Vec2.array_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 2, ); assert_eq!( Type::I8Vec3.array_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 4, ); assert_eq!( Type::Mat4x3.array_stride( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 64, ); } #[test] fn test_size() { assert_eq!( Type::UInt8.size( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 1, ); assert_eq!( Type::I8Vec2.size( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 2, ); assert_eq!( Type::I8Vec3.size( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 3, ); assert_eq!( Type::Mat4x3.size( Layout { std: LayoutStd::Std430, major: MajorAxis::Column } ), 3 * 4 * 4 + 3 * 4, ); } #[test] fn test_iterator() { let default_layout = Layout { std: LayoutStd::Std140, major: MajorAxis::Column, }; // The components of a vec4 are packed tightly let vec4_offsets = Type::Vec4.offsets(default_layout).collect::<Vec<usize>>(); assert_eq!(vec4_offsets, vec![0, 4, 8, 12]); // The major axis setting shouldn’t affect vectors let row_major_layout = Layout { std: LayoutStd::Std140, major: MajorAxis::Row, }; let vec4_offsets = Type::Vec4.offsets(row_major_layout).collect::<Vec<usize>>(); assert_eq!(vec4_offsets, vec![0, 4, 8, 12]); // Components of a mat4 are packed tightly let mat4_offsets = Type::Mat4.offsets(default_layout).collect::<Vec<usize>>(); assert_eq!( mat4_offsets, vec![0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60], ); // If there are 3 rows then they are padded out to align // each column to the size of a vec4 let mat4x3_offsets = Type::Mat4x3.offsets(default_layout).collect::<Vec<usize>>(); assert_eq!( mat4x3_offsets, vec![0, 4, 8, 16, 20, 24, 32, 36, 40, 48, 52, 56], ); // If row-major order is being used then the offsets are still // reported in column-major order but the addresses that they // point to are in row-major order. // [0 4 8] // [16 20 24] // [32 36 40] // [48 52 56] let mat3x4_offsets = Type::Mat3x4.offsets(row_major_layout).collect::<Vec<usize>>(); assert_eq!( mat3x4_offsets, vec![0, 16, 32, 48, 4, 20, 36, 52, 8, 24, 40, 56], ); // Check the size_hint and count. The size hint is always // exact and should match the count. let mut iter = Type::Mat4.offsets(default_layout); for i in 0..16 { let expected_count = 16 - i; assert_eq!( iter.size_hint(), (expected_count, Some(expected_count)) ); assert_eq!(iter.clone().count(), expected_count); assert!(matches!(iter.next(), Some(_))); } assert_eq!(iter.size_hint(), (0, Some(0))); assert_eq!(iter.clone().count(), 0); assert!(matches!(iter.next(), None)); } #[test] fn test_compare() { // Test every comparison mode assert!(Comparison::Equal.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[5], )); assert!(!Comparison::Equal.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[6], &[5], )); assert!(Comparison::NotEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[6], &[5], )); assert!(!Comparison::NotEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[5], )); assert!(Comparison::Less.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[4], &[5], )); assert!(!Comparison::Less.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[5], )); assert!(!Comparison::Less.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[6], &[5], )); assert!(Comparison::LessEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[4], &[5], )); assert!(Comparison::LessEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[5], )); assert!(!Comparison::LessEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[6], &[5], )); assert!(Comparison::Greater.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[4], )); assert!(!Comparison::Greater.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[5], )); assert!(!Comparison::Greater.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[6], )); assert!(Comparison::GreaterEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[4], )); assert!(Comparison::GreaterEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[5], )); assert!(!Comparison::GreaterEqual.compare( &Default::default(), // tolerance Type::UInt8, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5], &[6], )); assert!(Comparison::FuzzyEqual.compare( &Tolerance::new([1.0; 4], false), Type::Float, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &5.0f32.to_ne_bytes(), &5.9f32.to_ne_bytes(), )); assert!(!Comparison::FuzzyEqual.compare( &Tolerance::new([1.0; 4], false), Type::Float, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &5.0f32.to_ne_bytes(), &6.1f32.to_ne_bytes(), )); // Test all types macro_rules! test_compare_type_with_values { ($type_enum:expr, $low_value:expr, $high_value:expr) => { assert!(Comparison::Less.compare( &Default::default(), $type_enum, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &$low_value.to_ne_bytes(), &$high_value.to_ne_bytes(), )); }; } macro_rules! test_compare_simple_type { ($type_num:expr, $rust_type:ty) => { test_compare_type_with_values!( $type_num, 0 as $rust_type, <$rust_type>::MAX ); }; } test_compare_simple_type!(Type::Int, i32); test_compare_simple_type!(Type::UInt, u32); test_compare_simple_type!(Type::Int8, i8); test_compare_simple_type!(Type::UInt8, u8); test_compare_simple_type!(Type::Int16, i16); test_compare_simple_type!(Type::UInt16, u16); test_compare_simple_type!(Type::Int64, i64); test_compare_simple_type!(Type::UInt64, u64); test_compare_simple_type!(Type::Float, f32); test_compare_simple_type!(Type::Double, f64); test_compare_type_with_values!( Type::Float16, half_float::from_f32(-100.0), half_float::from_f32(300.0) ); macro_rules! test_compare_big_type { ($type_enum:expr, $bytes:expr) => { let layout = Layout { std: LayoutStd::Std140, major: MajorAxis::Column }; let mut bytes = Vec::new(); assert_eq!($type_enum.size(layout) % $bytes.len(), 0); for _ in 0..($type_enum.size(layout) / $bytes.len()) { bytes.extend_from_slice($bytes); } assert!(Comparison::FuzzyEqual.compare( &Default::default(), $type_enum, layout, &bytes, &bytes, )); }; } test_compare_big_type!( Type::F16Vec2, &half_float::from_f32(1.0).to_ne_bytes() ); test_compare_big_type!( Type::F16Vec3, &half_float::from_f32(1.0).to_ne_bytes() ); test_compare_big_type!( Type::F16Vec4, &half_float::from_f32(1.0).to_ne_bytes() ); test_compare_big_type!(Type::Vec2, &42.0f32.to_ne_bytes()); test_compare_big_type!(Type::Vec3, &24.0f32.to_ne_bytes()); test_compare_big_type!(Type::Vec4, &6.0f32.to_ne_bytes()); test_compare_big_type!(Type::DVec2, &42.0f64.to_ne_bytes()); test_compare_big_type!(Type::DVec3, &24.0f64.to_ne_bytes()); test_compare_big_type!(Type::DVec4, &6.0f64.to_ne_bytes()); test_compare_big_type!(Type::IVec2, &(-42i32).to_ne_bytes()); test_compare_big_type!(Type::IVec3, &(-24i32).to_ne_bytes()); test_compare_big_type!(Type::IVec4, &(-6i32).to_ne_bytes()); test_compare_big_type!(Type::UVec2, &42u32.to_ne_bytes()); test_compare_big_type!(Type::UVec3, &24u32.to_ne_bytes()); test_compare_big_type!(Type::UVec4, &6u32.to_ne_bytes()); test_compare_big_type!(Type::I8Vec2, &(-42i8).to_ne_bytes()); test_compare_big_type!(Type::I8Vec3, &(-24i8).to_ne_bytes()); test_compare_big_type!(Type::I8Vec4, &(-6i8).to_ne_bytes()); test_compare_big_type!(Type::U8Vec2, &42u8.to_ne_bytes()); test_compare_big_type!(Type::U8Vec3, &24u8.to_ne_bytes()); test_compare_big_type!(Type::U8Vec4, &6u8.to_ne_bytes()); test_compare_big_type!(Type::I16Vec2, &(-42i16).to_ne_bytes()); test_compare_big_type!(Type::I16Vec3, &(-24i16).to_ne_bytes()); test_compare_big_type!(Type::I16Vec4, &(-6i16).to_ne_bytes()); test_compare_big_type!(Type::U16Vec2, &42u16.to_ne_bytes()); test_compare_big_type!(Type::U16Vec3, &24u16.to_ne_bytes()); test_compare_big_type!(Type::U16Vec4, &6u16.to_ne_bytes()); test_compare_big_type!(Type::I64Vec2, &(-42i64).to_ne_bytes()); test_compare_big_type!(Type::I64Vec3, &(-24i64).to_ne_bytes()); test_compare_big_type!(Type::I64Vec4, &(-6i64).to_ne_bytes()); test_compare_big_type!(Type::U64Vec2, &42u64.to_ne_bytes()); test_compare_big_type!(Type::U64Vec3, &24u64.to_ne_bytes()); test_compare_big_type!(Type::U64Vec4, &6u64.to_ne_bytes()); test_compare_big_type!(Type::Mat2, &42.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat2x3, &42.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat2x4, &42.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat3, &24.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat3x2, &24.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat3x4, &24.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat4, &6.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat4x2, &6.0f32.to_ne_bytes()); test_compare_big_type!(Type::Mat4x3, &6.0f32.to_ne_bytes()); test_compare_big_type!(Type::DMat2, &42.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat2x3, &42.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat2x4, &42.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat3, &24.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat3x2, &24.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat3x4, &24.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat4, &6.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat4x2, &6.0f64.to_ne_bytes()); test_compare_big_type!(Type::DMat4x3, &6.0f64.to_ne_bytes()); // Check that it the comparison fails if a value other than // the first one is different assert!(!Comparison::Equal.compare( &Default::default(), // tolerance Type::U8Vec4, Layout { std: LayoutStd::Std140, major: MajorAxis::Column }, &[5, 6, 7, 8], &[5, 6, 7, 9], )); } #[test] fn test_from_glsl_type() { assert_eq!(Type::from_glsl_type("vec3"), Some(Type::Vec3)); assert_eq!(Type::from_glsl_type("dvec4"), Some(Type::DVec4)); assert_eq!(Type::from_glsl_type("i64vec4"), Some(Type::I64Vec4)); assert_eq!(Type::from_glsl_type("uint"), Some(Type::UInt)); assert_eq!(Type::from_glsl_type("uint16_t"), Some(Type::UInt16)); assert_eq!(Type::from_glsl_type("dvec5"), None); // Check that all types can be found with the binary search for &(type_name, slot_type) in GLSL_TYPE_NAMES.iter() { assert_eq!(Type::from_glsl_type(type_name), Some(slot_type)); } } #[test] fn test_comparison_from_operator() { // If a comparison is added then please add it to this test too assert_eq!(COMPARISON_NAMES.len(), 7); assert_eq!( Comparison::from_operator("!="), Some(Comparison::NotEqual), ); assert_eq!( Comparison::from_operator("<"), Some(Comparison::Less), ); assert_eq!( Comparison::from_operator("<="), Some(Comparison::LessEqual), ); assert_eq!( Comparison::from_operator("=="), Some(Comparison::Equal), ); assert_eq!( Comparison::from_operator(">"), Some(Comparison::Greater), ); assert_eq!( Comparison::from_operator(">="), Some(Comparison::GreaterEqual), ); assert_eq!( Comparison::from_operator("~="), Some(Comparison::FuzzyEqual), ); assert_eq!(Comparison::from_operator("<=>"), None); } #[track_caller] fn check_base_type_format( base_type: BaseType, bytes: &[u8], expected: &str, ) { assert_eq!( &BaseTypeInSlice::new(base_type, bytes).to_string(), expected ); } #[test] fn base_type_format() { check_base_type_format(BaseType::Int, &1234i32.to_ne_bytes(), "1234"); check_base_type_format( BaseType::UInt, &u32::MAX.to_ne_bytes(), "4294967295", ); check_base_type_format(BaseType::Int8, &(-4i8).to_ne_bytes(), "-4"); check_base_type_format(BaseType::UInt8, &129u8.to_ne_bytes(), "129"); check_base_type_format(BaseType::Int16, &(-4i16).to_ne_bytes(), "-4"); check_base_type_format( BaseType::UInt16, &32769u16.to_ne_bytes(), "32769" ); check_base_type_format(BaseType::Int64, &(-4i64).to_ne_bytes(), "-4"); check_base_type_format( BaseType::UInt64, &u64::MAX.to_ne_bytes(), "18446744073709551615" ); check_base_type_format( BaseType::Float16, &0xc000u16.to_ne_bytes(), "-2" ); check_base_type_format( BaseType::Float, &2.0f32.to_ne_bytes(), "2" ); check_base_type_format( BaseType::Double, &2.0f64.to_ne_bytes(), "2" ); } } ``` -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- ```rust use anyhow::Result; use clap::Parser; use image::codecs::pnm::PnmDecoder; use image::{DynamicImage, ImageError, RgbImage}; use rmcp::{ Error as McpError, RoleServer, ServerHandler, ServiceExt, const_string, model::*, schemars, service::RequestContext, tool, transport::stdio, }; use serde_json::json; use shaderc::{self, CompileOptions, Compiler, OptimizationLevel, ShaderKind}; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::sync::Mutex; use tracing_subscriber::{self, EnvFilter}; pub fn read_and_decode_ppm_file<P: AsRef<Path>>(path: P) -> Result<RgbImage, ImageError> { let file = File::open(path)?; let reader = BufReader::new(file); let decoder: PnmDecoder<BufReader<File>> = PnmDecoder::new(reader)?; let dynamic_image = DynamicImage::from_decoder(decoder)?; let rgb_image = dynamic_image.into_rgb8(); Ok(rgb_image) } #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] pub enum ShaderStage { #[schemars(description = "Vertex processing stage (transforms vertices)")] Vert, #[schemars(description = "Fragment processing stage (determines pixel colors)")] Frag, #[schemars(description = "Tessellation control stage for patch control points")] Tesc, #[schemars(description = "Tessellation evaluation stage for computing tessellated geometry")] Tese, #[schemars(description = "Geometry stage for generating/modifying primitives")] Geom, #[schemars(description = "Compute stage for general-purpose computation")] Comp, } #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] pub enum ShaderRunnerRequire { #[schemars( description = "Enables cooperative matrix operations with specified dimensions and data type" )] CooperativeMatrix { m: u32, n: u32, component_type: String, }, #[schemars( description = "Specifies depth/stencil format for depth testing and stencil operations" )] DepthStencil(String), #[schemars(description = "Specifies framebuffer format for render target output")] Framebuffer(String), #[schemars(description = "Enables double-precision floating point operations in shaders")] ShaderFloat64, #[schemars(description = "Enables geometry shader stage for primitive manipulation")] GeometryShader, #[schemars(description = "Enables lines with width greater than 1.0 pixel")] WideLines, #[schemars(description = "Enables bitwise logical operations on framebuffer contents")] LogicOp, #[schemars(description = "Specifies required subgroup size for shader execution")] SubgroupSize(u32), #[schemars(description = "Enables memory stores and atomic operations in fragment shaders")] FragmentStoresAndAtomics, #[schemars(description = "Enables shaders to use raw buffer addresses")] BufferDeviceAddress, } #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] pub enum ShaderRunnerPass { #[schemars( description = "Use a built-in pass-through vertex shader (forwards attributes to fragment shader)" )] VertPassthrough, #[schemars(description = "Use compiled vertex shader from specified SPIR-V assembly file")] VertSpirv { #[schemars( description = "Path to the compiled vertex shader SPIR-V assembly (.spvasm) file" )] vert_spvasm_path: String, }, #[schemars(description = "Use compiled fragment shader from specified SPIR-V assembly file")] FragSpirv { #[schemars( description = "Path to the compiled fragment shader SPIR-V assembly (.spvasm) file" )] frag_spvasm_path: String, }, #[schemars(description = "Use compiled compute shader from specified SPIR-V assembly file")] CompSpirv { #[schemars( description = "Path to the compiled compute shader SPIR-V assembly (.spvasm) file" )] comp_spvasm_path: String, }, #[schemars(description = "Use compiled geometry shader from specified SPIR-V assembly file")] GeomSpirv { #[schemars( description = "Path to the compiled geometry shader SPIR-V assembly (.spvasm) file" )] geom_spvasm_path: String, }, #[schemars( description = "Use compiled tessellation control shader from specified SPIR-V assembly file" )] TescSpirv { #[schemars( description = "Path to the compiled tessellation control shader SPIR-V assembly (.spvasm) file" )] tesc_spvasm_path: String, }, #[schemars( description = "Use compiled tessellation evaluation shader from specified SPIR-V assembly file" )] TeseSpirv { #[schemars( description = "Path to the compiled tessellation evaluation shader SPIR-V assembly (.spvasm) file" )] tese_spvasm_path: String, }, } #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] pub enum ShaderRunnerVertexData { #[schemars(description = "Defines an attribute format at a shader location/binding")] AttributeFormat { #[schemars(description = "Location/binding number in the shader")] location: u32, #[schemars(description = "Format name (e.g., R32G32_SFLOAT, A8B8G8R8_UNORM_PACK32)")] format: String, }, #[schemars(description = "2D position data (for R32G32_SFLOAT format)")] Vec2 { #[schemars(description = "X coordinate value")] x: f32, #[schemars(description = "Y coordinate value")] y: f32, }, #[schemars(description = "3D position data (for R32G32B32_SFLOAT format)")] Vec3 { #[schemars(description = "X coordinate value")] x: f32, #[schemars(description = "Y coordinate value")] y: f32, #[schemars(description = "Z coordinate value")] z: f32, }, #[schemars(description = "4D position/color data (for R32G32B32A32_SFLOAT format)")] Vec4 { #[schemars(description = "X coordinate or Red component")] x: f32, #[schemars(description = "Y coordinate or Green component")] y: f32, #[schemars(description = "Z coordinate or Blue component")] z: f32, #[schemars(description = "W coordinate or Alpha component")] w: f32, }, #[schemars(description = "RGB color data (for R8G8B8_UNORM format)")] RGB { #[schemars(description = "Red component (0-255)")] r: u8, #[schemars(description = "Green component (0-255)")] g: u8, #[schemars(description = "Blue component (0-255)")] b: u8, }, #[schemars(description = "ARGB color as hex (for A8B8G8R8_UNORM_PACK32 format)")] Hex { #[schemars(description = "Color as hex string (0xAARRGGBB format)")] value: String, }, #[schemars(description = "Generic data components for custom formats")] GenericComponents { #[schemars(description = "Component values as strings, interpreted by format")] components: Vec<String>, }, } #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] pub enum ShaderRunnerTest { #[schemars(description = "Set fragment shader entrypoint function name")] FragmentEntrypoint { #[schemars(description = "Function name to use as entrypoint")] name: String, }, #[schemars(description = "Set vertex shader entrypoint function name")] VertexEntrypoint { #[schemars(description = "Function name to use as entrypoint")] name: String, }, #[schemars(description = "Set compute shader entrypoint function name")] ComputeEntrypoint { #[schemars(description = "Function name to use as entrypoint")] name: String, }, #[schemars(description = "Set geometry shader entrypoint function name")] GeometryEntrypoint { #[schemars(description = "Function name to use as entrypoint")] name: String, }, #[schemars(description = "Draw a rectangle using normalized device coordinates")] DrawRect { #[schemars(description = "X coordinate of lower-left corner")] x: f32, #[schemars(description = "Y coordinate of lower-left corner")] y: f32, #[schemars(description = "Width of the rectangle")] width: f32, #[schemars(description = "Height of the rectangle")] height: f32, }, #[schemars(description = "Draw primitives using vertex data")] DrawArrays { #[schemars(description = "Primitive type (TRIANGLE_LIST, POINT_LIST, etc.)")] primitive_type: String, #[schemars(description = "Index of first vertex")] first: u32, #[schemars(description = "Number of vertices to draw")] count: u32, }, #[schemars(description = "Draw primitives using indexed vertex data")] DrawArraysIndexed { #[schemars(description = "Primitive type (TRIANGLE_LIST, etc.)")] primitive_type: String, #[schemars(description = "Index of first index")] first: u32, #[schemars(description = "Number of indices to use")] count: u32, }, #[schemars(description = "Create or initialize a Shader Storage Buffer Object (SSBO)")] SSBO { #[schemars(description = "Binding point in the shader")] binding: u32, #[schemars(description = "Size in bytes (if data not provided)")] size: Option<u32>, #[schemars(description = "Initial buffer contents")] data: Option<Vec<u8>>, #[schemars(description = "Descriptor set number (default: 0)")] descriptor_set: Option<u32>, }, #[schemars(description = "Update a portion of an SSBO with new data")] SSBOSubData { #[schemars(description = "Binding point in the shader")] binding: u32, #[schemars(description = "Data type (float, vec4, etc.)")] data_type: String, #[schemars(description = "Byte offset into the buffer")] offset: u32, #[schemars(description = "Values to write (as strings)")] values: Vec<String>, #[schemars(description = "Descriptor set number (default: 0)")] descriptor_set: Option<u32>, }, #[schemars(description = "Create or initialize a Uniform Buffer Object (UBO)")] UBO { #[schemars(description = "Binding point in the shader")] binding: u32, #[schemars(description = "Buffer contents")] data: Vec<u8>, #[schemars(description = "Descriptor set number (default: 0)")] descriptor_set: Option<u32>, }, #[schemars(description = "Update a portion of a UBO with new data")] UBOSubData { #[schemars(description = "Binding point in the shader")] binding: u32, #[schemars(description = "Data type (float, vec4, etc.)")] data_type: String, #[schemars(description = "Byte offset into the buffer")] offset: u32, #[schemars(description = "Values to write (as strings)")] values: Vec<String>, #[schemars(description = "Descriptor set number (default: 0)")] descriptor_set: Option<u32>, }, #[schemars(description = "Set memory layout for buffer data")] BufferLayout { #[schemars(description = "Buffer type (ubo, ssbo)")] buffer_type: String, #[schemars(description = "Layout specification (std140, std430, row_major, column_major)")] layout_type: String, }, #[schemars(description = "Set push constant values")] Push { #[schemars(description = "Data type (float, vec4, etc.)")] data_type: String, #[schemars(description = "Byte offset into push constant block")] offset: u32, #[schemars(description = "Values to write (as strings)")] values: Vec<String>, }, #[schemars(description = "Set memory layout for push constants")] PushLayout { #[schemars(description = "Layout specification (std140, std430)")] layout_type: String, }, #[schemars(description = "Execute compute shader with specified workgroup counts")] Compute { #[schemars(description = "Number of workgroups in X dimension")] x: u32, #[schemars(description = "Number of workgroups in Y dimension")] y: u32, #[schemars(description = "Number of workgroups in Z dimension")] z: u32, }, #[schemars(description = "Verify framebuffer or buffer contents match expected values")] Probe { #[schemars(description = "Probe type (all, rect, ssbo, etc.)")] probe_type: String, #[schemars(description = "Component format (rgba, rgb, etc.)")] format: String, #[schemars(description = "Parameters (coordinates, expected values)")] args: Vec<String>, }, #[schemars(description = "Verify contents using normalized (0-1) coordinates")] RelativeProbe { #[schemars(description = "Probe type (rect, etc.)")] probe_type: String, #[schemars(description = "Component format (rgba, rgb, etc.)")] format: String, #[schemars(description = "Parameters (coordinates, expected values)")] args: Vec<String>, }, #[schemars(description = "Set acceptable error margin for value comparisons")] Tolerance { #[schemars(description = "Error margins (absolute or percentage with % suffix)")] values: Vec<f32>, }, #[schemars(description = "Clear the framebuffer to default values")] Clear, #[schemars(description = "Enable/disable depth testing")] DepthTestEnable { #[schemars(description = "True to enable depth testing")] enable: bool, }, #[schemars(description = "Enable/disable writing to depth buffer")] DepthWriteEnable { #[schemars(description = "True to enable depth writes")] enable: bool, }, #[schemars(description = "Set depth comparison function")] DepthCompareOp { #[schemars(description = "Function name (VK_COMPARE_OP_LESS, etc.)")] op: String, }, #[schemars(description = "Enable/disable stencil testing")] StencilTestEnable { #[schemars(description = "True to enable stencil testing")] enable: bool, }, #[schemars(description = "Define which winding order is front-facing")] FrontFace { #[schemars(description = "Mode (VK_FRONT_FACE_CLOCKWISE, etc.)")] mode: String, }, #[schemars(description = "Configure stencil operation for a face")] StencilOp { #[schemars(description = "Face (front, back)")] face: String, #[schemars(description = "Operation name (passOp, failOp, etc.)")] op_name: String, #[schemars(description = "Value (VK_STENCIL_OP_REPLACE, etc.)")] value: String, }, #[schemars(description = "Set reference value for stencil comparisons")] StencilReference { #[schemars(description = "Face (front, back)")] face: String, #[schemars(description = "Reference value")] value: u32, }, #[schemars(description = "Set stencil comparison function")] StencilCompareOp { #[schemars(description = "Face (front, back)")] face: String, #[schemars(description = "Function (VK_COMPARE_OP_EQUAL, etc.)")] op: String, }, #[schemars(description = "Control which color channels can be written")] ColorWriteMask { #[schemars( description = "Bit flags (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT, etc.)" )] mask: String, }, #[schemars(description = "Enable/disable logical operations on colors")] LogicOpEnable { #[schemars(description = "True to enable logic operations")] enable: bool, }, #[schemars(description = "Set type of logical operation on colors")] LogicOp { #[schemars(description = "Operation (VK_LOGIC_OP_XOR, etc.)")] op: String, }, #[schemars(description = "Set face culling mode")] CullMode { #[schemars(description = "Mode (VK_CULL_MODE_BACK_BIT, etc.)")] mode: String, }, #[schemars(description = "Set width for line primitives")] LineWidth { #[schemars(description = "Width in pixels")] width: f32, }, #[schemars(description = "Specify a feature required by the test")] Require { #[schemars(description = "Feature name (subgroup_size, depthstencil, etc.)")] feature: String, #[schemars(description = "Feature parameters")] parameters: Vec<String>, }, } #[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)] pub struct CompileRequest { #[schemars(description = "The shader stage to compile (vert, frag, comp, geom, tesc, tese)")] pub stage: ShaderStage, #[schemars(description = "GLSL shader source code to compile")] pub source: String, #[schemars(description = "Path where compiled SPIR-V assembly (.spvasm) will be saved")] pub tmp_output_path: String, } #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct CompileShadersRequest { #[schemars(description = "List of shader compile requests (produces SPIR-V assemblies)")] pub requests: Vec<CompileRequest>, } #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct CompileRunShadersRequest { #[schemars( description = "List of shader compile requests - each produces a SPIR-V assembly file" )] pub requests: Vec<CompileRequest>, #[schemars(description = "Optional hardware/feature requirements needed for shader execution")] pub requirements: Option<Vec<ShaderRunnerRequire>>, #[schemars( description = "Shader pipeline configuration (references compiled SPIR-V files by path)" )] pub passes: Vec<ShaderRunnerPass>, #[schemars(description = "Optional vertex data for rendering geometry")] pub vertex_data: Option<Vec<ShaderRunnerVertexData>>, #[schemars(description = "Test commands to execute (drawing, compute, verification, etc.)")] pub tests: Vec<ShaderRunnerTest>, #[schemars(description = "Optional path to save output image (PNG format)")] pub output_path: Option<String>, } #[derive(Clone)] pub struct ShadercVkrunnerMcp {} #[tool(tool_box)] impl ShadercVkrunnerMcp { #[must_use] pub fn new() -> Self { Self {} } #[tool( description = "REQUIRES: (1) Shader compile requests that produce SPIR-V assembly files, (2) Passes referencing these files by path, and (3) Test commands for drawing/computation. Workflow: First compile GLSL to SPIR-V, then reference compiled files in passes, then execute drawing commands in tests to render/compute. Tests MUST include draw/compute commands to produce visible output. Optional: requirements for hardware features, vertex data for geometry, output path for saving image. Every shader MUST have its compiled output referenced in passes." )] fn compile_run_shaders( &self, #[tool(aggr)] request: CompileRunShadersRequest, ) -> Result<CallToolResult, McpError> { use std::fs::File; use std::io::{Read, Write}; use std::path::Path; use std::process::{Command, Stdio}; // Still needed for vkrunner fn io_err(e: std::io::Error) -> McpError { McpError::internal_error("IO operation failed", Some(json!({"error": e.to_string()}))) } for req in &request.requests { let shader_kind = match req.stage { ShaderStage::Vert => ShaderKind::Vertex, ShaderStage::Frag => ShaderKind::Fragment, ShaderStage::Tesc => ShaderKind::TessControl, ShaderStage::Tese => ShaderKind::TessEvaluation, ShaderStage::Geom => ShaderKind::Geometry, ShaderStage::Comp => ShaderKind::Compute, }; let stage_flag = match req.stage { ShaderStage::Vert => "vert", ShaderStage::Frag => "frag", ShaderStage::Tesc => "tesc", ShaderStage::Tese => "tese", ShaderStage::Geom => "geom", ShaderStage::Comp => "comp", }; let tmp_output_path = if req.tmp_output_path.starts_with("/tmp") { req.tmp_output_path.clone() } else { format!("/tmp/{}", req.tmp_output_path) }; if let Some(parent) = Path::new(&tmp_output_path).parent() { std::fs::create_dir_all(parent).map_err(|e| { McpError::internal_error( "Failed to create temporary output directory", Some(json!({"error": e.to_string()})), ) })?; } // Create compiler and options let mut compiler = Compiler::new().ok().ok_or_else(|| { McpError::internal_error( "Failed to create shaderc compiler", Some(json!({"error": "Could not instantiate shaderc compiler"})), ) })?; let mut options = CompileOptions::new().ok().ok_or_else(|| { McpError::internal_error( "Failed to create shaderc compile options", Some(json!({"error": "Could not create compiler options"})), ) })?; // Set options equivalent to the CLI flags options.set_target_env( shaderc::TargetEnv::Vulkan, shaderc::EnvVersion::Vulkan1_4 as u32, ); options.set_optimization_level(OptimizationLevel::Performance); options.set_generate_debug_info(); // Compile to SPIR-V assembly let artifact = match compiler.compile_into_spirv_assembly( &req.source, shader_kind, "shader.glsl", // source name for error reporting "main", // entry point Some(&options), ) { Ok(artifact) => artifact, Err(e) => { let stage_name = match req.stage { ShaderStage::Vert => "Vertex", ShaderStage::Frag => "Fragment", ShaderStage::Tesc => "Tessellation Control", ShaderStage::Tese => "Tessellation Evaluation", ShaderStage::Geom => "Geometry", ShaderStage::Comp => "Compute", }; let error_details = format!("{}", e); return Ok(CallToolResult::success(vec![Content::text(format!( "Shader compilation failed for {} shader:\n\nError:\n{}\n\nShader Source:\n{}\n", stage_name, error_details, req.source ))])); } }; // Write the compiled SPIR-V assembly to the output file, filtering unsupported lines let spv_text = artifact.as_text(); let filtered_spv = spv_text .lines() .filter(|l| !l.trim_start().starts_with("OpModuleProcessed")) .collect::<Vec<_>>() .join("\n"); std::fs::write(&tmp_output_path, filtered_spv).map_err(|e| { McpError::internal_error( "Failed to write compiled shader to file", Some(json!({"error": e.to_string()})), ) })?; } let shader_test_path = "/tmp/vkrunner_test.shader_test"; let mut shader_test_file = File::create(shader_test_path).map_err(io_err)?; if let Some(requirements) = &request.requirements { if !requirements.is_empty() { writeln!(shader_test_file, "[require]").map_err(io_err)?; for req in requirements { match req { ShaderRunnerRequire::CooperativeMatrix { m, n, component_type, } => { writeln!( shader_test_file, "cooperative_matrix m={m} n={n} c={component_type}" ) .map_err(io_err)?; } ShaderRunnerRequire::DepthStencil(format) => { writeln!(shader_test_file, "depthstencil {format}").map_err(io_err)?; } ShaderRunnerRequire::Framebuffer(format) => { writeln!(shader_test_file, "framebuffer {format}").map_err(io_err)?; } ShaderRunnerRequire::ShaderFloat64 => { writeln!(shader_test_file, "shaderFloat64").map_err(io_err)?; } ShaderRunnerRequire::GeometryShader => { writeln!(shader_test_file, "geometryShader").map_err(io_err)?; } ShaderRunnerRequire::WideLines => { writeln!(shader_test_file, "wideLines").map_err(io_err)?; } ShaderRunnerRequire::LogicOp => { writeln!(shader_test_file, "logicOp").map_err(io_err)?; } ShaderRunnerRequire::SubgroupSize(size) => { writeln!(shader_test_file, "subgroup_size {size}").map_err(io_err)?; } ShaderRunnerRequire::FragmentStoresAndAtomics => { writeln!(shader_test_file, "fragmentStoresAndAtomics") .map_err(io_err)?; } ShaderRunnerRequire::BufferDeviceAddress => { writeln!(shader_test_file, "bufferDeviceAddress").map_err(io_err)?; } } } writeln!(shader_test_file).map_err(io_err)?; } } for pass in &request.passes { match pass { ShaderRunnerPass::VertPassthrough => { writeln!(shader_test_file, "[vertex shader passthrough]").map_err(io_err)?; } ShaderRunnerPass::VertSpirv { vert_spvasm_path } => { writeln!(shader_test_file, "[vertex shader spirv]").map_err(io_err)?; let mut spvasm = String::new(); let path = if vert_spvasm_path.starts_with("/tmp") { vert_spvasm_path.clone() } else { format!("/tmp/{vert_spvasm_path}") }; File::open(&path) .map_err(|e| { McpError::internal_error( format!("Failed to open vertex shader SPIR-V file at {path}"), Some(json!({"error": e.to_string()})), ) })? .read_to_string(&mut spvasm) .map_err(|e| { McpError::internal_error( "Failed to read vertex shader SPIR-V file", Some(json!({"error": e.to_string()})), ) })?; writeln!(shader_test_file, "{spvasm}").map_err(io_err)?; } ShaderRunnerPass::FragSpirv { frag_spvasm_path } => { writeln!(shader_test_file, "[fragment shader spirv]").map_err(io_err)?; let mut spvasm = String::new(); let path = if frag_spvasm_path.starts_with("/tmp") { frag_spvasm_path.clone() } else { format!("/tmp/{frag_spvasm_path}") }; File::open(&path) .map_err(|e| { McpError::internal_error( format!("Failed to open fragment shader SPIR-V file at {path}"), Some(json!({"error": e.to_string()})), ) })? .read_to_string(&mut spvasm) .map_err(|e| { McpError::internal_error( "Failed to read fragment shader SPIR-V file", Some(json!({"error": e.to_string()})), ) })?; writeln!(shader_test_file, "{spvasm}").map_err(io_err)?; } ShaderRunnerPass::CompSpirv { comp_spvasm_path } => { writeln!(shader_test_file, "[compute shader spirv]").map_err(io_err)?; let mut spvasm = String::new(); let path = if comp_spvasm_path.starts_with("/tmp") { comp_spvasm_path.clone() } else { format!("/tmp/{comp_spvasm_path}") }; File::open(&path) .map_err(|e| { McpError::internal_error( format!("Failed to open compute shader SPIR-V file at {path}"), Some(json!({"error": e.to_string()})), ) })? .read_to_string(&mut spvasm) .map_err(|e| { McpError::internal_error( "Failed to read compute shader SPIR-V file", Some(json!({"error": e.to_string()})), ) })?; writeln!(shader_test_file, "{spvasm}").map_err(io_err)?; } ShaderRunnerPass::GeomSpirv { geom_spvasm_path } => { writeln!(shader_test_file, "[geometry shader spirv]").map_err(io_err)?; let mut spvasm = String::new(); let path = if geom_spvasm_path.starts_with("/tmp") { geom_spvasm_path.clone() } else { format!("/tmp/{geom_spvasm_path}") }; File::open(&path) .map_err(|e| { McpError::internal_error( format!("Failed to open geometry shader SPIR-V file at {path}"), Some(json!({"error": e.to_string()})), ) })? .read_to_string(&mut spvasm) .map_err(|e| { McpError::internal_error( "Failed to read geometry shader SPIR-V file", Some(json!({"error": e.to_string()})), ) })?; writeln!(shader_test_file, "{spvasm}").map_err(io_err)?; } ShaderRunnerPass::TescSpirv { tesc_spvasm_path } => { writeln!(shader_test_file, "[tessellation control shader spirv]") .map_err(io_err)?; let mut spvasm = String::new(); let path = if tesc_spvasm_path.starts_with("/tmp") { tesc_spvasm_path.clone() } else { format!("/tmp/{tesc_spvasm_path}") }; File::open(&path) .map_err(|e| { McpError::internal_error( format!( "Failed to open tessellation control shader SPIR-V file at {path}" ), Some(json!({"error": e.to_string()})), ) })? .read_to_string(&mut spvasm) .map_err(|e| { McpError::internal_error( "Failed to read tessellation control shader SPIR-V file", Some(json!({"error": e.to_string()})), ) })?; writeln!(shader_test_file, "{spvasm}").map_err(io_err)?; } ShaderRunnerPass::TeseSpirv { tese_spvasm_path } => { writeln!(shader_test_file, "[tessellation evaluation shader spirv]") .map_err(io_err)?; let mut spvasm = String::new(); let path = if tese_spvasm_path.starts_with("/tmp") { tese_spvasm_path.clone() } else { format!("/tmp/{tese_spvasm_path}") }; File::open(&path) .map_err(|e| { McpError::internal_error( format!( "Failed to open tessellation evaluation shader SPIR-V file at {path}" ), Some(json!({"error": e.to_string()})), ) })? .read_to_string(&mut spvasm) .map_err(|e| { McpError::internal_error( "Failed to read tessellation evaluation shader SPIR-V file", Some(json!({"error": e.to_string()})), ) })?; writeln!(shader_test_file, "{spvasm}").map_err(io_err)?; } } writeln!(shader_test_file).map_err(io_err)?; } if let Some(vertex_data) = &request.vertex_data { writeln!(shader_test_file, "[vertex data]").map_err(io_err)?; for data in vertex_data { if let ShaderRunnerVertexData::AttributeFormat { location, format } = data { writeln!(shader_test_file, "{location}/{format}").map_err(io_err)?; } else { match data { ShaderRunnerVertexData::Vec2 { x, y } => { write!(shader_test_file, "{x} {y}").map_err(io_err)?; } ShaderRunnerVertexData::Vec3 { x, y, z } => { write!(shader_test_file, "{x} {y} {z}").map_err(io_err)?; } ShaderRunnerVertexData::Vec4 { x, y, z, w } => { write!(shader_test_file, "{x} {y} {z} {w}").map_err(io_err)?; } ShaderRunnerVertexData::RGB { r, g, b } => { write!(shader_test_file, "{r} {g} {b}").map_err(io_err)?; } ShaderRunnerVertexData::Hex { value } => { write!(shader_test_file, "{value}").map_err(io_err)?; } ShaderRunnerVertexData::GenericComponents { components } => { for component in components { write!(shader_test_file, "{component} ").map_err(io_err)?; } } _ => {} } writeln!(shader_test_file).map_err(io_err)?; } } writeln!(shader_test_file).map_err(io_err)?; } writeln!(shader_test_file, "[test]").map_err(io_err)?; for test_cmd in &request.tests { match test_cmd { ShaderRunnerTest::FragmentEntrypoint { name } => { writeln!(shader_test_file, "fragment entrypoint {name}").map_err(io_err)?; } ShaderRunnerTest::VertexEntrypoint { name } => { writeln!(shader_test_file, "vertex entrypoint {name}").map_err(io_err)?; } ShaderRunnerTest::ComputeEntrypoint { name } => { writeln!(shader_test_file, "compute entrypoint {name}").map_err(io_err)?; } ShaderRunnerTest::GeometryEntrypoint { name } => { writeln!(shader_test_file, "geometry entrypoint {name}").map_err(io_err)?; } ShaderRunnerTest::DrawRect { x, y, width, height, } => { writeln!(shader_test_file, "draw rect {x} {y} {width} {height}") .map_err(io_err)?; } ShaderRunnerTest::DrawArrays { primitive_type, first, count, } => { writeln!( shader_test_file, "draw arrays {primitive_type} {first} {count}" ) .map_err(io_err)?; } ShaderRunnerTest::DrawArraysIndexed { primitive_type, first, count, } => { writeln!( shader_test_file, "draw arrays indexed {primitive_type} {first} {count}" ) .map_err(io_err)?; } ShaderRunnerTest::SSBO { binding, size, data, descriptor_set, } => { let set_prefix = if let Some(set) = descriptor_set { format!("{set}:") } else { String::new() }; if let Some(size) = size { writeln!(shader_test_file, "ssbo {set_prefix}{binding} {size}") .map_err(io_err)?; } else if let Some(_data) = data { writeln!(shader_test_file, "ssbo {set_prefix}{binding} data") .map_err(io_err)?; } } ShaderRunnerTest::SSBOSubData { binding, data_type, offset, values, descriptor_set, } => { let set_prefix = if let Some(set) = descriptor_set { format!("{set}:") } else { String::new() }; write!( shader_test_file, "ssbo {set_prefix}{binding} subdata {data_type} {offset}" ) .map_err(io_err)?; for value in values { write!(shader_test_file, " {value}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } ShaderRunnerTest::UBO { binding, data: _, descriptor_set, } => { let set_prefix = if let Some(set) = descriptor_set { format!("{set}:") } else { String::new() }; writeln!(shader_test_file, "ubo {set_prefix}{binding} data").map_err(io_err)?; } ShaderRunnerTest::UBOSubData { binding, data_type, offset, values, descriptor_set, } => { let set_prefix = if let Some(set) = descriptor_set { format!("{set}:") } else { String::new() }; write!( shader_test_file, "ubo {set_prefix}{binding} subdata {data_type} {offset}" ) .map_err(io_err)?; for value in values { write!(shader_test_file, " {value}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } ShaderRunnerTest::BufferLayout { buffer_type, layout_type, } => { writeln!(shader_test_file, "{buffer_type} layout {layout_type}") .map_err(io_err)?; } ShaderRunnerTest::Push { data_type, offset, values, } => { write!(shader_test_file, "push {data_type} {offset}").map_err(io_err)?; for value in values { write!(shader_test_file, " {value}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } ShaderRunnerTest::PushLayout { layout_type } => { writeln!(shader_test_file, "push layout {layout_type}").map_err(io_err)?; } ShaderRunnerTest::Compute { x, y, z } => { writeln!(shader_test_file, "compute {x} {y} {z}").map_err(io_err)?; } ShaderRunnerTest::Probe { probe_type, format, args, } => { write!(shader_test_file, "probe {probe_type} {format}").map_err(io_err)?; for arg in args { write!(shader_test_file, " {arg}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } ShaderRunnerTest::RelativeProbe { probe_type, format, args, } => { write!(shader_test_file, "relative probe {probe_type} {format}") .map_err(io_err)?; for arg in args { write!(shader_test_file, " {arg}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } ShaderRunnerTest::Tolerance { values } => { write!(shader_test_file, "tolerance").map_err(io_err)?; for value in values { write!(shader_test_file, " {value}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } ShaderRunnerTest::Clear => { writeln!(shader_test_file, "clear").map_err(io_err)?; } ShaderRunnerTest::DepthTestEnable { enable } => { writeln!(shader_test_file, "depthTestEnable {enable}").map_err(io_err)?; } ShaderRunnerTest::DepthWriteEnable { enable } => { writeln!(shader_test_file, "depthWriteEnable {enable}").map_err(io_err)?; } ShaderRunnerTest::DepthCompareOp { op } => { writeln!(shader_test_file, "depthCompareOp {op}").map_err(io_err)?; } ShaderRunnerTest::StencilTestEnable { enable } => { writeln!(shader_test_file, "stencilTestEnable {enable}").map_err(io_err)?; } ShaderRunnerTest::FrontFace { mode } => { writeln!(shader_test_file, "frontFace {mode}").map_err(io_err)?; } ShaderRunnerTest::StencilOp { face, op_name, value, } => { writeln!(shader_test_file, "{face}.{op_name} {value}",).map_err(io_err)?; } ShaderRunnerTest::StencilReference { face, value } => { writeln!(shader_test_file, "{face}.reference {value}",).map_err(io_err)?; } ShaderRunnerTest::StencilCompareOp { face, op } => { writeln!(shader_test_file, "{face}.compareOp {op}",).map_err(io_err)?; } ShaderRunnerTest::ColorWriteMask { mask } => { writeln!(shader_test_file, "colorWriteMask {mask}",).map_err(io_err)?; } ShaderRunnerTest::LogicOpEnable { enable } => { writeln!(shader_test_file, "logicOpEnable {enable}",).map_err(io_err)?; } ShaderRunnerTest::LogicOp { op } => { writeln!(shader_test_file, "logicOp {op}",).map_err(io_err)?; } ShaderRunnerTest::CullMode { mode } => { writeln!(shader_test_file, "cullMode {mode}",).map_err(io_err)?; } ShaderRunnerTest::LineWidth { width } => { writeln!(shader_test_file, "lineWidth {width}").map_err(io_err)?; } ShaderRunnerTest::Require { feature, parameters, } => { write!(shader_test_file, "require {feature}").map_err(io_err)?; for param in parameters { write!(shader_test_file, " {param}").map_err(io_err)?; } writeln!(shader_test_file).map_err(io_err)?; } } } shader_test_file.flush().map_err(io_err)?; let tmp_image_path = "/tmp/vkrunner_output.ppm"; if let Some(output_path) = &request.output_path { if output_path.starts_with("/tmp") { tmp_image_path.to_string() } else { format!("/tmp/{output_path}") } } else { tmp_image_path.to_string() }; let mut vkrunner_args = vec![shader_test_path]; if request.output_path.is_some() { vkrunner_args.push("--image"); vkrunner_args.push(tmp_image_path); } let vkrunner_output = Command::new("vkrunner") .args(&vkrunner_args) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .output() .map_err(|e| { McpError::internal_error( "Failed to run vkrunner", Some(json!({"error": e.to_string()})), ) })?; let stdout = String::from_utf8_lossy(&vkrunner_output.stdout).to_string(); let stderr = String::from_utf8_lossy(&vkrunner_output.stderr).to_string(); let mut result_message = if vkrunner_output.status.success() { format!( "Shader compilation successful using shaderc-rs.\nVkRunner execution successful.\n\nOutput:\n{stdout}\n\n" ) } else { format!( "Shader compilation successful using shaderc-rs.\nVkRunner execution failed.\n\nOutput:\n{stdout}\n\nError:\n{stderr}\n\n", ) }; if let Some(output_path) = &request.output_path { if vkrunner_output.status.success() && Path::new(tmp_image_path).exists() { match read_and_decode_ppm_file(tmp_image_path) { Ok(img) => { if let Some(parent) = Path::new(output_path).parent() { if !parent.as_os_str().is_empty() { std::fs::create_dir_all(parent).map_err(|e| { McpError::internal_error( "Failed to create output directory", Some(json!({"error": e.to_string()})), ) })?; } } img.save(output_path).map_err(|e| { McpError::internal_error( "Failed to save output image", Some(json!({"error": e.to_string()})), ) })?; result_message.push_str(&format!("Image saved to: {output_path}\n")); } Err(e) => { result_message.push_str(&format!("Failed to convert output image: {e}\n")); } } } else if vkrunner_output.status.success() { result_message.push_str("No output image was generated by VkRunner.\n"); } } result_message.push_str("\nShader Test File Contents:\n"); result_message.push_str( &std::fs::read_to_string(shader_test_path) .unwrap_or_else(|_| "Failed to read shader test file".to_string()), ); Ok(CallToolResult::success(vec![Content::text(result_message)])) } } impl Default for ShadercVkrunnerMcp { fn default() -> Self { Self::new() } } const_string!(Echo = "echo"); #[tool(tool_box)] impl ServerHandler for ShadercVkrunnerMcp { fn get_info(&self) -> ServerInfo { ServerInfo { protocol_version: ProtocolVersion::V_2024_11_05, capabilities: ServerCapabilities::builder() .enable_tools() .build(), server_info: Implementation::from_build_env(), instructions: Some("This server provides tools for compiling and running GLSL shaders using Vulkan infrastructure. The typical workflow is: 1. Compile GLSL source code to SPIR-V assembly using the 'compile_run_shaders' tool 2. Specify hardware/feature requirements if needed (e.g., geometry shaders, floating-point formats) 3. Reference compiled SPIR-V files in shader passes 4. Define vertex data if rendering geometry 5. Set up test commands to draw or compute 6. Optionally save the rendered output as an image".to_string()), } } } #[derive(Parser, Debug)] #[command(about)] struct Args { #[clap(short, long, value_parser)] work_dir: Option<PathBuf>, } #[tokio::main] async fn main() -> Result<()> { let args = Args::parse(); if let Some(work_dir) = args.work_dir { std::env::set_current_dir(&work_dir) .map_err(|e| { eprintln!("Failed to set working directory to {work_dir:?}: {e}"); std::process::exit(1); }) .unwrap(); } tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into())) .with_writer(std::io::stderr) .with_ansi(false) .init(); tracing::info!("Starting MCP server"); let service = ShadercVkrunnerMcp::new() .serve(stdio()) .await .inspect_err(|e| { tracing::error!("serving error: {:?}", e); })?; service.waiting().await?; Ok(()) } ``` -------------------------------------------------------------------------------- /vkrunner/vkrunner/requirements.rs: -------------------------------------------------------------------------------- ```rust // vkrunner // // Copyright (C) 2019 Intel Corporation // Copyright 2023 Neil Roberts // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice (including the next // paragraph) shall be included in all copies or substantial portions of the // Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. use crate::vk; use crate::vulkan_funcs; use crate::vulkan_funcs::{NEXT_PTR_OFFSET, FIRST_FEATURE_OFFSET}; use crate::result; use std::mem; use std::collections::{HashMap, HashSet}; use std::ffi::CStr; use std::convert::TryInto; use std::fmt; use std::cell::UnsafeCell; use std::ptr; #[derive(Debug)] struct Extension { // The name of the extension that provides this set of features // stored as a null-terminated byte sequence. It is stored this // way because that is how bindgen generates string literals from // headers. name_bytes: &'static [u8], // The size of the corresponding features struct struct_size: usize, // The enum for this struct struct_type: vk::VkStructureType, // List of feature names in this extension in the order they // appear in the features struct. The names are as written in the // struct definition. features: &'static [&'static str], } impl Extension { fn name(&self) -> &str { // The name comes from static data so it should always be // valid CStr::from_bytes_with_nul(self.name_bytes) .unwrap() .to_str() .unwrap() } } include!("features.rs"); // Lazily generated extensions data used for passing to Vulkan struct LazyExtensions { // All of the null-terminated strings concatenated into a buffer strings: Vec<u8>, // Pointers to the start of each string pointers: Vec<* const u8>, } // Lazily generated structures data used for passing to Vulkan struct LazyStructures { // A buffer used to return a linked chain of structs that can be // accessed from C. list: Vec<u8>, } // Number of features in VkPhysicalDeviceFeatures. The struct is just // a series of VkBool32 so we should be able to safely calculate the // number based on the struct size. const N_PHYSICAL_DEVICE_FEATURES: usize = mem::size_of::<vk::VkPhysicalDeviceFeatures>() / mem::size_of::<vk::VkBool32>(); // Lazily generated VkPhysicalDeviceFeatures struct to pass to Vulkan struct LazyBaseFeatures { // Array big enough to store a VkPhysicalDeciveFeatures struct. // It’s easier to manipulate as an array instead of the actual // struct because we want to update the bools by index. bools: [vk::VkBool32; N_PHYSICAL_DEVICE_FEATURES], } #[derive(Debug, Default, Clone)] pub struct CooperativeMatrix { pub m_size: Option<u32>, pub n_size: Option<u32>, pub k_size: Option<u32>, pub a_type: Option<vk::VkComponentTypeKHR>, pub b_type: Option<vk::VkComponentTypeKHR>, pub c_type: Option<vk::VkComponentTypeKHR>, pub result_type: Option<vk::VkComponentTypeKHR>, pub saturating_accumulation: Option<vk::VkBool32>, pub scope: Option<vk::VkScopeKHR>, pub line: String, } #[derive(Debug)] pub struct Requirements { // Minimum vulkan version version: u32, // Set of extension names required extensions: HashSet<String>, // A map indexed by extension number. The value is an array of // bools representing whether we need each feature in the // extension’s feature struct. features: HashMap<usize, Box<[bool]>>, // An array of bools corresponding to each feature in the base // VkPhysicalDeviceFeatures struct base_features: [bool; N_BASE_FEATURES], // Required subgroup size to be used with // VkPipelineShaderStageRequiredSubgroupSizeCreateInfo. pub required_subgroup_size: Option<u32>, cooperative_matrix_reqs: Vec<CooperativeMatrix>, // The rest of the struct is lazily created from the above data // and shouldn’t be part of the PartialEq implementation. lazy_extensions: UnsafeCell<Option<LazyExtensions>>, lazy_structures: UnsafeCell<Option<LazyStructures>>, lazy_base_features: UnsafeCell<Option<LazyBaseFeatures>>, } /// Error returned by [Requirements::check] #[derive(Debug)] pub enum Error { EnumerateDeviceExtensionPropertiesFailed, ExtensionMissingNullTerminator, ExtensionInvalidUtf8, /// A required base feature from VkPhysicalDeviceFeatures is missing. MissingBaseFeature(usize), /// A required extension is missing. The string is the name of the /// extension. MissingExtension(String), /// A required feature is missing. MissingFeature { extension: usize, feature: usize }, /// The API version reported by the driver is too low VersionTooLow { required_version: u32, actual_version: u32 }, /// Required subgroup size out of supported range. RequiredSubgroupSizeInvalid { size: u32, min: u32, max: u32 }, MissingCooperativeMatrixProperties(String), } impl Error { pub fn result(&self) -> result::Result { match self { Error::EnumerateDeviceExtensionPropertiesFailed => { result::Result::Fail }, Error::ExtensionMissingNullTerminator => result::Result::Fail, Error::ExtensionInvalidUtf8 => result::Result::Fail, Error::MissingBaseFeature(_) => result::Result::Skip, Error::MissingExtension(_) => result::Result::Skip, Error::MissingFeature { .. } => result::Result::Skip, Error::VersionTooLow { .. } => result::Result::Skip, Error::RequiredSubgroupSizeInvalid { .. } => result::Result::Skip, Error::MissingCooperativeMatrixProperties(_) => result::Result::Skip, } } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Error::EnumerateDeviceExtensionPropertiesFailed => { write!(f, "vkEnumerateDeviceExtensionProperties failed") }, Error::ExtensionMissingNullTerminator => { write!( f, "NULL terminator missing in string returned from \ vkEnumerateDeviceExtensionProperties" ) }, Error::ExtensionInvalidUtf8 => { write!( f, "Invalid UTF-8 in string returned from \ vkEnumerateDeviceExtensionProperties" ) }, &Error::MissingBaseFeature(feature_num) => { write!( f, "Missing required feature: {}", BASE_FEATURES[feature_num], ) }, Error::MissingExtension(s) => { write!(f, "Missing required extension: {}", s) }, &Error::MissingFeature { extension, feature } => { write!( f, "Missing required feature “{}” from extension “{}”", EXTENSIONS[extension].features[feature], EXTENSIONS[extension].name(), ) }, &Error::VersionTooLow { required_version, actual_version } => { let (req_major, req_minor, req_patch) = extract_version(required_version); let (actual_major, actual_minor, actual_patch) = extract_version(actual_version); write!( f, "Vulkan API version {}.{}.{} required but the driver \ reported {}.{}.{}", req_major, req_minor, req_patch, actual_major, actual_minor, actual_patch, ) }, &Error::RequiredSubgroupSizeInvalid { size, min, max } => { write!( f, "Required subgroup size {} is not in the supported range {}-{}", size, min, max ) }, Error::MissingCooperativeMatrixProperties(s) => { write!(f, "Physical device can't fulfill cooperative matrix requirements: {}", s.trim()) }, } } } /// Convert a decomposed Vulkan version into an integer. This is the /// same as the `VK_MAKE_VERSION` macro in the Vulkan headers. pub const fn make_version(major: u32, minor: u32, patch: u32) -> u32 { (major << 22) | (minor << 12) | patch } /// Decompose a Vulkan version into its component major, minor and /// patch parts. This is the equivalent of the `VK_VERSION_MAJOR`, /// `VK_VERSION_MINOR` and `VK_VERSION_PATCH` macros in the Vulkan /// header. pub const fn extract_version(version: u32) -> (u32, u32, u32) { (version >> 22, (version >> 12) & 0x3ff, version & 0xfff) } impl Requirements { pub fn new() -> Requirements { Requirements { version: make_version(1, 0, 0), extensions: HashSet::new(), features: HashMap::new(), base_features: [false; N_BASE_FEATURES], required_subgroup_size: None, lazy_extensions: UnsafeCell::new(None), lazy_structures: UnsafeCell::new(None), lazy_base_features: UnsafeCell::new(None), cooperative_matrix_reqs: Vec::new(), } } /// Get the required Vulkan version that was previously set with /// [add_version](Requirements::add_version). pub fn version(&self) -> u32 { self.version } /// Set the minimum required Vulkan version. pub fn add_version(&mut self, major: u32, minor: u32, patch: u32) { self.version = std::cmp::max(self.version, make_version(major, minor, patch)); } pub fn add_cooperative_matrix_req(&mut self, req: CooperativeMatrix) { self.cooperative_matrix_reqs.push(req); } fn update_lazy_extensions(&self) -> &[* const u8] { // SAFETY: The only case where lazy_extensions will be // modified is if it is None and the only way for that to // happen is if something modifies the requirements through a // mutable reference. If that is the case then at that point // there were no other references to the requirements so // nothing can be holding a reference to the lazy data. let extensions = unsafe { match &mut *self.lazy_extensions.get() { Some(data) => return data.pointers.as_ref(), option @ None => { option.insert(LazyExtensions { strings: Vec::new(), pointers: Vec::new(), }) }, } }; // Store a list of offsets into the c_extensions array for // each extension. We can’t directly store the pointers yet // because the Vec will probably be reallocated while we are // adding to it. let mut offsets = Vec::<usize>::new(); for extension in self.extensions.iter() { offsets.push(extensions.strings.len()); extensions.strings.extend_from_slice(extension.as_bytes()); // Add the null terminator extensions.strings.push(0); } let base_ptr = extensions.strings.as_ptr(); extensions.pointers.reserve(offsets.len()); for offset in offsets { // SAFETY: These are all valid offsets into the // c_extensions Vec so they should all be in the same // allocation and no overflow is possible. extensions.pointers.push(unsafe { base_ptr.add(offset) }); } extensions.pointers.as_ref() } /// Return a reference to an array of pointers to C-style strings /// that can be passed to `vkCreateDevice`. pub fn c_extensions(&self) -> &[* const u8] { self.update_lazy_extensions() } // Make a linked list of features structures that is suitable for // passing to VkPhysicalDeviceFeatures2 or vkCreateDevice. All of // the bools are set to 0. The vec with the structs is returned as // well as a list of offsets to the bools along with the extension // number. fn make_empty_structures( &self ) -> (Vec<u8>, Vec<(usize, usize)>) { let mut structures = Vec::new(); // Keep an array of offsets and corresponding extension num in a // vec for each structure that will be returned. The offset is // also used to update the pNext pointers. We have to do this // after filling the vec because the vec will probably // reallocate while we are adding to it which would invalidate // the pointers. let mut offsets = Vec::<(usize, usize)>::new(); for (&extension_num, features) in self.features.iter() { let extension = &EXTENSIONS[extension_num]; // Make sure the struct size is big enough to hold all of // the bools assert!( extension.struct_size >= (FIRST_FEATURE_OFFSET + features.len() * mem::size_of::<vk::VkBool32>()) ); let offset = structures.len(); structures.resize(offset + extension.struct_size, 0); // The structure type is the first field in the structure let type_end = offset + mem::size_of::<vk::VkStructureType>(); structures[offset..type_end] .copy_from_slice(&extension.struct_type.to_ne_bytes()); offsets.push((offset + FIRST_FEATURE_OFFSET, extension_num)); } let mut last_offset = 0; for &(offset, _) in &offsets[1..] { let offset = offset - FIRST_FEATURE_OFFSET; // SAFETY: The offsets are all valid offsets into the // structures vec so they should all be in the same // allocation and no overflow should occur. let ptr = unsafe { structures.as_ptr().add(offset) }; let ptr_start = last_offset + NEXT_PTR_OFFSET; let ptr_end = ptr_start + mem::size_of::<*const u8>(); structures[ptr_start..ptr_end] .copy_from_slice(&(ptr as usize).to_ne_bytes()); last_offset = offset; } (structures, offsets) } fn update_lazy_structures(&self) -> &[u8] { // SAFETY: The only case where lazy_structures will be // modified is if it is None and the only way for that to // happen is if something modifies the requirements through a // mutable reference. If that is the case then at that point // there were no other references to the requirements so // nothing can be holding a reference to the lazy data. let (structures, offsets) = unsafe { match &mut *self.lazy_structures.get() { Some(data) => return data.list.as_slice(), option @ None => { let (list, offsets) = self.make_empty_structures(); (option.insert(LazyStructures { list }), offsets) }, } }; for (offset, extension_num) in offsets { let features = &self.features[&extension_num]; for (feature_num, &feature) in features.iter().enumerate() { let feature_start = offset + mem::size_of::<vk::VkBool32>() * feature_num; let feature_end = feature_start + mem::size_of::<vk::VkBool32>(); structures.list[feature_start..feature_end] .copy_from_slice(&(feature as vk::VkBool32).to_ne_bytes()); } } structures.list.as_slice() } /// Return a pointer to a linked list of feature structures that /// can be passed to `vkCreateDevice`, or `None` if no feature /// structs are required. pub fn c_structures(&self) -> Option<&[u8]> { if self.features.is_empty() { None } else { Some(self.update_lazy_structures()) } } fn update_lazy_base_features(&self) -> &[vk::VkBool32] { // SAFETY: The only case where lazy_structures will be // modified is if it is None and the only way for that to // happen is if something modifies the requirements through a // mutable reference. If that is the case then at that point // there were no other references to the requirements so // nothing can be holding a reference to the lazy data. let base_features = unsafe { match &mut *self.lazy_base_features.get() { Some(data) => return &data.bools, option @ None => { option.insert(LazyBaseFeatures { bools: [0; N_PHYSICAL_DEVICE_FEATURES], }) }, } }; for (feature_num, &feature) in self.base_features.iter().enumerate() { base_features.bools[feature_num] = feature as vk::VkBool32; } &base_features.bools } /// Return a pointer to a `VkPhysicalDeviceFeatures` struct that /// can be passed to `vkCreateDevice`. pub fn c_base_features(&self) -> &vk::VkPhysicalDeviceFeatures { unsafe { &*self.update_lazy_base_features().as_ptr().cast() } } fn add_extension_name(&mut self, name: &str) { // It would be nice to use get_or_insert_owned here if it // becomes stable. It’s probably better not to use // HashSet::replace directly because if it’s already in the // set then we’ll pointlessly copy the str slice and // immediately free it. if !self.extensions.contains(name) { self.extensions.insert(name.to_owned()); // SAFETY: self is immutable so there should be no other // reference to the lazy data unsafe { *self.lazy_extensions.get() = None; } } } /// Adds a requirement to the list of requirements. /// /// The name can be either a feature as written in the /// corresponding features struct or the name of an extension. If /// it is a feature it needs to be either the name of a field in /// the `VkPhysicalDeviceFeatures` struct or a field in any of the /// features structs of the extensions that vkrunner knows about. /// In the latter case the name of the corresponding extension /// will be automatically added as a requirement. pub fn add(&mut self, name: &str) { if let Some((extension_num, feature_num)) = find_feature(name) { let extension = &EXTENSIONS[extension_num]; self.add_extension_name(extension.name()); let features = self .features .entry(extension_num) .or_insert_with(|| { vec![false; extension.features.len()].into_boxed_slice() }); if !features[feature_num] { // SAFETY: self is immutable so there should be no other // reference to the lazy data unsafe { *self.lazy_structures.get() = None; } features[feature_num] = true; } } else if let Some(num) = find_base_feature(name) { if !self.base_features[num] { self.base_features[num] = true; // SAFETY: self is immutable so there should be no other // reference to the lazy data unsafe { *self.lazy_structures.get() = None; } } } else { self.add_extension_name(name); } } fn check_base_features( &self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice ) -> Result<(), Error> { let mut actual_features: [vk::VkBool32; N_BASE_FEATURES] = [0; N_BASE_FEATURES]; unsafe { vkinst.vkGetPhysicalDeviceFeatures.unwrap()( device, actual_features.as_mut_ptr().cast() ); } for (feature_num, &required) in self.base_features.iter().enumerate() { if required && actual_features[feature_num] == 0 { return Err(Error::MissingBaseFeature(feature_num)); } } Ok(()) } fn get_device_extensions( &self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice ) -> Result<HashSet::<String>, Error> { let mut property_count = 0u32; let res = unsafe { vkinst.vkEnumerateDeviceExtensionProperties.unwrap()( device, std::ptr::null(), // layerName &mut property_count as *mut u32, std::ptr::null_mut(), // properties ) }; if res != vk::VK_SUCCESS { return Err(Error::EnumerateDeviceExtensionPropertiesFailed); } let mut extensions = Vec::<vk::VkExtensionProperties>::with_capacity( property_count as usize ); unsafe { let res = vkinst.vkEnumerateDeviceExtensionProperties.unwrap()( device, std::ptr::null(), // layerName &mut property_count as *mut u32, extensions.as_mut_ptr(), ); if res != vk::VK_SUCCESS { return Err(Error::EnumerateDeviceExtensionPropertiesFailed); } // SAFETY: The FFI call to // vkEnumerateDeviceExtensionProperties should have filled // the extensions array with valid values so we can safely // set the length to the capacity we allocated earlier extensions.set_len(property_count as usize); }; let mut extensions_set = HashSet::new(); for extension in extensions.iter() { let name = &extension.extensionName; // Make sure it has a NULL terminator if let None = name.iter().find(|&&b| b == 0) { return Err(Error::ExtensionMissingNullTerminator); } // SAFETY: we just checked that the array has a null terminator let name = unsafe { CStr::from_ptr(name.as_ptr()) }; let name = match name.to_str() { Err(_) => { return Err(Error::ExtensionInvalidUtf8); }, Ok(s) => s, }; extensions_set.insert(name.to_owned()); } Ok(extensions_set) } fn check_extensions( &self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice ) -> Result<(), Error> { if self.extensions.is_empty() { return Ok(()); } let actual_extensions = self.get_device_extensions(vkinst, device)?; for extension in self.extensions.iter() { if !actual_extensions.contains(extension) { return Err(Error::MissingExtension(extension.to_string())); } } Ok(()) } fn check_structures( &self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice ) -> Result<(), Error> { if self.features.is_empty() { return Ok(()); } // If vkGetPhysicalDeviceFeatures2KHR isn’t available then we // can probably assume that none of the extensions are // available. let get_features = match vkinst.vkGetPhysicalDeviceFeatures2KHR { None => { // Find the first feature and report that as missing for (&extension, features) in self.features.iter() { for (feature, &enabled) in features.iter().enumerate() { if enabled { return Err(Error::MissingFeature { extension, feature, }); } } } unreachable!("Requirements::features should be empty if no \ features are required"); }, Some(func) => func, }; let (mut actual_structures, offsets) = self.make_empty_structures(); let mut features_query = vk::VkPhysicalDeviceFeatures2 { sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2, pNext: actual_structures.as_mut_ptr().cast(), features: Default::default(), }; unsafe { get_features( device, &mut features_query as *mut vk::VkPhysicalDeviceFeatures2, ); } for (offset, extension_num) in offsets { let features = &self.features[&extension_num]; for (feature_num, &feature) in features.iter().enumerate() { if !feature { continue; } let feature_start = offset + mem::size_of::<vk::VkBool32>() * feature_num; let feature_end = feature_start + mem::size_of::<vk::VkBool32>(); let actual_value = vk::VkBool32::from_ne_bytes( actual_structures[feature_start..feature_end] .try_into() .unwrap() ); if actual_value == 0 { return Err(Error::MissingFeature { extension: extension_num, feature: feature_num, }); } } } Ok(()) } fn check_version( &self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice, ) -> Result<(), Error> { let actual_version = if self.version() >= make_version(1, 1, 0) { let mut subgroup_size_props = vk::VkPhysicalDeviceSubgroupSizeControlProperties { sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_PROPERTIES, ..Default::default() }; let mut props = vk::VkPhysicalDeviceProperties2 { sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2, pNext: std::ptr::addr_of_mut!(subgroup_size_props).cast(), ..Default::default() }; unsafe { vkinst.vkGetPhysicalDeviceProperties2.unwrap()( device, &mut props as *mut vk::VkPhysicalDeviceProperties2, ); } if let Some(size) = self.required_subgroup_size { let min = subgroup_size_props.minSubgroupSize; let max = subgroup_size_props.maxSubgroupSize; if size < min || size > max { return Err(Error::RequiredSubgroupSizeInvalid { size, min, max }); } } props.properties.apiVersion } else { let mut props = vk::VkPhysicalDeviceProperties::default(); unsafe { vkinst.vkGetPhysicalDeviceProperties.unwrap()( device, &mut props as *mut vk::VkPhysicalDeviceProperties, ); } props.apiVersion }; if actual_version < self.version() { Err(Error::VersionTooLow { required_version: self.version(), actual_version, }) } else { Ok(()) } } fn check_cooperative_matrix(&self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice) -> Result<(), Error> { if self.cooperative_matrix_reqs.len() == 0 { return Ok(()); } let mut count = 0u32; unsafe { vkinst.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR.unwrap()( device, ptr::addr_of_mut!(count), ptr::null_mut(), ); } let mut props = Vec::<vk::VkCooperativeMatrixPropertiesKHR>::new(); props.resize_with(count as usize, Default::default); for prop in &mut props { prop.sType = vk::VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; } unsafe { vkinst.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR.unwrap()( device, ptr::addr_of_mut!(count), props.as_mut_ptr(), ); } for req in &self.cooperative_matrix_reqs { let mut found = false; for prop in &props { if let Some(m) = req.m_size { if m != prop.MSize { continue } } if let Some(n) = req.n_size { if n != prop.NSize { continue } } if let Some(k) = req.k_size { if k != prop.KSize { continue } } if let Some(a) = req.a_type { if a != prop.AType { continue } } if let Some(b) = req.b_type { if b != prop.BType { continue } } if let Some(c) = req.c_type { if c != prop.CType { continue } } if let Some(result) = req.result_type { if result != prop.ResultType { continue } } if let Some(sa) = req.saturating_accumulation { if sa != prop.saturatingAccumulation { continue } } if let Some(scope) = req.scope { if scope != prop.scope { continue } } found = true; break } if !found { return Err(Error::MissingCooperativeMatrixProperties(req.line.clone())); } } Ok(()) } pub fn check( &self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice ) -> Result<(), Error> { self.check_base_features(vkinst, device)?; self.check_extensions(vkinst, device)?; self.check_structures(vkinst, device)?; self.check_version(vkinst, device)?; self.check_cooperative_matrix(vkinst, device)?; Ok(()) } } // Looks for a feature with the given name. If found it returns the // index of the extension it was found in and the index of the feature // name. fn find_feature(name: &str) -> Option<(usize, usize)> { for (extension_num, extension) in EXTENSIONS.iter().enumerate() { for (feature_num, &feature) in extension.features.iter().enumerate() { if feature == name { return Some((extension_num, feature_num)); } } } None } fn find_base_feature(name: &str) -> Option<usize> { BASE_FEATURES.iter().position(|&f| f == name) } impl PartialEq for Requirements { fn eq(&self, other: &Requirements) -> bool { self.version == other.version && self.extensions == other.extensions && self.features == other.features && self.base_features == other.base_features && self.required_subgroup_size == other.required_subgroup_size } } impl Eq for Requirements { } // Manual implementation of clone to avoid copying the lazy state impl Clone for Requirements { fn clone(&self) -> Requirements { Requirements { version: self.version, extensions: self.extensions.clone(), features: self.features.clone(), base_features: self.base_features.clone(), required_subgroup_size: self.required_subgroup_size, lazy_extensions: UnsafeCell::new(None), lazy_structures: UnsafeCell::new(None), lazy_base_features: UnsafeCell::new(None), cooperative_matrix_reqs: self.cooperative_matrix_reqs.clone(), } } fn clone_from(&mut self, source: &Requirements) { self.version = source.version; self.extensions.clone_from(&source.extensions); self.features.clone_from(&source.features); self.base_features.clone_from(&source.base_features); self.required_subgroup_size = source.required_subgroup_size; self.cooperative_matrix_reqs.clone_from(&source.cooperative_matrix_reqs); // SAFETY: self is immutable so there should be no other // reference to the lazy data unsafe { *self.lazy_extensions.get() = None; *self.lazy_structures.get() = None; *self.lazy_base_features.get() = None; } } } #[cfg(test)] mod test { use super::*; use std::ffi::{c_char, c_void}; use crate::fake_vulkan; fn get_struct_type(structure: &[u8]) -> vk::VkStructureType { let slice = &structure[0..mem::size_of::<vk::VkStructureType>()]; let mut bytes = [0; mem::size_of::<vk::VkStructureType>()]; bytes.copy_from_slice(slice); vk::VkStructureType::from_ne_bytes(bytes) } fn get_next_structure(structure: &[u8]) -> &[u8] { let slice = &structure[ NEXT_PTR_OFFSET..NEXT_PTR_OFFSET + mem::size_of::<usize>() ]; let mut bytes = [0; mem::size_of::<usize>()]; bytes.copy_from_slice(slice); let pointer = usize::from_ne_bytes(bytes); let offset = pointer - structure.as_ptr() as usize; &structure[offset..] } fn find_bools_for_extension<'a, 'b>( mut structures: &'a [u8], extension: &'b Extension ) -> &'a [vk::VkBool32] { while !structures.is_empty() { let struct_type = get_struct_type(structures); if struct_type == extension.struct_type { unsafe { let bools_ptr = structures .as_ptr() .add(FIRST_FEATURE_OFFSET); return std::slice::from_raw_parts( bools_ptr as *const vk::VkBool32, extension.features.len() ); } } structures = get_next_structure(structures); } unreachable!("No structure found for extension “{}”", extension.name()); } unsafe fn extension_in_c_extensions( reqs: &mut Requirements, ext: &str ) -> bool { for &p in reqs.c_extensions().iter() { if CStr::from_ptr(p as *const c_char).to_str().unwrap() == ext { return true; } } false } #[test] fn test_all_features() { let mut reqs = Requirements::new(); for extension in EXTENSIONS.iter() { for feature in extension.features.iter() { reqs.add(feature); } } for feature in BASE_FEATURES.iter() { reqs.add(feature); } for (extension_num, extension) in EXTENSIONS.iter().enumerate() { // All of the extensions should be in the set assert!(reqs.extensions.contains(extension.name())); // All of the features of every extension should be true assert!(reqs.features[&extension_num].iter().all(|&b| b)); } // All of the base features should be enabled assert!(reqs.base_features.iter().all(|&b| b)); let base_features = reqs.c_base_features(); assert_eq!(base_features.robustBufferAccess, 1); assert_eq!(base_features.fullDrawIndexUint32, 1); assert_eq!(base_features.imageCubeArray, 1); assert_eq!(base_features.independentBlend, 1); assert_eq!(base_features.geometryShader, 1); assert_eq!(base_features.tessellationShader, 1); assert_eq!(base_features.sampleRateShading, 1); assert_eq!(base_features.dualSrcBlend, 1); assert_eq!(base_features.logicOp, 1); assert_eq!(base_features.multiDrawIndirect, 1); assert_eq!(base_features.drawIndirectFirstInstance, 1); assert_eq!(base_features.depthClamp, 1); assert_eq!(base_features.depthBiasClamp, 1); assert_eq!(base_features.fillModeNonSolid, 1); assert_eq!(base_features.depthBounds, 1); assert_eq!(base_features.wideLines, 1); assert_eq!(base_features.largePoints, 1); assert_eq!(base_features.alphaToOne, 1); assert_eq!(base_features.multiViewport, 1); assert_eq!(base_features.samplerAnisotropy, 1); assert_eq!(base_features.textureCompressionETC2, 1); assert_eq!(base_features.textureCompressionASTC_LDR, 1); assert_eq!(base_features.textureCompressionBC, 1); assert_eq!(base_features.occlusionQueryPrecise, 1); assert_eq!(base_features.pipelineStatisticsQuery, 1); assert_eq!(base_features.vertexPipelineStoresAndAtomics, 1); assert_eq!(base_features.fragmentStoresAndAtomics, 1); assert_eq!(base_features.shaderTessellationAndGeometryPointSize, 1); assert_eq!(base_features.shaderImageGatherExtended, 1); assert_eq!(base_features.shaderStorageImageExtendedFormats, 1); assert_eq!(base_features.shaderStorageImageMultisample, 1); assert_eq!(base_features.shaderStorageImageReadWithoutFormat, 1); assert_eq!(base_features.shaderStorageImageWriteWithoutFormat, 1); assert_eq!(base_features.shaderUniformBufferArrayDynamicIndexing, 1); assert_eq!(base_features.shaderSampledImageArrayDynamicIndexing, 1); assert_eq!(base_features.shaderStorageBufferArrayDynamicIndexing, 1); assert_eq!(base_features.shaderStorageImageArrayDynamicIndexing, 1); assert_eq!(base_features.shaderClipDistance, 1); assert_eq!(base_features.shaderCullDistance, 1); assert_eq!(base_features.shaderFloat64, 1); assert_eq!(base_features.shaderInt64, 1); assert_eq!(base_features.shaderInt16, 1); assert_eq!(base_features.shaderResourceResidency, 1); assert_eq!(base_features.shaderResourceMinLod, 1); assert_eq!(base_features.sparseBinding, 1); assert_eq!(base_features.sparseResidencyBuffer, 1); assert_eq!(base_features.sparseResidencyImage2D, 1); assert_eq!(base_features.sparseResidencyImage3D, 1); assert_eq!(base_features.sparseResidency2Samples, 1); assert_eq!(base_features.sparseResidency4Samples, 1); assert_eq!(base_features.sparseResidency8Samples, 1); assert_eq!(base_features.sparseResidency16Samples, 1); assert_eq!(base_features.sparseResidencyAliased, 1); assert_eq!(base_features.variableMultisampleRate, 1); assert_eq!(base_features.inheritedQueries, 1); // All of the values should be set in the C structs for extension in EXTENSIONS.iter() { let structs = reqs.c_structures().unwrap(); assert!( find_bools_for_extension(structs, extension) .iter() .all(|&b| b == 1) ); assert!(unsafe { &*reqs.lazy_structures.get() }.is_some()); } // All of the extensions should be in the c_extensions for extension in EXTENSIONS.iter() { assert!(unsafe { extension_in_c_extensions(&mut reqs, extension.name()) }); assert!(unsafe { &*reqs.lazy_extensions.get() }.is_some()); } // Sanity check that a made-up extension isn’t in c_extensions assert!(!unsafe { extension_in_c_extensions(&mut reqs, "not_a_real_ext") }); } #[test] fn test_version() { let mut reqs = Requirements::new(); reqs.add_version(2, 1, 5); assert_eq!(reqs.version(), 0x801005); } #[test] fn test_empty() { let reqs = Requirements::new(); assert!(unsafe { &*reqs.lazy_extensions.get() }.is_none()); assert!(unsafe { &*reqs.lazy_structures.get() }.is_none()); let base_features_ptr = reqs.c_base_features() as *const vk::VkPhysicalDeviceFeatures as *const vk::VkBool32; unsafe { let base_features = std::slice::from_raw_parts( base_features_ptr, N_BASE_FEATURES ); assert!(base_features.iter().all(|&b| b == 0)); } } #[test] fn test_eq() { let mut reqs_a = Requirements::new(); let mut reqs_b = Requirements::new(); assert_eq!(reqs_a, reqs_b); reqs_a.add("advancedBlendCoherentOperations"); assert_ne!(reqs_a, reqs_b); reqs_b.add("advancedBlendCoherentOperations"); assert_eq!(reqs_a, reqs_b); // Getting the C data shouldn’t affect the equality reqs_a.c_structures(); reqs_a.c_base_features(); reqs_a.c_extensions(); assert_eq!(reqs_a, reqs_b); // The order of adding shouldn’t matter reqs_a.add("fake_extension"); reqs_a.add("another_fake_extension"); reqs_b.add("another_fake_extension"); reqs_b.add("fake_extension"); assert_eq!(reqs_a, reqs_b); reqs_a.add("wideLines"); assert_ne!(reqs_a, reqs_b); reqs_b.add("wideLines"); assert_eq!(reqs_a, reqs_b); reqs_a.add_version(3, 1, 2); assert_ne!(reqs_a, reqs_b); reqs_b.add_version(3, 1, 2); assert_eq!(reqs_a, reqs_b); } #[test] fn test_clone() { let mut reqs = Requirements::new(); reqs.add("wideLines"); reqs.add("fake_extension"); reqs.add("advancedBlendCoherentOperations"); assert_eq!(reqs.clone(), reqs); assert!(unsafe { &*reqs.lazy_structures.get() }.is_none()); assert_eq!(reqs.c_extensions().len(), 2); assert_eq!(reqs.c_base_features().wideLines, 1); let empty = Requirements::new(); reqs.clone_from(&empty); assert!(unsafe { &*reqs.lazy_structures.get() }.is_none()); assert_eq!(reqs.c_extensions().len(), 0); assert_eq!(reqs.c_base_features().wideLines, 0); assert_eq!(reqs, empty); } struct FakeVulkanData { fake_vulkan: Box<fake_vulkan::FakeVulkan>, _vklib: vulkan_funcs::Library, vkinst: vulkan_funcs::Instance, instance: vk::VkInstance, device: vk::VkPhysicalDevice, } impl FakeVulkanData { fn new() -> FakeVulkanData { let mut fake_vulkan = fake_vulkan::FakeVulkan::new(); fake_vulkan.physical_devices.push(Default::default()); fake_vulkan.set_override(); let vklib = vulkan_funcs::Library::new().unwrap(); let mut instance = std::ptr::null_mut(); unsafe { let res = vklib.vkCreateInstance.unwrap()( std::ptr::null(), // pCreateInfo std::ptr::null(), // pAllocator std::ptr::addr_of_mut!(instance), ); assert_eq!(res, vk::VK_SUCCESS); } extern "C" fn get_instance_proc( func_name: *const c_char, user_data: *const c_void, ) -> *const c_void { unsafe { let fake_vulkan = &*user_data.cast::<fake_vulkan::FakeVulkan>(); std::mem::transmute( fake_vulkan.get_function(func_name.cast()) ) } } let vkinst = unsafe { vulkan_funcs::Instance::new( get_instance_proc, fake_vulkan.as_ref() as *const fake_vulkan::FakeVulkan as *const c_void, ) }; let mut device = std::ptr::null_mut(); unsafe { let mut count = 1; let res = vkinst.vkEnumeratePhysicalDevices.unwrap()( instance, std::ptr::addr_of_mut!(count), std::ptr::addr_of_mut!(device), ); assert_eq!(res, vk::VK_SUCCESS); } FakeVulkanData { fake_vulkan, _vklib: vklib, vkinst, instance, device, } } } impl Drop for FakeVulkanData { fn drop(&mut self) { if !std::thread::panicking() { unsafe { self.vkinst.vkDestroyInstance.unwrap()( self.instance, std::ptr::null(), // allocator ); } } } } fn check_base_features<'a>( reqs: &'a Requirements, features: &vk::VkPhysicalDeviceFeatures ) -> Result<(), Error> { let mut data = FakeVulkanData::new(); data.fake_vulkan.physical_devices[0].features = features.clone(); reqs.check(&data.vkinst, data.device) } #[test] fn test_check_base_features() { let mut features = Default::default(); assert!(matches!( check_base_features(&Requirements::new(), &features), Ok(()), )); features.geometryShader = vk::VK_TRUE; assert!(matches!( check_base_features(&Requirements::new(), &features), Ok(()), )); let mut reqs = Requirements::new(); reqs.add("geometryShader"); assert!(matches!( check_base_features(&reqs, &features), Ok(()), )); reqs.add("depthBounds"); match check_base_features(&reqs, &features) { Ok(()) => unreachable!("Requirements::check was supposed to fail"), Err(e) => { assert!(matches!(e, Error::MissingBaseFeature(_))); assert_eq!( e.to_string(), "Missing required feature: depthBounds" ); assert_eq!(e.result(), result::Result::Skip); }, } } #[test] fn test_check_extensions() { let mut data = FakeVulkanData::new(); let mut reqs = Requirements::new(); assert!(matches!( reqs.check(&data.vkinst, data.device), Ok(()), )); reqs.add("fake_extension"); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected extensions check to fail"), Err(e) => { assert!(matches!( e, Error::MissingExtension(_), )); assert_eq!( e.to_string(), "Missing required extension: fake_extension" ); assert_eq!(e.result(), result::Result::Skip); }, }; data.fake_vulkan.physical_devices[0].add_extension("fake_extension"); assert!(matches!( reqs.check(&data.vkinst, data.device), Ok(()), )); // Add an extension via a feature reqs.add("multiviewGeometryShader"); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected extensions check to fail"), Err(e) => { assert!(matches!( e, Error::MissingExtension(_) )); assert_eq!( e.to_string(), "Missing required extension: VK_KHR_multiview", ); assert_eq!(e.result(), result::Result::Skip); }, }; data.fake_vulkan.physical_devices[0].add_extension("VK_KHR_multiview"); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected extensions check to fail"), Err(e) => { assert!(matches!( e, Error::MissingFeature { .. }, )); assert_eq!( e.to_string(), "Missing required feature “multiviewGeometryShader” from \ extension “VK_KHR_multiview”", ); assert_eq!(e.result(), result::Result::Skip); }, }; // Make an unterminated UTF-8 character let extension_name = &mut data .fake_vulkan .physical_devices[0] .extensions[0] .extensionName; extension_name[0] = -1i8 as c_char; extension_name[1] = 0; match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected extensions check to fail"), Err(e) => { assert_eq!( e.to_string(), "Invalid UTF-8 in string returned from \ vkEnumerateDeviceExtensionProperties" ); assert!(matches!(e, Error::ExtensionInvalidUtf8)); assert_eq!(e.result(), result::Result::Fail); }, }; // No null-terminator in the extension extension_name.fill(32); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected extensions check to fail"), Err(e) => { assert_eq!( e.to_string(), "NULL terminator missing in string returned from \ vkEnumerateDeviceExtensionProperties" ); assert!(matches!(e, Error::ExtensionMissingNullTerminator)); assert_eq!(e.result(), result::Result::Fail); }, }; } #[test] fn test_check_structures() { let mut data = FakeVulkanData::new(); let mut reqs = Requirements::new(); reqs.add("multiview"); data.fake_vulkan.physical_devices[0].add_extension("VK_KHR_multiview"); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected features check to fail"), Err(e) => { assert_eq!( e.to_string(), "Missing required feature “multiview” \ from extension “VK_KHR_multiview”", ); assert!(matches!( e, Error::MissingFeature { .. }, )); assert_eq!(e.result(), result::Result::Skip); }, }; data.fake_vulkan.physical_devices[0].multiview.multiview = vk::VK_TRUE; assert!(matches!( reqs.check(&data.vkinst, data.device), Ok(()), )); reqs.add("shaderBufferInt64Atomics"); data.fake_vulkan.physical_devices[0].add_extension( "VK_KHR_shader_atomic_int64" ); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected features check to fail"), Err(e) => { assert_eq!( e.to_string(), "Missing required feature “shaderBufferInt64Atomics” \ from extension “VK_KHR_shader_atomic_int64”", ); assert!(matches!( e, Error::MissingFeature { .. }, )); assert_eq!(e.result(), result::Result::Skip); }, }; data.fake_vulkan .physical_devices[0] .shader_atomic .shaderBufferInt64Atomics = vk::VK_TRUE; assert!(matches!( reqs.check(&data.vkinst, data.device), Ok(()), )); } #[test] fn test_check_version() { let mut data = FakeVulkanData::new(); let mut reqs = Requirements::new(); reqs.add_version(1, 2, 0); data.fake_vulkan.physical_devices[0].properties.apiVersion = make_version(1, 1, 0); match reqs.check(&data.vkinst, data.device) { Ok(()) => unreachable!("expected version check to fail"), Err(e) => { // The check will report that the version is 1.0.0 // because vkEnumerateInstanceVersion is not available assert_eq!( e.to_string(), "Vulkan API version 1.2.0 required but the driver \ reported 1.1.0", ); assert!(matches!( e, Error::VersionTooLow { .. }, )); assert_eq!(e.result(), result::Result::Skip); }, }; // Set a valid version data.fake_vulkan.physical_devices[0].properties.apiVersion = make_version(1, 3, 0); assert!(matches!( reqs.check(&data.vkinst, data.device), Ok(()), )); } } ```