This is page 4 of 7. Use http://codebase.md/mehmetoguzderin/shaderc-vkrunner-mcp?page={x} to view the full context.
# Directory Structure
```
├── .devcontainer
│ ├── devcontainer.json
│ ├── docker-compose.yml
│ └── Dockerfile
├── .gitattributes
├── .github
│ └── workflows
│ └── build-push-image.yml
├── .gitignore
├── .vscode
│ └── mcp.json
├── Cargo.lock
├── Cargo.toml
├── Dockerfile
├── LICENSE
├── README.adoc
├── shaderc-vkrunner-mcp.jpg
├── src
│ └── main.rs
└── vkrunner
├── .editorconfig
├── .gitignore
├── .gitlab-ci.yml
├── build.rs
├── Cargo.toml
├── COPYING
├── examples
│ ├── compute-shader.shader_test
│ ├── cooperative-matrix.shader_test
│ ├── depth-buffer.shader_test
│ ├── desc_set_and_binding.shader_test
│ ├── entrypoint.shader_test
│ ├── float-framebuffer.shader_test
│ ├── frexp.shader_test
│ ├── geometry.shader_test
│ ├── indices.shader_test
│ ├── layouts.shader_test
│ ├── properties.shader_test
│ ├── push-constants.shader_test
│ ├── require-subgroup-size.shader_test
│ ├── row-major.shader_test
│ ├── spirv.shader_test
│ ├── ssbo.shader_test
│ ├── tolerance.shader_test
│ ├── tricolore.shader_test
│ ├── ubo.shader_test
│ ├── vertex-data-piglit.shader_test
│ └── vertex-data.shader_test
├── include
│ ├── vk_video
│ │ ├── vulkan_video_codec_av1std_decode.h
│ │ ├── vulkan_video_codec_av1std_encode.h
│ │ ├── vulkan_video_codec_av1std.h
│ │ ├── vulkan_video_codec_h264std_decode.h
│ │ ├── vulkan_video_codec_h264std_encode.h
│ │ ├── vulkan_video_codec_h264std.h
│ │ ├── vulkan_video_codec_h265std_decode.h
│ │ ├── vulkan_video_codec_h265std_encode.h
│ │ ├── vulkan_video_codec_h265std.h
│ │ └── vulkan_video_codecs_common.h
│ └── vulkan
│ ├── vk_platform.h
│ ├── vulkan_core.h
│ └── vulkan.h
├── precompile-script.py
├── README.md
├── scripts
│ └── update-vulkan.sh
├── src
│ └── main.rs
├── test-build.sh
└── vkrunner
├── allocate_store.rs
├── buffer.rs
├── compiler
│ └── fake_process.rs
├── compiler.rs
├── config.rs
├── context.rs
├── enum_table.rs
├── env_var_test.rs
├── executor.rs
├── fake_vulkan.rs
├── features.rs
├── flush_memory.rs
├── format_table.rs
├── format.rs
├── half_float.rs
├── hex.rs
├── inspect.rs
├── lib.rs
├── logger.rs
├── make-enums.py
├── make-features.py
├── make-formats.py
├── make-pipeline-key-data.py
├── make-vulkan-funcs-data.py
├── parse_num.rs
├── pipeline_key_data.rs
├── pipeline_key.rs
├── pipeline_set.rs
├── requirements.rs
├── result.rs
├── script.rs
├── shader_stage.rs
├── slot.rs
├── small_float.rs
├── source.rs
├── stream.rs
├── temp_file.rs
├── tester.rs
├── tolerance.rs
├── util.rs
├── vbo.rs
├── vk.rs
├── vulkan_funcs_data.rs
├── vulkan_funcs.rs
├── window_format.rs
└── window.rs
```
# Files
--------------------------------------------------------------------------------
/vkrunner/vkrunner/pipeline_set.rs:
--------------------------------------------------------------------------------
```rust
// vkrunner
//
// Copyright (C) 2018 Intel Corporation
// Copypright 2023 Neil Roberts
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice (including the next
// paragraph) shall be included in all copies or substantial portions of the
// Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::window::Window;
use crate::compiler;
use crate::shader_stage;
use crate::vk;
use crate::script::{Script, Buffer, BufferType, Operation};
use crate::pipeline_key;
use crate::logger::Logger;
use crate::vbo::Vbo;
use std::rc::Rc;
use std::ptr;
use std::mem;
use std::fmt;
#[derive(Debug)]
pub struct PipelineSet {
pipelines: PipelineVec,
layout: PipelineLayout,
// The descriptor data is only created if there are buffers in the
// script
descriptor_data: Option<DescriptorData>,
stages: vk::VkShaderStageFlagBits,
// These are never read but they should probably be kept alive
// until the pipelines are destroyed.
_pipeline_cache: PipelineCache,
_modules: [Option<ShaderModule>; shader_stage::N_STAGES],
}
#[repr(C)]
#[derive(Debug)]
pub struct RectangleVertex {
pub x: f32,
pub y: f32,
pub z: f32,
}
/// An error that can be returned by [PipelineSet::new].
#[derive(Debug)]
pub enum Error {
/// Compiling one of the shaders in the script failed
CompileError(compiler::Error),
/// vkCreatePipelineCache failed
CreatePipelineCacheFailed,
/// vkCreateDescriptorPool failed
CreateDescriptorPoolFailed,
/// vkCreateDescriptorSetLayout failed
CreateDescriptorSetLayoutFailed,
/// vkCreatePipelineLayout failed
CreatePipelineLayoutFailed,
/// vkCreatePipeline failed
CreatePipelineFailed,
}
impl From<compiler::Error> for Error {
fn from(e: compiler::Error) -> Error {
Error::CompileError(e)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match self {
Error::CompileError(e) => e.fmt(f),
Error::CreatePipelineCacheFailed => {
write!(f, "vkCreatePipelineCache failed")
},
Error::CreateDescriptorPoolFailed => {
write!(f, "vkCreateDescriptorPool failed")
},
Error::CreateDescriptorSetLayoutFailed => {
write!(f, "vkCreateDescriptorSetLayout failed")
},
Error::CreatePipelineLayoutFailed => {
write!(f, "vkCreatePipelineLayout failed")
},
Error::CreatePipelineFailed => {
write!(f, "Pipeline creation function failed")
},
}
}
}
#[derive(Debug)]
struct ShaderModule {
handle: vk::VkShaderModule,
// needed for the destructor
window: Rc<Window>,
}
impl Drop for ShaderModule {
fn drop(&mut self) {
unsafe {
self.window.device().vkDestroyShaderModule.unwrap()(
self.window.vk_device(),
self.handle,
ptr::null(), // allocator
);
}
}
}
#[derive(Debug)]
struct PipelineCache {
handle: vk::VkPipelineCache,
// needed for the destructor
window: Rc<Window>,
}
impl Drop for PipelineCache {
fn drop(&mut self) {
unsafe {
self.window.device().vkDestroyPipelineCache.unwrap()(
self.window.vk_device(),
self.handle,
ptr::null(), // allocator
);
}
}
}
impl PipelineCache {
fn new(window: Rc<Window>) -> Result<PipelineCache, Error> {
let create_info = vk::VkPipelineCacheCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO,
flags: 0,
pNext: ptr::null(),
initialDataSize: 0,
pInitialData: ptr::null(),
};
let mut handle = vk::null_handle();
let res = unsafe {
window.device().vkCreatePipelineCache.unwrap()(
window.vk_device(),
ptr::addr_of!(create_info),
ptr::null(), // allocator
ptr::addr_of_mut!(handle),
)
};
if res == vk::VK_SUCCESS {
Ok(PipelineCache { handle, window })
} else {
Err(Error::CreatePipelineCacheFailed)
}
}
}
#[derive(Debug)]
struct DescriptorPool {
handle: vk::VkDescriptorPool,
// needed for the destructor
window: Rc<Window>,
}
impl Drop for DescriptorPool {
fn drop(&mut self) {
unsafe {
self.window.device().vkDestroyDescriptorPool.unwrap()(
self.window.vk_device(),
self.handle,
ptr::null(), // allocator
);
}
}
}
impl DescriptorPool {
fn new(
window: Rc<Window>,
buffers: &[Buffer],
) -> Result<DescriptorPool, Error> {
let mut n_ubos = 0;
let mut n_ssbos = 0;
for buffer in buffers {
match buffer.buffer_type {
BufferType::Ubo => n_ubos += 1,
BufferType::Ssbo => n_ssbos += 1,
}
}
let mut pool_sizes = Vec::<vk::VkDescriptorPoolSize>::new();
if n_ubos > 0 {
pool_sizes.push(vk::VkDescriptorPoolSize {
type_: vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
descriptorCount: n_ubos,
});
}
if n_ssbos > 0 {
pool_sizes.push(vk::VkDescriptorPoolSize {
type_: vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
descriptorCount: n_ssbos,
});
}
// The descriptor pool shouldn’t have been created if there
// were no buffers
assert!(!pool_sizes.is_empty());
let create_info = vk::VkDescriptorPoolCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO,
flags: vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT,
pNext: ptr::null(),
maxSets: n_desc_sets(buffers) as u32,
poolSizeCount: pool_sizes.len() as u32,
pPoolSizes: pool_sizes.as_ptr(),
};
let mut handle = vk::null_handle();
let res = unsafe {
window.device().vkCreateDescriptorPool.unwrap()(
window.vk_device(),
ptr::addr_of!(create_info),
ptr::null(), // allocator
ptr::addr_of_mut!(handle),
)
};
if res == vk::VK_SUCCESS {
Ok(DescriptorPool { handle, window })
} else {
Err(Error::CreateDescriptorPoolFailed)
}
}
}
#[derive(Debug)]
struct DescriptorSetLayoutVec {
handles: Vec<vk::VkDescriptorSetLayout>,
// needed for the destructor
window: Rc<Window>,
}
impl Drop for DescriptorSetLayoutVec {
fn drop(&mut self) {
for &handle in self.handles.iter() {
unsafe {
self.window.device().vkDestroyDescriptorSetLayout.unwrap()(
self.window.vk_device(),
handle,
ptr::null(), // allocator
);
}
}
}
}
impl DescriptorSetLayoutVec {
fn new(window: Rc<Window>) -> DescriptorSetLayoutVec {
DescriptorSetLayoutVec { window, handles: Vec::new() }
}
fn len(&self) -> usize {
self.handles.len()
}
fn as_ptr(&self) -> *const vk::VkDescriptorSetLayout {
self.handles.as_ptr()
}
fn add(
&mut self,
bindings: &[vk::VkDescriptorSetLayoutBinding],
) -> Result<(), Error> {
let create_info = vk::VkDescriptorSetLayoutCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
flags: 0,
pNext: ptr::null(),
bindingCount: bindings.len() as u32,
pBindings: bindings.as_ptr(),
};
let mut handle = vk::null_handle();
let res = unsafe {
self.window.device().vkCreateDescriptorSetLayout.unwrap()(
self.window.vk_device(),
ptr::addr_of!(create_info),
ptr::null(), // allocator
ptr::addr_of_mut!(handle),
)
};
if res == vk::VK_SUCCESS {
self.handles.push(handle);
Ok(())
} else {
Err(Error::CreateDescriptorSetLayoutFailed)
}
}
}
fn create_descriptor_set_layouts(
window: Rc<Window>,
buffers: &[Buffer],
stages: vk::VkShaderStageFlags,
) -> Result<DescriptorSetLayoutVec, Error> {
let n_desc_sets = n_desc_sets(buffers);
let mut layouts = DescriptorSetLayoutVec::new(window);
let mut bindings = Vec::new();
let mut buffer_num = 0;
for desc_set in 0..n_desc_sets {
bindings.clear();
while buffer_num < buffers.len()
&& buffers[buffer_num].desc_set as usize == desc_set
{
let buffer = &buffers[buffer_num];
bindings.push(vk::VkDescriptorSetLayoutBinding {
binding: buffer.binding,
descriptorType: match buffer.buffer_type {
BufferType::Ubo => vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
BufferType::Ssbo => vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
},
descriptorCount: 1,
stageFlags: stages,
pImmutableSamplers: ptr::null(),
});
buffer_num += 1;
}
layouts.add(&bindings)?;
}
assert_eq!(layouts.len(), n_desc_sets);
Ok(layouts)
}
fn n_desc_sets(buffers: &[Buffer]) -> usize {
match buffers.last() {
// The number of descriptor sets is the highest used
// descriptor set index + 1. The buffers are in order so the
// highest one should be the last one.
Some(last) => last.desc_set as usize + 1,
None => 0,
}
}
#[derive(Debug)]
struct DescriptorData {
pool: DescriptorPool,
layouts: DescriptorSetLayoutVec,
}
impl DescriptorData {
fn new(
window: &Rc<Window>,
buffers: &[Buffer],
stages: vk::VkShaderStageFlagBits,
) -> Result<DescriptorData, Error> {
let pool = DescriptorPool::new(
Rc::clone(&window),
buffers,
)?;
let layouts = create_descriptor_set_layouts(
Rc::clone(&window),
buffers,
stages,
)?;
Ok(DescriptorData { pool, layouts })
}
}
fn compile_shaders(
logger: &mut Logger,
window: &Rc<Window>,
script: &Script,
show_disassembly: bool,
) -> Result<[Option<ShaderModule>; shader_stage::N_STAGES], Error> {
let mut modules: [Option<ShaderModule>; shader_stage::N_STAGES] =
Default::default();
for &stage in shader_stage::ALL_STAGES.iter() {
if script.shaders(stage).is_empty() {
continue;
}
modules[stage as usize] = Some(ShaderModule {
handle: compiler::build_stage(
logger,
window.context(),
script,
stage,
show_disassembly,
)?,
window: Rc::clone(window),
});
}
Ok(modules)
}
fn stage_flags(script: &Script) -> vk::VkShaderStageFlagBits {
// Set a flag for each stage that has a shader in the script
shader_stage::ALL_STAGES
.iter()
.filter_map(|&stage| if script.shaders(stage).is_empty() {
None
} else {
Some(stage.flag())
})
.fold(0, |a, b| a | b)
}
fn push_constant_size(script: &Script) -> usize {
script.commands()
.iter()
.map(|command| match &command.op {
Operation::SetPushCommand { offset, data } => offset + data.len(),
_ => 0,
})
.max()
.unwrap_or(0)
}
#[derive(Debug)]
struct PipelineLayout {
handle: vk::VkPipelineLayout,
// needed for the destructor
window: Rc<Window>,
}
impl Drop for PipelineLayout {
fn drop(&mut self) {
unsafe {
self.window.device().vkDestroyPipelineLayout.unwrap()(
self.window.vk_device(),
self.handle,
ptr::null(), // allocator
);
}
}
}
impl PipelineLayout {
fn new(
window: Rc<Window>,
script: &Script,
stages: vk::VkShaderStageFlagBits,
descriptor_data: Option<&DescriptorData>
) -> Result<PipelineLayout, Error> {
let mut handle = vk::null_handle();
let push_constant_range = vk::VkPushConstantRange {
stageFlags: stages,
offset: 0,
size: push_constant_size(script) as u32,
};
let mut create_info = vk::VkPipelineLayoutCreateInfo::default();
create_info.sType = vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
if push_constant_range.size > 0 {
create_info.pushConstantRangeCount = 1;
create_info.pPushConstantRanges =
ptr::addr_of!(push_constant_range);
}
if let Some(descriptor_data) = descriptor_data {
create_info.setLayoutCount = descriptor_data.layouts.len() as u32;
create_info.pSetLayouts = descriptor_data.layouts.as_ptr();
}
let res = unsafe {
window.device().vkCreatePipelineLayout.unwrap()(
window.vk_device(),
ptr::addr_of!(create_info),
ptr::null(), // allocator
ptr::addr_of_mut!(handle),
)
};
if res == vk::VK_SUCCESS {
Ok(PipelineLayout { handle, window })
} else {
Err(Error::CreatePipelineLayoutFailed)
}
}
}
#[derive(Debug)]
struct VertexInputState {
create_info: vk::VkPipelineVertexInputStateCreateInfo,
// These are not read put the create_info struct keeps pointers to
// them so they need to be kept alive
_input_bindings: Vec::<vk::VkVertexInputBindingDescription>,
_attribs: Vec::<vk::VkVertexInputAttributeDescription>,
}
impl VertexInputState {
fn new(script: &Script, key: &pipeline_key::Key) -> VertexInputState {
let mut input_bindings = Vec::new();
let mut attribs = Vec::new();
match key.source() {
pipeline_key::Source::Rectangle => {
input_bindings.push(vk::VkVertexInputBindingDescription {
binding: 0,
stride: mem::size_of::<RectangleVertex>() as u32,
inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX,
});
attribs.push(vk::VkVertexInputAttributeDescription {
location: 0,
binding: 0,
format: vk::VK_FORMAT_R32G32B32_SFLOAT,
offset: 0,
});
},
pipeline_key::Source::VertexData => {
if let Some(vbo) = script.vertex_data() {
VertexInputState::set_up_vertex_data_attribs(
&mut input_bindings,
&mut attribs,
vbo
);
}
}
}
VertexInputState {
create_info: vk::VkPipelineVertexInputStateCreateInfo {
sType:
vk::VK_STRUCTURE_TYPE_PIPELINE_VERTEX_INPUT_STATE_CREATE_INFO,
flags: 0,
pNext: ptr::null(),
vertexBindingDescriptionCount: input_bindings.len() as u32,
pVertexBindingDescriptions: input_bindings.as_ptr(),
vertexAttributeDescriptionCount: attribs.len() as u32,
pVertexAttributeDescriptions: attribs.as_ptr(),
},
_input_bindings: input_bindings,
_attribs: attribs,
}
}
fn set_up_vertex_data_attribs(
input_bindings: &mut Vec::<vk::VkVertexInputBindingDescription>,
attribs: &mut Vec::<vk::VkVertexInputAttributeDescription>,
vbo: &Vbo,
) {
input_bindings.push(vk::VkVertexInputBindingDescription {
binding: 0,
stride: vbo.stride() as u32,
inputRate: vk::VK_VERTEX_INPUT_RATE_VERTEX,
});
for attrib in vbo.attribs().iter() {
attribs.push(vk::VkVertexInputAttributeDescription {
location: attrib.location(),
binding: 0,
format: attrib.format().vk_format,
offset: attrib.offset() as u32,
});
}
}
}
#[derive(Debug)]
struct PipelineVec {
handles: Vec<vk::VkPipeline>,
// needed for the destructor
window: Rc<Window>,
}
impl Drop for PipelineVec {
fn drop(&mut self) {
for &handle in self.handles.iter() {
unsafe {
self.window.device().vkDestroyPipeline.unwrap()(
self.window.vk_device(),
handle,
ptr::null(), // allocator
);
}
}
}
}
impl PipelineVec {
fn new(
window: Rc<Window>,
script: &Script,
pipeline_cache: vk::VkPipelineCache,
layout: vk::VkPipelineLayout,
modules: &[Option<ShaderModule>],
) -> Result<PipelineVec, Error> {
let mut vec = PipelineVec { window, handles: Vec::new() };
let mut first_graphics_pipeline: Option<vk::VkPipeline> = None;
for key in script.pipeline_keys().iter() {
let pipeline = match key.pipeline_type() {
pipeline_key::Type::Graphics => {
let allow_derivatives =
first_graphics_pipeline.is_none()
&& script.pipeline_keys().len() > 1;
let pipeline = PipelineVec::create_graphics_pipeline(
&vec.window,
script,
key,
pipeline_cache,
layout,
modules,
allow_derivatives,
first_graphics_pipeline,
)?;
first_graphics_pipeline.get_or_insert(pipeline);
pipeline
},
pipeline_key::Type::Compute => {
PipelineVec::create_compute_pipeline(
&vec.window,
key,
pipeline_cache,
layout,
modules[shader_stage::Stage::Compute as usize]
.as_ref()
.map(|m| m.handle)
.unwrap_or(vk::null_handle()),
script.requirements().required_subgroup_size,
)?
},
};
vec.handles.push(pipeline);
}
Ok(vec)
}
fn null_terminated_entrypoint(
key: &pipeline_key::Key,
stage: shader_stage::Stage,
) -> String {
let mut entrypoint = key.entrypoint(stage).to_string();
entrypoint.push('\0');
entrypoint
}
fn create_stages(
modules: &[Option<ShaderModule>],
entrypoints: &[String],
) -> Vec<vk::VkPipelineShaderStageCreateInfo> {
let mut stages = Vec::new();
for &stage in shader_stage::ALL_STAGES.iter() {
if stage == shader_stage::Stage::Compute {
continue;
}
let module = match &modules[stage as usize] {
Some(module) => module.handle,
None => continue,
};
stages.push(vk::VkPipelineShaderStageCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
flags: 0,
pNext: ptr::null(),
stage: stage.flag(),
module,
pName: entrypoints[stage as usize].as_ptr().cast(),
pSpecializationInfo: ptr::null(),
});
}
stages
}
fn create_graphics_pipeline(
window: &Window,
script: &Script,
key: &pipeline_key::Key,
pipeline_cache: vk::VkPipelineCache,
layout: vk::VkPipelineLayout,
modules: &[Option<ShaderModule>],
allow_derivatives: bool,
parent_pipeline: Option<vk::VkPipeline>,
) -> Result<vk::VkPipeline, Error> {
let create_info_buf = key.to_create_info();
let mut create_info = vk::VkGraphicsPipelineCreateInfo::default();
unsafe {
ptr::copy_nonoverlapping(
create_info_buf.as_ptr().cast(),
ptr::addr_of_mut!(create_info),
1, // count
);
}
let entrypoints = shader_stage::ALL_STAGES
.iter()
.map(|&s| PipelineVec::null_terminated_entrypoint(key, s))
.collect::<Vec<String>>();
let stages = PipelineVec::create_stages(modules, &entrypoints);
let window_format = window.format();
let viewport = vk::VkViewport {
x: 0.0,
y: 0.0,
width: window_format.width as f32,
height: window_format.height as f32,
minDepth: 0.0,
maxDepth: 1.0,
};
let scissor = vk::VkRect2D {
offset: vk::VkOffset2D {
x: 0,
y: 0,
},
extent: vk::VkExtent2D {
width: window_format.width as u32,
height: window_format.height as u32,
},
};
let viewport_state = vk::VkPipelineViewportStateCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO,
flags: 0,
pNext: ptr::null(),
viewportCount: 1,
pViewports: ptr::addr_of!(viewport),
scissorCount: 1,
pScissors: ptr::addr_of!(scissor),
};
let multisample_state = vk::VkPipelineMultisampleStateCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO,
pNext: ptr::null(),
flags: 0,
rasterizationSamples: vk::VK_SAMPLE_COUNT_1_BIT,
sampleShadingEnable: vk::VK_FALSE,
minSampleShading: 0.0,
pSampleMask: ptr::null(),
alphaToCoverageEnable: vk::VK_FALSE,
alphaToOneEnable: vk::VK_FALSE,
};
create_info.pViewportState = ptr::addr_of!(viewport_state);
create_info.pMultisampleState = ptr::addr_of!(multisample_state);
create_info.subpass = 0;
create_info.basePipelineHandle =
parent_pipeline.unwrap_or(vk::null_handle());
create_info.basePipelineIndex = -1;
create_info.stageCount = stages.len() as u32;
create_info.pStages = stages.as_ptr();
create_info.layout = layout;
create_info.renderPass = window.render_passes()[0];
if allow_derivatives {
create_info.flags |= vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT;
}
if parent_pipeline.is_some() {
create_info.flags |= vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT;
}
if modules[shader_stage::Stage::TessCtrl as usize].is_none()
&& modules[shader_stage::Stage::TessEval as usize].is_none()
{
create_info.pTessellationState = ptr::null();
}
let vertex_input_state = VertexInputState::new(script, key);
create_info.pVertexInputState =
ptr::addr_of!(vertex_input_state.create_info);
let mut handle = vk::null_handle();
let res = unsafe {
window.device().vkCreateGraphicsPipelines.unwrap()(
window.vk_device(),
pipeline_cache,
1, // nCreateInfos
ptr::addr_of!(create_info),
ptr::null(), // allocator
ptr::addr_of_mut!(handle),
)
};
if res == vk::VK_SUCCESS {
Ok(handle)
} else {
Err(Error::CreatePipelineFailed)
}
}
fn create_compute_pipeline(
window: &Window,
key: &pipeline_key::Key,
pipeline_cache: vk::VkPipelineCache,
layout: vk::VkPipelineLayout,
module: vk::VkShaderModule,
required_subgroup_size: Option<u32>,
) -> Result<vk::VkPipeline, Error> {
let entrypoint = PipelineVec::null_terminated_entrypoint(
key,
shader_stage::Stage::Compute,
);
let rss_info = required_subgroup_size.map(|size|
vk::VkPipelineShaderStageRequiredSubgroupSizeCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
pNext: ptr::null_mut(),
requiredSubgroupSize: size,
});
let create_info = vk::VkComputePipelineCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
pNext: ptr::null(),
flags: 0,
stage: vk::VkPipelineShaderStageCreateInfo {
sType: vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
pNext: if let Some(ref rss_info) = rss_info {
ptr::addr_of!(*rss_info).cast()
} else {
ptr::null_mut()
},
flags: 0,
stage: vk::VK_SHADER_STAGE_COMPUTE_BIT,
module,
pName: entrypoint.as_ptr().cast(),
pSpecializationInfo: ptr::null(),
},
layout,
basePipelineHandle: vk::null_handle(),
basePipelineIndex: -1,
};
let mut handle = vk::null_handle();
let res = unsafe {
window.device().vkCreateComputePipelines.unwrap()(
window.vk_device(),
pipeline_cache,
1, // nCreateInfos
ptr::addr_of!(create_info),
ptr::null(), // allocator
ptr::addr_of_mut!(handle),
)
};
if res == vk::VK_SUCCESS {
Ok(handle)
} else {
Err(Error::CreatePipelineFailed)
}
}
}
impl PipelineSet {
pub fn new(
logger: &mut Logger,
window: Rc<Window>,
script: &Script,
show_disassembly: bool,
) -> Result<PipelineSet, Error> {
let modules = compile_shaders(
logger,
&window,
script,
show_disassembly
)?;
let pipeline_cache = PipelineCache::new(Rc::clone(&window))?;
let stages = stage_flags(script);
let descriptor_data = if script.buffers().is_empty() {
None
} else {
Some(DescriptorData::new(&window, script.buffers(), stages)?)
};
let layout = PipelineLayout::new(
Rc::clone(&window),
script,
stages,
descriptor_data.as_ref(),
)?;
let pipelines = PipelineVec::new(
Rc::clone(&window),
script,
pipeline_cache.handle,
layout.handle,
&modules,
)?;
Ok(PipelineSet {
_modules: modules,
_pipeline_cache: pipeline_cache,
stages,
descriptor_data,
layout,
pipelines,
})
}
pub fn descriptor_set_layouts(&self) -> &[vk::VkDescriptorSetLayout] {
match self.descriptor_data.as_ref() {
Some(data) => &data.layouts.handles,
None => &[],
}
}
pub fn stages(&self) -> vk::VkShaderStageFlagBits {
self.stages
}
pub fn layout(&self) -> vk::VkPipelineLayout {
self.layout.handle
}
pub fn pipelines(&self) -> &[vk::VkPipeline] {
&self.pipelines.handles
}
pub fn descriptor_pool(&self) -> Option<vk::VkDescriptorPool> {
self.descriptor_data.as_ref().map(|data| data.pool.handle)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::fake_vulkan::{FakeVulkan, HandleType, PipelineCreateInfo};
use crate::fake_vulkan::GraphicsPipelineCreateInfo;
use crate::fake_vulkan::PipelineLayoutCreateInfo;
use crate::context::Context;
use crate::requirements::Requirements;
use crate::source::Source;
use crate::config::Config;
#[derive(Debug)]
struct TestData {
pipeline_set: PipelineSet,
window: Rc<Window>,
fake_vulkan: Box<FakeVulkan>,
}
impl TestData {
fn new_with_errors<F: FnOnce(&mut FakeVulkan)>(
source: &str,
queue_errors: F
) -> Result<TestData, Error> {
let mut fake_vulkan = FakeVulkan::new();
fake_vulkan.physical_devices.push(Default::default());
fake_vulkan.physical_devices[0].format_properties.insert(
vk::VK_FORMAT_B8G8R8A8_UNORM,
vk::VkFormatProperties {
linearTilingFeatures: 0,
optimalTilingFeatures:
vk::VK_FORMAT_FEATURE_COLOR_ATTACHMENT_BIT
| vk::VK_FORMAT_FEATURE_BLIT_SRC_BIT,
bufferFeatures: 0,
},
);
let memory_properties =
&mut fake_vulkan.physical_devices[0].memory_properties;
memory_properties.memoryTypes[0].propertyFlags =
vk::VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
memory_properties.memoryTypeCount = 1;
fake_vulkan.memory_requirements.memoryTypeBits = 1;
fake_vulkan.set_override();
let context = Rc::new(Context::new(
&Requirements::new(),
None, // device_id
).unwrap());
let window = Rc::new(Window::new(
Rc::clone(&context),
&Default::default(), // format
).unwrap());
queue_errors(&mut fake_vulkan);
let mut logger = Logger::new(None, ptr::null_mut());
let source = Source::from_string(source.to_string());
let script = Script::load(&Config::new(), &source).unwrap();
let pipeline_set = PipelineSet::new(
&mut logger,
Rc::clone(&window),
&script,
false, // show_disassembly
)?;
Ok(TestData { fake_vulkan, window, pipeline_set })
}
fn new(source: &str) -> Result<TestData, Error> {
TestData::new_with_errors(source, |_| ())
}
fn graphics_create_info(
&mut self,
pipeline_num: usize,
) -> GraphicsPipelineCreateInfo {
let pipeline = self.pipeline_set.pipelines.handles[pipeline_num];
match &self.fake_vulkan.get_handle(pipeline).data {
HandleType::Pipeline(PipelineCreateInfo::Graphics(info)) => {
info.clone()
},
handle @ _ => {
unreachable!("unexpected handle type: {:?}", handle)
},
}
}
fn compute_create_info(
&mut self,
pipeline_num: usize,
) -> vk::VkComputePipelineCreateInfo {
let pipeline = self.pipeline_set.pipelines.handles[pipeline_num];
match &self.fake_vulkan.get_handle(pipeline).data {
HandleType::Pipeline(PipelineCreateInfo::Compute(info)) => {
info.clone()
},
handle @ _ => {
unreachable!("unexpected handle type: {:?}", handle)
},
}
}
fn shader_module_code(
&mut self,
stage: shader_stage::Stage,
) -> Vec<u32> {
let module = self.pipeline_set._modules[stage as usize]
.as_ref()
.unwrap()
.handle;
match &self.fake_vulkan.get_handle(module).data {
HandleType::ShaderModule { code } => code.clone(),
_ => unreachable!("Unexpected Vulkan handle type"),
}
}
fn descriptor_set_layout_bindings(
&mut self,
desc_set: usize,
) -> Vec<vk::VkDescriptorSetLayoutBinding> {
let desc_set_layout = self.pipeline_set
.descriptor_set_layouts()
[desc_set];
match &self.fake_vulkan.get_handle(desc_set_layout).data {
HandleType::DescriptorSetLayout { bindings } => {
bindings.clone()
},
_ => unreachable!("Unexpected Vulkan handle type"),
}
}
fn pipeline_layout_create_info(&mut self) -> PipelineLayoutCreateInfo {
let layout = self.pipeline_set.layout();
match &self.fake_vulkan.get_handle(layout).data {
HandleType::PipelineLayout(create_info) => create_info.clone(),
_ => unreachable!("Unexpected Vulkan handle type"),
}
}
}
#[test]
fn base() {
let mut test_data = TestData::new(
"[vertex shader passthrough]\n\
[fragment shader]\n\
03 02 23 07\n\
fe ca fe ca\n\
[vertex data]\n\
0/R32G32_SFLOAT 1/R32_SFLOAT\n\
-0.5 -0.5 1.52\n\
0.5 -0.5 1.55\n\
[compute shader]\n\
03 02 23 07\n\
ca fe ca fe\n\
[test]\n\
ubo 0 1024\n\
ssbo 1 1024\n\
compute 1 1 1\n\
push float 6 42.0\n\
draw rect 0 0 1 1\n\
draw arrays TRIANGLE_LIST 0 2\n"
).unwrap();
assert_eq!(
test_data.pipeline_set.stages(),
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT
| vk::VK_SHADER_STAGE_COMPUTE_BIT,
);
let descriptor_pool = test_data.pipeline_set.descriptor_pool();
assert!(matches!(
test_data.fake_vulkan.get_handle(descriptor_pool.unwrap()).data,
HandleType::DescriptorPool,
));
assert_eq!(test_data.pipeline_set.pipelines().len(), 3);
let compute_create_info = test_data.compute_create_info(0);
assert_eq!(compute_create_info.sType, vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO);
let create_data = test_data.graphics_create_info(1);
let create_info = &create_data.create_info;
assert_eq!(
create_info.flags,
vk::VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT,
);
assert_eq!(create_info.stageCount, 2);
assert_eq!(create_info.layout, test_data.pipeline_set.layout());
assert_eq!(
create_info.renderPass,
test_data.window.render_passes()[0]
);
assert_eq!(create_info.basePipelineHandle, vk::null_handle());
assert_eq!(create_data.bindings.len(), 1);
assert_eq!(create_data.bindings[0].binding, 0);
assert_eq!(
create_data.bindings[0].stride as usize,
mem::size_of::<RectangleVertex>(),
);
assert_eq!(
create_data.bindings[0].inputRate,
vk::VK_VERTEX_INPUT_RATE_VERTEX,
);
assert_eq!(create_data.attribs.len(), 1);
assert_eq!(create_data.attribs[0].location, 0);
assert_eq!(create_data.attribs[0].binding, 0);
assert_eq!(
create_data.attribs[0].format,
vk::VK_FORMAT_R32G32B32_SFLOAT
);
assert_eq!(create_data.attribs[0].offset, 0);
let create_data = test_data.graphics_create_info(2);
let create_info = &create_data.create_info;
assert_eq!(
create_info.flags,
vk::VK_PIPELINE_CREATE_DERIVATIVE_BIT,
);
assert_eq!(create_info.stageCount, 2);
assert_eq!(create_info.layout, test_data.pipeline_set.layout());
assert_eq!(
create_info.renderPass,
test_data.window.render_passes()[0]
);
assert_eq!(
create_info.basePipelineHandle,
test_data.pipeline_set.pipelines()[1]
);
assert_eq!(create_data.bindings.len(), 1);
assert_eq!(create_data.bindings[0].binding, 0);
assert_eq!(
create_data.bindings[0].stride as usize,
mem::size_of::<f32>() * 3,
);
assert_eq!(
create_data.bindings[0].inputRate,
vk::VK_VERTEX_INPUT_RATE_VERTEX,
);
assert_eq!(create_data.attribs.len(), 2);
assert_eq!(create_data.attribs[0].location, 0);
assert_eq!(create_data.attribs[0].binding, 0);
assert_eq!(
create_data.attribs[0].format,
vk::VK_FORMAT_R32G32_SFLOAT
);
assert_eq!(create_data.attribs[0].offset, 0);
assert_eq!(create_data.attribs[1].location, 1);
assert_eq!(create_data.attribs[1].binding, 0);
assert_eq!(
create_data.attribs[1].format,
vk::VK_FORMAT_R32_SFLOAT
);
assert_eq!(create_data.attribs[1].offset, 8);
let code = test_data.shader_module_code(shader_stage::Stage::Vertex);
assert!(code.len() > 2);
let code = test_data.shader_module_code(shader_stage::Stage::Fragment);
assert_eq!(&code, &[0x07230203, 0xcafecafe]);
let code = test_data.shader_module_code(shader_stage::Stage::Compute);
assert_eq!(&code, &[0x07230203, 0xfecafeca]);
assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 1);
let bindings = test_data.descriptor_set_layout_bindings(0);
assert_eq!(bindings.len(), 2);
assert_eq!(bindings[0].binding, 0);
assert_eq!(
bindings[0].descriptorType,
vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
);
assert_eq!(bindings[0].descriptorCount, 1);
assert_eq!(
bindings[0].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT
| vk::VK_SHADER_STAGE_COMPUTE_BIT,
);
assert_eq!(bindings[1].binding, 1);
assert_eq!(
bindings[1].descriptorType,
vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
);
assert_eq!(bindings[1].descriptorCount, 1);
assert_eq!(
bindings[1].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT
| vk::VK_SHADER_STAGE_COMPUTE_BIT,
);
let layout = test_data.pipeline_layout_create_info();
assert_eq!(layout.create_info.sType, vk::VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO);
assert_eq!(layout.push_constant_ranges.len(), 1);
assert_eq!(
layout.push_constant_ranges[0].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT
| vk::VK_SHADER_STAGE_COMPUTE_BIT,
);
assert_eq!(layout.push_constant_ranges[0].offset, 0);
assert_eq!(
layout.push_constant_ranges[0].size as usize,
6 + mem::size_of::<f32>()
);
assert_eq!(
&layout.layouts,
test_data.pipeline_set.descriptor_set_layouts()
);
}
#[test]
fn descriptor_set_gaps() {
let mut test_data = TestData::new(
"[vertex shader passthrough]\n\
[fragment shader]\n\
03 02 23 07\n\
fe ca fe ca\n\
[test]\n\
ubo 0 1024\n\
ssbo 2:1 1024\n\
ssbo 3:1 1024\n\
ubo 5:5 1024\n\
draw rect 0 0 1 1\n"
).unwrap();
assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 6);
let bindings = test_data.descriptor_set_layout_bindings(0);
assert_eq!(bindings.len(), 1);
assert_eq!(bindings[0].binding, 0);
assert_eq!(
bindings[0].descriptorType,
vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
);
assert_eq!(bindings[0].descriptorCount, 1);
assert_eq!(
bindings[0].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT,
);
assert_eq!(test_data.descriptor_set_layout_bindings(1).len(), 0);
let bindings = test_data.descriptor_set_layout_bindings(2);
assert_eq!(bindings.len(), 1);
assert_eq!(bindings[0].binding, 1);
assert_eq!(
bindings[0].descriptorType,
vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
);
assert_eq!(bindings[0].descriptorCount, 1);
assert_eq!(
bindings[0].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT,
);
let bindings = test_data.descriptor_set_layout_bindings(3);
assert_eq!(bindings.len(), 1);
assert_eq!(bindings[0].binding, 1);
assert_eq!(
bindings[0].descriptorType,
vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
);
assert_eq!(bindings[0].descriptorCount, 1);
assert_eq!(
bindings[0].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT,
);
assert_eq!(test_data.descriptor_set_layout_bindings(4).len(), 0);
let bindings = test_data.descriptor_set_layout_bindings(5);
assert_eq!(bindings.len(), 1);
assert_eq!(bindings[0].binding, 5);
assert_eq!(
bindings[0].descriptorType,
vk::VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
);
assert_eq!(bindings[0].descriptorCount, 1);
assert_eq!(
bindings[0].stageFlags,
vk::VK_SHADER_STAGE_VERTEX_BIT
| vk::VK_SHADER_STAGE_FRAGMENT_BIT,
);
}
#[test]
fn compile_error() {
let error = TestData::new(
"[fragment shader]\n\
12 34 56 78\n"
).unwrap_err();
assert_eq!(
&error.to_string(),
"The compiler or assembler generated an invalid SPIR-V binary"
);
}
#[test]
fn create_errors() {
let commands = [
"vkCreatePipelineCache",
"vkCreateDescriptorPool",
"vkCreateDescriptorSetLayout",
"vkCreatePipelineLayout",
];
for command in commands {
let error = TestData::new_with_errors(
"[test]\n\
ubo 0 10\n\
draw rect 0 0 1 1\n",
|fake_vulkan| fake_vulkan.queue_result(
command.to_string(),
vk::VK_ERROR_UNKNOWN,
),
).unwrap_err();
assert_eq!(
error.to_string(),
format!("{} failed", command),
);
}
}
#[test]
fn create_pipeline_errors() {
let commands = [
"vkCreateGraphicsPipelines",
"vkCreateComputePipelines",
];
for command in commands {
let error = TestData::new_with_errors(
"[compute shader]\n\
03 02 23 07\n\
ca fe ca fe\n\
[test]\n\
compute 1 1 1\n\
draw rect 0 0 1 1\n",
|fake_vulkan| fake_vulkan.queue_result(
command.to_string(),
vk::VK_ERROR_UNKNOWN,
),
).unwrap_err();
assert_eq!(
&error.to_string(),
"Pipeline creation function failed",
);
}
}
#[test]
fn no_buffers() {
let test_data = TestData::new("").unwrap();
assert!(test_data.pipeline_set.descriptor_pool().is_none());
assert_eq!(test_data.pipeline_set.descriptor_set_layouts().len(), 0);
}
}
```
--------------------------------------------------------------------------------
/vkrunner/vkrunner/slot.rs:
--------------------------------------------------------------------------------
```rust
// vkrunner
//
// Copyright (C) 2018 Intel Corporation
// Copyright 2023 Neil Roberts
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice (including the next
// paragraph) shall be included in all copies or substantial portions of the
// Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//! Contains utilities for accessing values stored in a buffer
//! according to the GLSL layout rules. These are decribed in detail
//! in the OpenGL specs here:
//! <https://registry.khronos.org/OpenGL/specs/gl/glspec45.core.pdf#page=159>
use crate::util;
use crate::tolerance::Tolerance;
use crate::half_float;
use std::mem;
use std::convert::TryInto;
use std::fmt;
/// Describes which layout standard is being used. The only difference
/// is that with `Std140` the offset between members of an
/// array or between elements of the minor axis of a matrix is rounded
/// up to be a multiple of 16 bytes.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum LayoutStd {
Std140,
Std430,
}
/// For matrix types, describes which axis is the major one. The
/// elements of the major axis are stored consecutively in memory. For
/// example, if the numbers indicate the order in memory of the
/// components of a matrix, the two orderings would look like this:
///
/// ```text
/// Column major Row major
/// [0 3 6] [0 1 2]
/// [1 4 7] [3 4 5]
/// [2 5 8] [6 7 8]
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MajorAxis {
Column,
Row
}
/// Combines the [MajorAxis] and the [LayoutStd] to provide a complete
/// description of the layout options used for accessing the
/// components of a type.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Layout {
pub std: LayoutStd,
pub major: MajorAxis,
}
/// An enum representing all of the types that can be stored in a
/// slot. These correspond to the types in GLSL.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Type {
Int = 0,
UInt,
Int8,
UInt8,
Int16,
UInt16,
Int64,
UInt64,
Float16,
Float,
Double,
F16Vec2,
F16Vec3,
F16Vec4,
Vec2,
Vec3,
Vec4,
DVec2,
DVec3,
DVec4,
IVec2,
IVec3,
IVec4,
UVec2,
UVec3,
UVec4,
I8Vec2,
I8Vec3,
I8Vec4,
U8Vec2,
U8Vec3,
U8Vec4,
I16Vec2,
I16Vec3,
I16Vec4,
U16Vec2,
U16Vec3,
U16Vec4,
I64Vec2,
I64Vec3,
I64Vec4,
U64Vec2,
U64Vec3,
U64Vec4,
Mat2,
Mat2x3,
Mat2x4,
Mat3x2,
Mat3,
Mat3x4,
Mat4x2,
Mat4x3,
Mat4,
DMat2,
DMat2x3,
DMat2x4,
DMat3x2,
DMat3,
DMat3x4,
DMat4x2,
DMat4x3,
DMat4,
}
/// The type of a component of the types in [Type].
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BaseType {
Int,
UInt,
Int8,
UInt8,
Int16,
UInt16,
Int64,
UInt64,
Float16,
Float,
Double,
}
/// A base type combined with a slice of bytes. This implements
/// [Display](std::fmt::Display) so it can be used to extract a base
/// type from a byte array and present it as a human-readable string.
pub struct BaseTypeInSlice<'a> {
base_type: BaseType,
slice: &'a [u8],
}
/// A type of comparison that can be used to compare values stored in
/// a slot with the chosen criteria. Use the
/// [compare](Comparison::compare) method to perform the comparison.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Comparison {
/// The comparison passes if the values are exactly equal.
Equal,
/// The comparison passes if the floating-point values are
/// approximately equal using the given [Tolerance] settings. For
/// integer types this is the same as [Equal](Comparison::Equal).
FuzzyEqual,
/// The comparison passes if the values are different.
NotEqual,
/// The comparison passes if all of the values in the first slot
/// are less than the values in the second slot.
Less,
/// The comparison passes if all of the values in the first slot
/// are greater than or equal to the values in the second slot.
GreaterEqual,
/// The comparison passes if all of the values in the first slot
/// are greater than the values in the second slot.
Greater,
/// The comparison passes if all of the values in the first slot
/// are less than or equal to the values in the second slot.
LessEqual,
}
#[derive(Debug)]
struct TypeInfo {
base_type: BaseType,
columns: usize,
rows: usize,
}
static TYPE_INFOS: [TypeInfo; 62] = [
TypeInfo { base_type: BaseType::Int, columns: 1, rows: 1 }, // Int
TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 1 }, // UInt
TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 1 }, // Int8
TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 1 }, // UInt8
TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 1 }, // Int16
TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 1 }, // UInt16
TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 1 }, // Int64
TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 1 }, // UInt64
TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 1 }, // Float16
TypeInfo { base_type: BaseType::Float, columns: 1, rows: 1 }, // Float
TypeInfo { base_type: BaseType::Double, columns: 1, rows: 1 }, // Double
TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 2 }, // F16Vec2
TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 3 }, // F16Vec3
TypeInfo { base_type: BaseType::Float16, columns: 1, rows: 4 }, // F16Vec4
TypeInfo { base_type: BaseType::Float, columns: 1, rows: 2 }, // Vec2
TypeInfo { base_type: BaseType::Float, columns: 1, rows: 3 }, // Vec3
TypeInfo { base_type: BaseType::Float, columns: 1, rows: 4 }, // Vec4
TypeInfo { base_type: BaseType::Double, columns: 1, rows: 2 }, // DVec2
TypeInfo { base_type: BaseType::Double, columns: 1, rows: 3 }, // DVec3
TypeInfo { base_type: BaseType::Double, columns: 1, rows: 4 }, // DVec4
TypeInfo { base_type: BaseType::Int, columns: 1, rows: 2 }, // IVec2
TypeInfo { base_type: BaseType::Int, columns: 1, rows: 3 }, // IVec3
TypeInfo { base_type: BaseType::Int, columns: 1, rows: 4 }, // IVec4
TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 2 }, // UVec2
TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 3 }, // UVec3
TypeInfo { base_type: BaseType::UInt, columns: 1, rows: 4 }, // UVec4
TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 2 }, // I8Vec2
TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 3 }, // I8Vec3
TypeInfo { base_type: BaseType::Int8, columns: 1, rows: 4 }, // I8Vec4
TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 2 }, // U8Vec2
TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 3 }, // U8Vec3
TypeInfo { base_type: BaseType::UInt8, columns: 1, rows: 4 }, // U8Vec4
TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 2 }, // I16Vec2
TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 3 }, // I16Vec3
TypeInfo { base_type: BaseType::Int16, columns: 1, rows: 4 }, // I16Vec4
TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 2 }, // U16Vec2
TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 3 }, // U16Vec3
TypeInfo { base_type: BaseType::UInt16, columns: 1, rows: 4 }, // U16Vec4
TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 2 }, // I64Vec2
TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 3 }, // I64Vec3
TypeInfo { base_type: BaseType::Int64, columns: 1, rows: 4 }, // I64Vec4
TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 2 }, // U64Vec2
TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 3 }, // U64Vec3
TypeInfo { base_type: BaseType::UInt64, columns: 1, rows: 4 }, // U64Vec4
TypeInfo { base_type: BaseType::Float, columns: 2, rows: 2 }, // Mat2
TypeInfo { base_type: BaseType::Float, columns: 2, rows: 3 }, // Mat2x3
TypeInfo { base_type: BaseType::Float, columns: 2, rows: 4 }, // Mat2x4
TypeInfo { base_type: BaseType::Float, columns: 3, rows: 2 }, // Mat3x2
TypeInfo { base_type: BaseType::Float, columns: 3, rows: 3 }, // Mat3
TypeInfo { base_type: BaseType::Float, columns: 3, rows: 4 }, // Mat3x4
TypeInfo { base_type: BaseType::Float, columns: 4, rows: 2 }, // Mat4x2
TypeInfo { base_type: BaseType::Float, columns: 4, rows: 3 }, // Mat4x3
TypeInfo { base_type: BaseType::Float, columns: 4, rows: 4 }, // Mat4
TypeInfo { base_type: BaseType::Double, columns: 2, rows: 2 }, // DMat2
TypeInfo { base_type: BaseType::Double, columns: 2, rows: 3 }, // DMat2x3
TypeInfo { base_type: BaseType::Double, columns: 2, rows: 4 }, // DMat2x4
TypeInfo { base_type: BaseType::Double, columns: 3, rows: 2 }, // DMat3x2
TypeInfo { base_type: BaseType::Double, columns: 3, rows: 3 }, // DMat3
TypeInfo { base_type: BaseType::Double, columns: 3, rows: 4 }, // DMat3x4
TypeInfo { base_type: BaseType::Double, columns: 4, rows: 2 }, // DMat4x2
TypeInfo { base_type: BaseType::Double, columns: 4, rows: 3 }, // DMat4x3
TypeInfo { base_type: BaseType::Double, columns: 4, rows: 4 }, // DMat4
];
// Mapping from GLSL type name to slot type. Sorted alphabetically so
// we can do a binary search.
static GLSL_TYPE_NAMES: [(&'static str, Type); 68] = [
("dmat2", Type::DMat2),
("dmat2x2", Type::DMat2),
("dmat2x3", Type::DMat2x3),
("dmat2x4", Type::DMat2x4),
("dmat3", Type::DMat3),
("dmat3x2", Type::DMat3x2),
("dmat3x3", Type::DMat3),
("dmat3x4", Type::DMat3x4),
("dmat4", Type::DMat4),
("dmat4x2", Type::DMat4x2),
("dmat4x3", Type::DMat4x3),
("dmat4x4", Type::DMat4),
("double", Type::Double),
("dvec2", Type::DVec2),
("dvec3", Type::DVec3),
("dvec4", Type::DVec4),
("f16vec2", Type::F16Vec2),
("f16vec3", Type::F16Vec3),
("f16vec4", Type::F16Vec4),
("float", Type::Float),
("float16_t", Type::Float16),
("i16vec2", Type::I16Vec2),
("i16vec3", Type::I16Vec3),
("i16vec4", Type::I16Vec4),
("i64vec2", Type::I64Vec2),
("i64vec3", Type::I64Vec3),
("i64vec4", Type::I64Vec4),
("i8vec2", Type::I8Vec2),
("i8vec3", Type::I8Vec3),
("i8vec4", Type::I8Vec4),
("int", Type::Int),
("int16_t", Type::Int16),
("int64_t", Type::Int64),
("int8_t", Type::Int8),
("ivec2", Type::IVec2),
("ivec3", Type::IVec3),
("ivec4", Type::IVec4),
("mat2", Type::Mat2),
("mat2x2", Type::Mat2),
("mat2x3", Type::Mat2x3),
("mat2x4", Type::Mat2x4),
("mat3", Type::Mat3),
("mat3x2", Type::Mat3x2),
("mat3x3", Type::Mat3),
("mat3x4", Type::Mat3x4),
("mat4", Type::Mat4),
("mat4x2", Type::Mat4x2),
("mat4x3", Type::Mat4x3),
("mat4x4", Type::Mat4),
("u16vec2", Type::U16Vec2),
("u16vec3", Type::U16Vec3),
("u16vec4", Type::U16Vec4),
("u64vec2", Type::U64Vec2),
("u64vec3", Type::U64Vec3),
("u64vec4", Type::U64Vec4),
("u8vec2", Type::U8Vec2),
("u8vec3", Type::U8Vec3),
("u8vec4", Type::U8Vec4),
("uint", Type::UInt),
("uint16_t", Type::UInt16),
("uint64_t", Type::UInt64),
("uint8_t", Type::UInt8),
("uvec2", Type::UVec2),
("uvec3", Type::UVec3),
("uvec4", Type::UVec4),
("vec2", Type::Vec2),
("vec3", Type::Vec3),
("vec4", Type::Vec4),
];
// Mapping from operator name to Comparison. Sorted by ASCII values so
// we can do a binary search.
static COMPARISON_NAMES: [(&'static str, Comparison); 7] = [
("!=", Comparison::NotEqual),
("<", Comparison::Less),
("<=", Comparison::LessEqual),
("==", Comparison::Equal),
(">", Comparison::Greater),
(">=", Comparison::GreaterEqual),
("~=", Comparison::FuzzyEqual),
];
impl BaseType {
/// Returns the size in bytes of a variable of this base type.
pub fn size(self) -> usize {
match self {
BaseType::Int => mem::size_of::<i32>(),
BaseType::UInt => mem::size_of::<u32>(),
BaseType::Int8 => mem::size_of::<i8>(),
BaseType::UInt8 => mem::size_of::<u8>(),
BaseType::Int16 => mem::size_of::<i16>(),
BaseType::UInt16 => mem::size_of::<u16>(),
BaseType::Int64 => mem::size_of::<i64>(),
BaseType::UInt64 => mem::size_of::<u64>(),
BaseType::Float16 => mem::size_of::<u16>(),
BaseType::Float => mem::size_of::<u32>(),
BaseType::Double => mem::size_of::<u64>(),
}
}
}
impl<'a> BaseTypeInSlice<'a> {
/// Combines the given base type and slice into a
/// `BaseTypeInSlice` so that it can be displayed with the
/// [Display](std::fmt::Display) trait. The slice is expected to
/// contain a representation of the base type in native-endian
/// order. The slice needs to have the same size as
/// `base_type.size()` or the constructor will panic.
pub fn new(base_type: BaseType, slice: &'a [u8]) -> BaseTypeInSlice<'a> {
assert_eq!(slice.len(), base_type.size());
BaseTypeInSlice { base_type, slice }
}
}
impl<'a> fmt::Display for BaseTypeInSlice<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.base_type {
BaseType::Int => {
let v = i32::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::UInt => {
let v = u32::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::Int8 => {
let v = i8::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::UInt8 => {
let v = u8::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::Int16 => {
let v = i16::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::UInt16 => {
let v = u16::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::Int64 => {
let v = i64::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::UInt64 => {
let v = u64::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::Float16 => {
let v = u16::from_ne_bytes(self.slice.try_into().unwrap());
let v = half_float::to_f64(v);
write!(f, "{}", v)
},
BaseType::Float => {
let v = f32::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
BaseType::Double => {
let v = f64::from_ne_bytes(self.slice.try_into().unwrap());
write!(f, "{}", v)
},
}
}
}
impl Type {
/// Returns the type of the components of this slot type.
#[inline]
pub fn base_type(self) -> BaseType {
TYPE_INFOS[self as usize].base_type
}
/// Returns the number of rows in this slot type. For primitive types
/// this will be 1, for vectors it will be the size of the vector
/// and for matrix types it will be the number of rows.
#[inline]
pub fn rows(self) -> usize {
TYPE_INFOS[self as usize].rows
}
/// Return the number of columns in this slot type. For non-matrix
/// types this will be 1.
#[inline]
pub fn columns(self) -> usize {
TYPE_INFOS[self as usize].columns
}
/// Returns whether the type is a matrix type
#[inline]
pub fn is_matrix(self) -> bool {
self.columns() > 1
}
// Returns the effective major axis setting that should be used
// when calculating offsets. The setting should only affect matrix
// types.
fn effective_major_axis(self, layout: Layout) -> MajorAxis {
if self.is_matrix() {
layout.major
} else {
MajorAxis::Column
}
}
// Return a tuple containing the size of the major and minor axes
// (in that order) given the type and layout
fn major_minor(self, layout: Layout) -> (usize, usize) {
match self.effective_major_axis(layout) {
MajorAxis::Column => (self.columns(), self.rows()),
MajorAxis::Row => (self.rows(), self.columns()),
}
}
/// Returns the matrix stride of the slot type. This is the offset
/// in bytes between consecutive components of the minor axis. For
/// example, in a column-major layout, the matrix stride of a vec4
/// will be 16, because to move to the next column you need to add
/// an offset of 16 bytes to the address of the previous column.
pub fn matrix_stride(self, layout: Layout) -> usize {
let component_size = self.base_type().size();
let (_major, minor) = self.major_minor(layout);
let base_stride = if minor == 3 {
component_size * 4
} else {
component_size * minor
};
match layout.std {
LayoutStd::Std140 => {
// According to std140 the size is rounded up to a vec4
util::align(base_stride, 16)
},
LayoutStd::Std430 => base_stride,
}
}
/// Returns the offset between members of the elements of an array
/// of this type of slot. There may be padding between the members
/// to fulfill the alignment criteria of the layout.
pub fn array_stride(self, layout: Layout) -> usize {
let matrix_stride = self.matrix_stride(layout);
let (major, _minor) = self.major_minor(layout);
matrix_stride * major
}
/// Returns the size of the slot type. Note that unlike
/// [array_stride](Type::array_stride), this won’t include the
/// padding to align the values in an array according to the GLSL
/// layout rules. It will however always have a size that is a
/// multiple of the size of the base type so it is suitable as an
/// array stride for packed values stored internally that aren’t
/// passed to Vulkan.
pub fn size(self, layout: Layout) -> usize {
let matrix_stride = self.matrix_stride(layout);
let base_size = self.base_type().size();
let (major, minor) = self.major_minor(layout);
(major - 1) * matrix_stride + base_size * minor
}
/// Returns an iterator that iterates over the offsets of the
/// components of this slot. This can be used to extract the
/// values out of an array of bytes containing the slot. The
/// offsets are returned in column-major order (ie, all of the
/// offsets for a row are returned before moving to the next
/// column), but the offsets are calculated to take into account
/// the major axis setting in the layout.
pub fn offsets(self, layout: Layout) -> OffsetsIter {
OffsetsIter::new(self, layout)
}
/// Returns a [Type] given the GLSL type name, for example `dmat2`
/// or `i64vec3`. `None` will be returned if the type name isn’t
/// recognised.
pub fn from_glsl_type(type_name: &str) -> Option<Type> {
match GLSL_TYPE_NAMES.binary_search_by(
|&(name, _)| name.cmp(type_name)
) {
Ok(pos) => Some(GLSL_TYPE_NAMES[pos].1),
Err(_) => None,
}
}
}
/// Iterator over the offsets into an array to extract the components
/// of a slot. Created by [Type::offsets].
#[derive(Debug, Clone)]
pub struct OffsetsIter {
slot_type: Type,
row: usize,
column: usize,
offset: usize,
column_offset: isize,
row_offset: isize,
}
impl OffsetsIter {
fn new(slot_type: Type, layout: Layout) -> OffsetsIter {
let column_offset;
let row_offset;
let stride = slot_type.matrix_stride(layout);
let base_size = slot_type.base_type().size();
match slot_type.effective_major_axis(layout) {
MajorAxis::Column => {
column_offset =
(stride - base_size * slot_type.rows()) as isize;
row_offset = base_size as isize;
},
MajorAxis::Row => {
column_offset =
base_size as isize - (stride * slot_type.rows()) as isize;
row_offset = stride as isize;
},
}
OffsetsIter {
slot_type,
row: 0,
column: 0,
offset: 0,
column_offset,
row_offset,
}
}
}
impl Iterator for OffsetsIter {
type Item = usize;
fn next(&mut self) -> Option<usize> {
if self.column >= self.slot_type.columns() {
return None;
}
let result = self.offset;
self.offset = ((self.offset as isize) + self.row_offset) as usize;
self.row += 1;
if self.row >= self.slot_type.rows() {
self.row = 0;
self.column += 1;
self.offset =
((self.offset as isize) + self.column_offset) as usize;
}
Some(result)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let n_columns = self.slot_type.columns();
let size = if self.column >= n_columns {
0
} else {
(self.slot_type.rows() - self.row)
+ ((n_columns - 1 - self.column) * self.slot_type.rows())
};
(size, Some(size))
}
fn count(self) -> usize {
self.size_hint().0
}
}
impl Comparison {
fn compare_without_fuzzy<T: PartialOrd>(self, a: T, b: T) -> bool {
match self {
Comparison::Equal | Comparison::FuzzyEqual => a.eq(&b),
Comparison::NotEqual => a.ne(&b),
Comparison::Less => a.lt(&b),
Comparison::GreaterEqual => a.ge(&b),
Comparison::Greater => a.gt(&b),
Comparison::LessEqual => a.le(&b),
}
}
fn compare_value(
self,
tolerance: &Tolerance,
component: usize,
base_type: BaseType,
a: &[u8],
b: &[u8]
) -> bool {
match base_type {
BaseType::Int => {
let a = i32::from_ne_bytes(a.try_into().unwrap());
let b = i32::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::UInt => {
let a = u32::from_ne_bytes(a.try_into().unwrap());
let b = u32::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::Int8 => {
let a = i8::from_ne_bytes(a.try_into().unwrap());
let b = i8::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::UInt8 => {
let a = u8::from_ne_bytes(a.try_into().unwrap());
let b = u8::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::Int16 => {
let a = i16::from_ne_bytes(a.try_into().unwrap());
let b = i16::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::UInt16 => {
let a = u16::from_ne_bytes(a.try_into().unwrap());
let b = u16::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::Int64 => {
let a = i64::from_ne_bytes(a.try_into().unwrap());
let b = i64::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::UInt64 => {
let a = u64::from_ne_bytes(a.try_into().unwrap());
let b = u64::from_ne_bytes(b.try_into().unwrap());
self.compare_without_fuzzy(a, b)
},
BaseType::Float16 => {
let a = u16::from_ne_bytes(a.try_into().unwrap());
let a = half_float::to_f64(a);
let b = u16::from_ne_bytes(b.try_into().unwrap());
let b = half_float::to_f64(b);
match self {
Comparison::FuzzyEqual => tolerance.equal(component, a, b),
Comparison::Equal
| Comparison::NotEqual
| Comparison::Less
| Comparison::GreaterEqual
| Comparison::Greater
| Comparison::LessEqual => {
self.compare_without_fuzzy(a, b)
},
}
},
BaseType::Float => {
let a = f32::from_ne_bytes(a.try_into().unwrap());
let b = f32::from_ne_bytes(b.try_into().unwrap());
match self {
Comparison::FuzzyEqual => {
tolerance.equal(component, a as f64, b as f64)
},
Comparison::Equal
| Comparison::NotEqual
| Comparison::Less
| Comparison::GreaterEqual
| Comparison::Greater
| Comparison::LessEqual => {
self.compare_without_fuzzy(a, b)
},
}
},
BaseType::Double => {
let a = f64::from_ne_bytes(a.try_into().unwrap());
let b = f64::from_ne_bytes(b.try_into().unwrap());
match self {
Comparison::FuzzyEqual => {
tolerance.equal(component, a, b)
},
Comparison::Equal
| Comparison::NotEqual
| Comparison::Less
| Comparison::GreaterEqual
| Comparison::Greater
| Comparison::LessEqual => {
self.compare_without_fuzzy(a, b)
},
}
},
}
}
/// Compares the values in `a` with the values in `b` using the
/// chosen criteria. The values are extracted from the two byte
/// slices using the layout and type provided. The tolerance
/// parameter is only used if the comparison is
/// [FuzzyEqual](Comparison::FuzzyEqual).
pub fn compare(
self,
tolerance: &Tolerance,
slot_type: Type,
layout: Layout,
a: &[u8],
b: &[u8],
) -> bool {
let base_type = slot_type.base_type();
let base_type_size = base_type.size();
let rows = slot_type.rows();
for (i, offset) in slot_type.offsets(layout).enumerate() {
let a = &a[offset..offset + base_type_size];
let b = &b[offset..offset + base_type_size];
if !self.compare_value(tolerance, i % rows, base_type, a, b) {
return false;
}
}
true
}
/// Get the comparison operator given the name in GLSL like `"<"`
/// or `">="`. It also accepts `"~="` to return the fuzzy equal
/// comparison. If the operator isn’t recognised then it will
/// return `None`.
pub fn from_operator(operator: &str) -> Option<Comparison> {
match COMPARISON_NAMES.binary_search_by(
|&(probe, _value)| probe.cmp(operator)
) {
Ok(pos) => Some(COMPARISON_NAMES[pos].1),
Err(_) => None,
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_type_info() {
// Test a few types
// First
assert!(matches!(Type::Int.base_type(), BaseType::Int));
assert_eq!(Type::Int.columns(), 1);
assert_eq!(Type::Int.rows(), 1);
assert!(!Type::Int.is_matrix());
// Last
assert!(matches!(Type::DMat4.base_type(), BaseType::Double));
assert_eq!(Type::DMat4.columns(), 4);
assert_eq!(Type::DMat4.rows(), 4);
assert!(Type::DMat4.is_matrix());
// Vector
assert!(matches!(Type::Vec4.base_type(), BaseType::Float));
assert_eq!(Type::Vec4.columns(), 1);
assert_eq!(Type::Vec4.rows(), 4);
assert!(!Type::Vec4.is_matrix());
// Non-square matrix
assert!(matches!(Type::DMat4x3.base_type(), BaseType::Double));
assert_eq!(Type::DMat4x3.columns(), 4);
assert_eq!(Type::DMat4x3.rows(), 3);
assert!(Type::DMat4x3.is_matrix());
}
#[test]
fn test_base_type_size() {
assert_eq!(BaseType::Int.size(), 4);
assert_eq!(BaseType::UInt.size(), 4);
assert_eq!(BaseType::Int8.size(), 1);
assert_eq!(BaseType::UInt8.size(), 1);
assert_eq!(BaseType::Int16.size(), 2);
assert_eq!(BaseType::UInt16.size(), 2);
assert_eq!(BaseType::Int64.size(), 8);
assert_eq!(BaseType::UInt64.size(), 8);
assert_eq!(BaseType::Float16.size(), 2);
assert_eq!(BaseType::Float.size(), 4);
assert_eq!(BaseType::Double.size(), 8);
}
#[test]
fn test_matrix_stride() {
assert_eq!(
Type::Mat4.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
16,
);
// In Std140 the vecs along the minor axis are aligned to the
// size of a vec4
assert_eq!(
Type::Mat4x2.matrix_stride(
Layout { std: LayoutStd::Std140, major: MajorAxis::Column }
),
16,
);
// In Std430 they are not
assert_eq!(
Type::Mat4x2.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
8,
);
// 3-component axis are always aligned to a vec4 regardless of
// the layout standard
assert_eq!(
Type::Mat4x3.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
16,
);
// For the row-major the stride the axes are reversed
assert_eq!(
Type::Mat4x2.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
),
16,
);
assert_eq!(
Type::Mat2x4.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
),
8,
);
// For non-matrix types the matrix stride is the array stride
assert_eq!(
Type::Vec3.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
16,
);
assert_eq!(
Type::Float.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
4,
);
// Row-major layout does not affect vectors
assert_eq!(
Type::Vec2.matrix_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Row }
),
8,
);
}
#[test]
fn test_array_stride() {
assert_eq!(
Type::UInt8.array_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
1,
);
assert_eq!(
Type::I8Vec2.array_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
2,
);
assert_eq!(
Type::I8Vec3.array_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
4,
);
assert_eq!(
Type::Mat4x3.array_stride(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
64,
);
}
#[test]
fn test_size() {
assert_eq!(
Type::UInt8.size(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
1,
);
assert_eq!(
Type::I8Vec2.size(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
2,
);
assert_eq!(
Type::I8Vec3.size(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
3,
);
assert_eq!(
Type::Mat4x3.size(
Layout { std: LayoutStd::Std430, major: MajorAxis::Column }
),
3 * 4 * 4 + 3 * 4,
);
}
#[test]
fn test_iterator() {
let default_layout = Layout {
std: LayoutStd::Std140,
major: MajorAxis::Column,
};
// The components of a vec4 are packed tightly
let vec4_offsets =
Type::Vec4.offsets(default_layout).collect::<Vec<usize>>();
assert_eq!(vec4_offsets, vec![0, 4, 8, 12]);
// The major axis setting shouldn’t affect vectors
let row_major_layout = Layout {
std: LayoutStd::Std140,
major: MajorAxis::Row,
};
let vec4_offsets =
Type::Vec4.offsets(row_major_layout).collect::<Vec<usize>>();
assert_eq!(vec4_offsets, vec![0, 4, 8, 12]);
// Components of a mat4 are packed tightly
let mat4_offsets =
Type::Mat4.offsets(default_layout).collect::<Vec<usize>>();
assert_eq!(
mat4_offsets,
vec![0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60],
);
// If there are 3 rows then they are padded out to align
// each column to the size of a vec4
let mat4x3_offsets =
Type::Mat4x3.offsets(default_layout).collect::<Vec<usize>>();
assert_eq!(
mat4x3_offsets,
vec![0, 4, 8, 16, 20, 24, 32, 36, 40, 48, 52, 56],
);
// If row-major order is being used then the offsets are still
// reported in column-major order but the addresses that they
// point to are in row-major order.
// [0 4 8]
// [16 20 24]
// [32 36 40]
// [48 52 56]
let mat3x4_offsets =
Type::Mat3x4.offsets(row_major_layout).collect::<Vec<usize>>();
assert_eq!(
mat3x4_offsets,
vec![0, 16, 32, 48, 4, 20, 36, 52, 8, 24, 40, 56],
);
// Check the size_hint and count. The size hint is always
// exact and should match the count.
let mut iter = Type::Mat4.offsets(default_layout);
for i in 0..16 {
let expected_count = 16 - i;
assert_eq!(
iter.size_hint(),
(expected_count, Some(expected_count))
);
assert_eq!(iter.clone().count(), expected_count);
assert!(matches!(iter.next(), Some(_)));
}
assert_eq!(iter.size_hint(), (0, Some(0)));
assert_eq!(iter.clone().count(), 0);
assert!(matches!(iter.next(), None));
}
#[test]
fn test_compare() {
// Test every comparison mode
assert!(Comparison::Equal.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[5],
));
assert!(!Comparison::Equal.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[6],
&[5],
));
assert!(Comparison::NotEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[6],
&[5],
));
assert!(!Comparison::NotEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[5],
));
assert!(Comparison::Less.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[4],
&[5],
));
assert!(!Comparison::Less.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[5],
));
assert!(!Comparison::Less.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[6],
&[5],
));
assert!(Comparison::LessEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[4],
&[5],
));
assert!(Comparison::LessEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[5],
));
assert!(!Comparison::LessEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[6],
&[5],
));
assert!(Comparison::Greater.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[4],
));
assert!(!Comparison::Greater.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[5],
));
assert!(!Comparison::Greater.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[6],
));
assert!(Comparison::GreaterEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[4],
));
assert!(Comparison::GreaterEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[5],
));
assert!(!Comparison::GreaterEqual.compare(
&Default::default(), // tolerance
Type::UInt8,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5],
&[6],
));
assert!(Comparison::FuzzyEqual.compare(
&Tolerance::new([1.0; 4], false),
Type::Float,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&5.0f32.to_ne_bytes(),
&5.9f32.to_ne_bytes(),
));
assert!(!Comparison::FuzzyEqual.compare(
&Tolerance::new([1.0; 4], false),
Type::Float,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&5.0f32.to_ne_bytes(),
&6.1f32.to_ne_bytes(),
));
// Test all types
macro_rules! test_compare_type_with_values {
($type_enum:expr, $low_value:expr, $high_value:expr) => {
assert!(Comparison::Less.compare(
&Default::default(),
$type_enum,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&$low_value.to_ne_bytes(),
&$high_value.to_ne_bytes(),
));
};
}
macro_rules! test_compare_simple_type {
($type_num:expr, $rust_type:ty) => {
test_compare_type_with_values!(
$type_num,
0 as $rust_type,
<$rust_type>::MAX
);
};
}
test_compare_simple_type!(Type::Int, i32);
test_compare_simple_type!(Type::UInt, u32);
test_compare_simple_type!(Type::Int8, i8);
test_compare_simple_type!(Type::UInt8, u8);
test_compare_simple_type!(Type::Int16, i16);
test_compare_simple_type!(Type::UInt16, u16);
test_compare_simple_type!(Type::Int64, i64);
test_compare_simple_type!(Type::UInt64, u64);
test_compare_simple_type!(Type::Float, f32);
test_compare_simple_type!(Type::Double, f64);
test_compare_type_with_values!(
Type::Float16,
half_float::from_f32(-100.0),
half_float::from_f32(300.0)
);
macro_rules! test_compare_big_type {
($type_enum:expr, $bytes:expr) => {
let layout = Layout {
std: LayoutStd::Std140,
major: MajorAxis::Column
};
let mut bytes = Vec::new();
assert_eq!($type_enum.size(layout) % $bytes.len(), 0);
for _ in 0..($type_enum.size(layout) / $bytes.len()) {
bytes.extend_from_slice($bytes);
}
assert!(Comparison::FuzzyEqual.compare(
&Default::default(),
$type_enum,
layout,
&bytes,
&bytes,
));
};
}
test_compare_big_type!(
Type::F16Vec2,
&half_float::from_f32(1.0).to_ne_bytes()
);
test_compare_big_type!(
Type::F16Vec3,
&half_float::from_f32(1.0).to_ne_bytes()
);
test_compare_big_type!(
Type::F16Vec4,
&half_float::from_f32(1.0).to_ne_bytes()
);
test_compare_big_type!(Type::Vec2, &42.0f32.to_ne_bytes());
test_compare_big_type!(Type::Vec3, &24.0f32.to_ne_bytes());
test_compare_big_type!(Type::Vec4, &6.0f32.to_ne_bytes());
test_compare_big_type!(Type::DVec2, &42.0f64.to_ne_bytes());
test_compare_big_type!(Type::DVec3, &24.0f64.to_ne_bytes());
test_compare_big_type!(Type::DVec4, &6.0f64.to_ne_bytes());
test_compare_big_type!(Type::IVec2, &(-42i32).to_ne_bytes());
test_compare_big_type!(Type::IVec3, &(-24i32).to_ne_bytes());
test_compare_big_type!(Type::IVec4, &(-6i32).to_ne_bytes());
test_compare_big_type!(Type::UVec2, &42u32.to_ne_bytes());
test_compare_big_type!(Type::UVec3, &24u32.to_ne_bytes());
test_compare_big_type!(Type::UVec4, &6u32.to_ne_bytes());
test_compare_big_type!(Type::I8Vec2, &(-42i8).to_ne_bytes());
test_compare_big_type!(Type::I8Vec3, &(-24i8).to_ne_bytes());
test_compare_big_type!(Type::I8Vec4, &(-6i8).to_ne_bytes());
test_compare_big_type!(Type::U8Vec2, &42u8.to_ne_bytes());
test_compare_big_type!(Type::U8Vec3, &24u8.to_ne_bytes());
test_compare_big_type!(Type::U8Vec4, &6u8.to_ne_bytes());
test_compare_big_type!(Type::I16Vec2, &(-42i16).to_ne_bytes());
test_compare_big_type!(Type::I16Vec3, &(-24i16).to_ne_bytes());
test_compare_big_type!(Type::I16Vec4, &(-6i16).to_ne_bytes());
test_compare_big_type!(Type::U16Vec2, &42u16.to_ne_bytes());
test_compare_big_type!(Type::U16Vec3, &24u16.to_ne_bytes());
test_compare_big_type!(Type::U16Vec4, &6u16.to_ne_bytes());
test_compare_big_type!(Type::I64Vec2, &(-42i64).to_ne_bytes());
test_compare_big_type!(Type::I64Vec3, &(-24i64).to_ne_bytes());
test_compare_big_type!(Type::I64Vec4, &(-6i64).to_ne_bytes());
test_compare_big_type!(Type::U64Vec2, &42u64.to_ne_bytes());
test_compare_big_type!(Type::U64Vec3, &24u64.to_ne_bytes());
test_compare_big_type!(Type::U64Vec4, &6u64.to_ne_bytes());
test_compare_big_type!(Type::Mat2, &42.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat2x3, &42.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat2x4, &42.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat3, &24.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat3x2, &24.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat3x4, &24.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat4, &6.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat4x2, &6.0f32.to_ne_bytes());
test_compare_big_type!(Type::Mat4x3, &6.0f32.to_ne_bytes());
test_compare_big_type!(Type::DMat2, &42.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat2x3, &42.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat2x4, &42.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat3, &24.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat3x2, &24.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat3x4, &24.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat4, &6.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat4x2, &6.0f64.to_ne_bytes());
test_compare_big_type!(Type::DMat4x3, &6.0f64.to_ne_bytes());
// Check that it the comparison fails if a value other than
// the first one is different
assert!(!Comparison::Equal.compare(
&Default::default(), // tolerance
Type::U8Vec4,
Layout { std: LayoutStd::Std140, major: MajorAxis::Column },
&[5, 6, 7, 8],
&[5, 6, 7, 9],
));
}
#[test]
fn test_from_glsl_type() {
assert_eq!(Type::from_glsl_type("vec3"), Some(Type::Vec3));
assert_eq!(Type::from_glsl_type("dvec4"), Some(Type::DVec4));
assert_eq!(Type::from_glsl_type("i64vec4"), Some(Type::I64Vec4));
assert_eq!(Type::from_glsl_type("uint"), Some(Type::UInt));
assert_eq!(Type::from_glsl_type("uint16_t"), Some(Type::UInt16));
assert_eq!(Type::from_glsl_type("dvec5"), None);
// Check that all types can be found with the binary search
for &(type_name, slot_type) in GLSL_TYPE_NAMES.iter() {
assert_eq!(Type::from_glsl_type(type_name), Some(slot_type));
}
}
#[test]
fn test_comparison_from_operator() {
// If a comparison is added then please add it to this test too
assert_eq!(COMPARISON_NAMES.len(), 7);
assert_eq!(
Comparison::from_operator("!="),
Some(Comparison::NotEqual),
);
assert_eq!(
Comparison::from_operator("<"),
Some(Comparison::Less),
);
assert_eq!(
Comparison::from_operator("<="),
Some(Comparison::LessEqual),
);
assert_eq!(
Comparison::from_operator("=="),
Some(Comparison::Equal),
);
assert_eq!(
Comparison::from_operator(">"),
Some(Comparison::Greater),
);
assert_eq!(
Comparison::from_operator(">="),
Some(Comparison::GreaterEqual),
);
assert_eq!(
Comparison::from_operator("~="),
Some(Comparison::FuzzyEqual),
);
assert_eq!(Comparison::from_operator("<=>"), None);
}
#[track_caller]
fn check_base_type_format(
base_type: BaseType,
bytes: &[u8],
expected: &str,
) {
assert_eq!(
&BaseTypeInSlice::new(base_type, bytes).to_string(),
expected
);
}
#[test]
fn base_type_format() {
check_base_type_format(BaseType::Int, &1234i32.to_ne_bytes(), "1234");
check_base_type_format(
BaseType::UInt,
&u32::MAX.to_ne_bytes(),
"4294967295",
);
check_base_type_format(BaseType::Int8, &(-4i8).to_ne_bytes(), "-4");
check_base_type_format(BaseType::UInt8, &129u8.to_ne_bytes(), "129");
check_base_type_format(BaseType::Int16, &(-4i16).to_ne_bytes(), "-4");
check_base_type_format(
BaseType::UInt16,
&32769u16.to_ne_bytes(),
"32769"
);
check_base_type_format(BaseType::Int64, &(-4i64).to_ne_bytes(), "-4");
check_base_type_format(
BaseType::UInt64,
&u64::MAX.to_ne_bytes(),
"18446744073709551615"
);
check_base_type_format(
BaseType::Float16,
&0xc000u16.to_ne_bytes(),
"-2"
);
check_base_type_format(
BaseType::Float,
&2.0f32.to_ne_bytes(),
"2"
);
check_base_type_format(
BaseType::Double,
&2.0f64.to_ne_bytes(),
"2"
);
}
}
```
--------------------------------------------------------------------------------
/src/main.rs:
--------------------------------------------------------------------------------
```rust
use anyhow::Result;
use clap::Parser;
use image::codecs::pnm::PnmDecoder;
use image::{DynamicImage, ImageError, RgbImage};
use rmcp::{
Error as McpError, RoleServer, ServerHandler, ServiceExt, const_string, model::*, schemars,
service::RequestContext, tool, transport::stdio,
};
use serde_json::json;
use shaderc::{self, CompileOptions, Compiler, OptimizationLevel, ShaderKind};
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing_subscriber::{self, EnvFilter};
pub fn read_and_decode_ppm_file<P: AsRef<Path>>(path: P) -> Result<RgbImage, ImageError> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let decoder: PnmDecoder<BufReader<File>> = PnmDecoder::new(reader)?;
let dynamic_image = DynamicImage::from_decoder(decoder)?;
let rgb_image = dynamic_image.into_rgb8();
Ok(rgb_image)
}
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderStage {
#[schemars(description = "Vertex processing stage (transforms vertices)")]
Vert,
#[schemars(description = "Fragment processing stage (determines pixel colors)")]
Frag,
#[schemars(description = "Tessellation control stage for patch control points")]
Tesc,
#[schemars(description = "Tessellation evaluation stage for computing tessellated geometry")]
Tese,
#[schemars(description = "Geometry stage for generating/modifying primitives")]
Geom,
#[schemars(description = "Compute stage for general-purpose computation")]
Comp,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerRequire {
#[schemars(
description = "Enables cooperative matrix operations with specified dimensions and data type"
)]
CooperativeMatrix {
m: u32,
n: u32,
component_type: String,
},
#[schemars(
description = "Specifies depth/stencil format for depth testing and stencil operations"
)]
DepthStencil(String),
#[schemars(description = "Specifies framebuffer format for render target output")]
Framebuffer(String),
#[schemars(description = "Enables double-precision floating point operations in shaders")]
ShaderFloat64,
#[schemars(description = "Enables geometry shader stage for primitive manipulation")]
GeometryShader,
#[schemars(description = "Enables lines with width greater than 1.0 pixel")]
WideLines,
#[schemars(description = "Enables bitwise logical operations on framebuffer contents")]
LogicOp,
#[schemars(description = "Specifies required subgroup size for shader execution")]
SubgroupSize(u32),
#[schemars(description = "Enables memory stores and atomic operations in fragment shaders")]
FragmentStoresAndAtomics,
#[schemars(description = "Enables shaders to use raw buffer addresses")]
BufferDeviceAddress,
}
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerPass {
#[schemars(
description = "Use a built-in pass-through vertex shader (forwards attributes to fragment shader)"
)]
VertPassthrough,
#[schemars(description = "Use compiled vertex shader from specified SPIR-V assembly file")]
VertSpirv {
#[schemars(
description = "Path to the compiled vertex shader SPIR-V assembly (.spvasm) file"
)]
vert_spvasm_path: String,
},
#[schemars(description = "Use compiled fragment shader from specified SPIR-V assembly file")]
FragSpirv {
#[schemars(
description = "Path to the compiled fragment shader SPIR-V assembly (.spvasm) file"
)]
frag_spvasm_path: String,
},
#[schemars(description = "Use compiled compute shader from specified SPIR-V assembly file")]
CompSpirv {
#[schemars(
description = "Path to the compiled compute shader SPIR-V assembly (.spvasm) file"
)]
comp_spvasm_path: String,
},
#[schemars(description = "Use compiled geometry shader from specified SPIR-V assembly file")]
GeomSpirv {
#[schemars(
description = "Path to the compiled geometry shader SPIR-V assembly (.spvasm) file"
)]
geom_spvasm_path: String,
},
#[schemars(
description = "Use compiled tessellation control shader from specified SPIR-V assembly file"
)]
TescSpirv {
#[schemars(
description = "Path to the compiled tessellation control shader SPIR-V assembly (.spvasm) file"
)]
tesc_spvasm_path: String,
},
#[schemars(
description = "Use compiled tessellation evaluation shader from specified SPIR-V assembly file"
)]
TeseSpirv {
#[schemars(
description = "Path to the compiled tessellation evaluation shader SPIR-V assembly (.spvasm) file"
)]
tese_spvasm_path: String,
},
}
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerVertexData {
#[schemars(description = "Defines an attribute format at a shader location/binding")]
AttributeFormat {
#[schemars(description = "Location/binding number in the shader")]
location: u32,
#[schemars(description = "Format name (e.g., R32G32_SFLOAT, A8B8G8R8_UNORM_PACK32)")]
format: String,
},
#[schemars(description = "2D position data (for R32G32_SFLOAT format)")]
Vec2 {
#[schemars(description = "X coordinate value")]
x: f32,
#[schemars(description = "Y coordinate value")]
y: f32,
},
#[schemars(description = "3D position data (for R32G32B32_SFLOAT format)")]
Vec3 {
#[schemars(description = "X coordinate value")]
x: f32,
#[schemars(description = "Y coordinate value")]
y: f32,
#[schemars(description = "Z coordinate value")]
z: f32,
},
#[schemars(description = "4D position/color data (for R32G32B32A32_SFLOAT format)")]
Vec4 {
#[schemars(description = "X coordinate or Red component")]
x: f32,
#[schemars(description = "Y coordinate or Green component")]
y: f32,
#[schemars(description = "Z coordinate or Blue component")]
z: f32,
#[schemars(description = "W coordinate or Alpha component")]
w: f32,
},
#[schemars(description = "RGB color data (for R8G8B8_UNORM format)")]
RGB {
#[schemars(description = "Red component (0-255)")]
r: u8,
#[schemars(description = "Green component (0-255)")]
g: u8,
#[schemars(description = "Blue component (0-255)")]
b: u8,
},
#[schemars(description = "ARGB color as hex (for A8B8G8R8_UNORM_PACK32 format)")]
Hex {
#[schemars(description = "Color as hex string (0xAARRGGBB format)")]
value: String,
},
#[schemars(description = "Generic data components for custom formats")]
GenericComponents {
#[schemars(description = "Component values as strings, interpreted by format")]
components: Vec<String>,
},
}
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub enum ShaderRunnerTest {
#[schemars(description = "Set fragment shader entrypoint function name")]
FragmentEntrypoint {
#[schemars(description = "Function name to use as entrypoint")]
name: String,
},
#[schemars(description = "Set vertex shader entrypoint function name")]
VertexEntrypoint {
#[schemars(description = "Function name to use as entrypoint")]
name: String,
},
#[schemars(description = "Set compute shader entrypoint function name")]
ComputeEntrypoint {
#[schemars(description = "Function name to use as entrypoint")]
name: String,
},
#[schemars(description = "Set geometry shader entrypoint function name")]
GeometryEntrypoint {
#[schemars(description = "Function name to use as entrypoint")]
name: String,
},
#[schemars(description = "Draw a rectangle using normalized device coordinates")]
DrawRect {
#[schemars(description = "X coordinate of lower-left corner")]
x: f32,
#[schemars(description = "Y coordinate of lower-left corner")]
y: f32,
#[schemars(description = "Width of the rectangle")]
width: f32,
#[schemars(description = "Height of the rectangle")]
height: f32,
},
#[schemars(description = "Draw primitives using vertex data")]
DrawArrays {
#[schemars(description = "Primitive type (TRIANGLE_LIST, POINT_LIST, etc.)")]
primitive_type: String,
#[schemars(description = "Index of first vertex")]
first: u32,
#[schemars(description = "Number of vertices to draw")]
count: u32,
},
#[schemars(description = "Draw primitives using indexed vertex data")]
DrawArraysIndexed {
#[schemars(description = "Primitive type (TRIANGLE_LIST, etc.)")]
primitive_type: String,
#[schemars(description = "Index of first index")]
first: u32,
#[schemars(description = "Number of indices to use")]
count: u32,
},
#[schemars(description = "Create or initialize a Shader Storage Buffer Object (SSBO)")]
SSBO {
#[schemars(description = "Binding point in the shader")]
binding: u32,
#[schemars(description = "Size in bytes (if data not provided)")]
size: Option<u32>,
#[schemars(description = "Initial buffer contents")]
data: Option<Vec<u8>>,
#[schemars(description = "Descriptor set number (default: 0)")]
descriptor_set: Option<u32>,
},
#[schemars(description = "Update a portion of an SSBO with new data")]
SSBOSubData {
#[schemars(description = "Binding point in the shader")]
binding: u32,
#[schemars(description = "Data type (float, vec4, etc.)")]
data_type: String,
#[schemars(description = "Byte offset into the buffer")]
offset: u32,
#[schemars(description = "Values to write (as strings)")]
values: Vec<String>,
#[schemars(description = "Descriptor set number (default: 0)")]
descriptor_set: Option<u32>,
},
#[schemars(description = "Create or initialize a Uniform Buffer Object (UBO)")]
UBO {
#[schemars(description = "Binding point in the shader")]
binding: u32,
#[schemars(description = "Buffer contents")]
data: Vec<u8>,
#[schemars(description = "Descriptor set number (default: 0)")]
descriptor_set: Option<u32>,
},
#[schemars(description = "Update a portion of a UBO with new data")]
UBOSubData {
#[schemars(description = "Binding point in the shader")]
binding: u32,
#[schemars(description = "Data type (float, vec4, etc.)")]
data_type: String,
#[schemars(description = "Byte offset into the buffer")]
offset: u32,
#[schemars(description = "Values to write (as strings)")]
values: Vec<String>,
#[schemars(description = "Descriptor set number (default: 0)")]
descriptor_set: Option<u32>,
},
#[schemars(description = "Set memory layout for buffer data")]
BufferLayout {
#[schemars(description = "Buffer type (ubo, ssbo)")]
buffer_type: String,
#[schemars(description = "Layout specification (std140, std430, row_major, column_major)")]
layout_type: String,
},
#[schemars(description = "Set push constant values")]
Push {
#[schemars(description = "Data type (float, vec4, etc.)")]
data_type: String,
#[schemars(description = "Byte offset into push constant block")]
offset: u32,
#[schemars(description = "Values to write (as strings)")]
values: Vec<String>,
},
#[schemars(description = "Set memory layout for push constants")]
PushLayout {
#[schemars(description = "Layout specification (std140, std430)")]
layout_type: String,
},
#[schemars(description = "Execute compute shader with specified workgroup counts")]
Compute {
#[schemars(description = "Number of workgroups in X dimension")]
x: u32,
#[schemars(description = "Number of workgroups in Y dimension")]
y: u32,
#[schemars(description = "Number of workgroups in Z dimension")]
z: u32,
},
#[schemars(description = "Verify framebuffer or buffer contents match expected values")]
Probe {
#[schemars(description = "Probe type (all, rect, ssbo, etc.)")]
probe_type: String,
#[schemars(description = "Component format (rgba, rgb, etc.)")]
format: String,
#[schemars(description = "Parameters (coordinates, expected values)")]
args: Vec<String>,
},
#[schemars(description = "Verify contents using normalized (0-1) coordinates")]
RelativeProbe {
#[schemars(description = "Probe type (rect, etc.)")]
probe_type: String,
#[schemars(description = "Component format (rgba, rgb, etc.)")]
format: String,
#[schemars(description = "Parameters (coordinates, expected values)")]
args: Vec<String>,
},
#[schemars(description = "Set acceptable error margin for value comparisons")]
Tolerance {
#[schemars(description = "Error margins (absolute or percentage with % suffix)")]
values: Vec<f32>,
},
#[schemars(description = "Clear the framebuffer to default values")]
Clear,
#[schemars(description = "Enable/disable depth testing")]
DepthTestEnable {
#[schemars(description = "True to enable depth testing")]
enable: bool,
},
#[schemars(description = "Enable/disable writing to depth buffer")]
DepthWriteEnable {
#[schemars(description = "True to enable depth writes")]
enable: bool,
},
#[schemars(description = "Set depth comparison function")]
DepthCompareOp {
#[schemars(description = "Function name (VK_COMPARE_OP_LESS, etc.)")]
op: String,
},
#[schemars(description = "Enable/disable stencil testing")]
StencilTestEnable {
#[schemars(description = "True to enable stencil testing")]
enable: bool,
},
#[schemars(description = "Define which winding order is front-facing")]
FrontFace {
#[schemars(description = "Mode (VK_FRONT_FACE_CLOCKWISE, etc.)")]
mode: String,
},
#[schemars(description = "Configure stencil operation for a face")]
StencilOp {
#[schemars(description = "Face (front, back)")]
face: String,
#[schemars(description = "Operation name (passOp, failOp, etc.)")]
op_name: String,
#[schemars(description = "Value (VK_STENCIL_OP_REPLACE, etc.)")]
value: String,
},
#[schemars(description = "Set reference value for stencil comparisons")]
StencilReference {
#[schemars(description = "Face (front, back)")]
face: String,
#[schemars(description = "Reference value")]
value: u32,
},
#[schemars(description = "Set stencil comparison function")]
StencilCompareOp {
#[schemars(description = "Face (front, back)")]
face: String,
#[schemars(description = "Function (VK_COMPARE_OP_EQUAL, etc.)")]
op: String,
},
#[schemars(description = "Control which color channels can be written")]
ColorWriteMask {
#[schemars(
description = "Bit flags (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT, etc.)"
)]
mask: String,
},
#[schemars(description = "Enable/disable logical operations on colors")]
LogicOpEnable {
#[schemars(description = "True to enable logic operations")]
enable: bool,
},
#[schemars(description = "Set type of logical operation on colors")]
LogicOp {
#[schemars(description = "Operation (VK_LOGIC_OP_XOR, etc.)")]
op: String,
},
#[schemars(description = "Set face culling mode")]
CullMode {
#[schemars(description = "Mode (VK_CULL_MODE_BACK_BIT, etc.)")]
mode: String,
},
#[schemars(description = "Set width for line primitives")]
LineWidth {
#[schemars(description = "Width in pixels")]
width: f32,
},
#[schemars(description = "Specify a feature required by the test")]
Require {
#[schemars(description = "Feature name (subgroup_size, depthstencil, etc.)")]
feature: String,
#[schemars(description = "Feature parameters")]
parameters: Vec<String>,
},
}
#[derive(Debug, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub struct CompileRequest {
#[schemars(description = "The shader stage to compile (vert, frag, comp, geom, tesc, tese)")]
pub stage: ShaderStage,
#[schemars(description = "GLSL shader source code to compile")]
pub source: String,
#[schemars(description = "Path where compiled SPIR-V assembly (.spvasm) will be saved")]
pub tmp_output_path: String,
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct CompileShadersRequest {
#[schemars(description = "List of shader compile requests (produces SPIR-V assemblies)")]
pub requests: Vec<CompileRequest>,
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
pub struct CompileRunShadersRequest {
#[schemars(
description = "List of shader compile requests - each produces a SPIR-V assembly file"
)]
pub requests: Vec<CompileRequest>,
#[schemars(description = "Optional hardware/feature requirements needed for shader execution")]
pub requirements: Option<Vec<ShaderRunnerRequire>>,
#[schemars(
description = "Shader pipeline configuration (references compiled SPIR-V files by path)"
)]
pub passes: Vec<ShaderRunnerPass>,
#[schemars(description = "Optional vertex data for rendering geometry")]
pub vertex_data: Option<Vec<ShaderRunnerVertexData>>,
#[schemars(description = "Test commands to execute (drawing, compute, verification, etc.)")]
pub tests: Vec<ShaderRunnerTest>,
#[schemars(description = "Optional path to save output image (PNG format)")]
pub output_path: Option<String>,
}
#[derive(Clone)]
pub struct ShadercVkrunnerMcp {}
#[tool(tool_box)]
impl ShadercVkrunnerMcp {
#[must_use]
pub fn new() -> Self {
Self {}
}
#[tool(
description = "REQUIRES: (1) Shader compile requests that produce SPIR-V assembly files, (2) Passes referencing these files by path, and (3) Test commands for drawing/computation. Workflow: First compile GLSL to SPIR-V, then reference compiled files in passes, then execute drawing commands in tests to render/compute. Tests MUST include draw/compute commands to produce visible output. Optional: requirements for hardware features, vertex data for geometry, output path for saving image. Every shader MUST have its compiled output referenced in passes."
)]
fn compile_run_shaders(
&self,
#[tool(aggr)] request: CompileRunShadersRequest,
) -> Result<CallToolResult, McpError> {
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;
use std::process::{Command, Stdio}; // Still needed for vkrunner
fn io_err(e: std::io::Error) -> McpError {
McpError::internal_error("IO operation failed", Some(json!({"error": e.to_string()})))
}
for req in &request.requests {
let shader_kind = match req.stage {
ShaderStage::Vert => ShaderKind::Vertex,
ShaderStage::Frag => ShaderKind::Fragment,
ShaderStage::Tesc => ShaderKind::TessControl,
ShaderStage::Tese => ShaderKind::TessEvaluation,
ShaderStage::Geom => ShaderKind::Geometry,
ShaderStage::Comp => ShaderKind::Compute,
};
let stage_flag = match req.stage {
ShaderStage::Vert => "vert",
ShaderStage::Frag => "frag",
ShaderStage::Tesc => "tesc",
ShaderStage::Tese => "tese",
ShaderStage::Geom => "geom",
ShaderStage::Comp => "comp",
};
let tmp_output_path = if req.tmp_output_path.starts_with("/tmp") {
req.tmp_output_path.clone()
} else {
format!("/tmp/{}", req.tmp_output_path)
};
if let Some(parent) = Path::new(&tmp_output_path).parent() {
std::fs::create_dir_all(parent).map_err(|e| {
McpError::internal_error(
"Failed to create temporary output directory",
Some(json!({"error": e.to_string()})),
)
})?;
}
// Create compiler and options
let mut compiler = Compiler::new().ok().ok_or_else(|| {
McpError::internal_error(
"Failed to create shaderc compiler",
Some(json!({"error": "Could not instantiate shaderc compiler"})),
)
})?;
let mut options = CompileOptions::new().ok().ok_or_else(|| {
McpError::internal_error(
"Failed to create shaderc compile options",
Some(json!({"error": "Could not create compiler options"})),
)
})?;
// Set options equivalent to the CLI flags
options.set_target_env(
shaderc::TargetEnv::Vulkan,
shaderc::EnvVersion::Vulkan1_4 as u32,
);
options.set_optimization_level(OptimizationLevel::Performance);
options.set_generate_debug_info();
// Compile to SPIR-V assembly
let artifact = match compiler.compile_into_spirv_assembly(
&req.source,
shader_kind,
"shader.glsl", // source name for error reporting
"main", // entry point
Some(&options),
) {
Ok(artifact) => artifact,
Err(e) => {
let stage_name = match req.stage {
ShaderStage::Vert => "Vertex",
ShaderStage::Frag => "Fragment",
ShaderStage::Tesc => "Tessellation Control",
ShaderStage::Tese => "Tessellation Evaluation",
ShaderStage::Geom => "Geometry",
ShaderStage::Comp => "Compute",
};
let error_details = format!("{}", e);
return Ok(CallToolResult::success(vec![Content::text(format!(
"Shader compilation failed for {} shader:\n\nError:\n{}\n\nShader Source:\n{}\n",
stage_name, error_details, req.source
))]));
}
};
// Write the compiled SPIR-V assembly to the output file, filtering unsupported lines
let spv_text = artifact.as_text();
let filtered_spv = spv_text
.lines()
.filter(|l| !l.trim_start().starts_with("OpModuleProcessed"))
.collect::<Vec<_>>()
.join("\n");
std::fs::write(&tmp_output_path, filtered_spv).map_err(|e| {
McpError::internal_error(
"Failed to write compiled shader to file",
Some(json!({"error": e.to_string()})),
)
})?;
}
let shader_test_path = "/tmp/vkrunner_test.shader_test";
let mut shader_test_file = File::create(shader_test_path).map_err(io_err)?;
if let Some(requirements) = &request.requirements {
if !requirements.is_empty() {
writeln!(shader_test_file, "[require]").map_err(io_err)?;
for req in requirements {
match req {
ShaderRunnerRequire::CooperativeMatrix {
m,
n,
component_type,
} => {
writeln!(
shader_test_file,
"cooperative_matrix m={m} n={n} c={component_type}"
)
.map_err(io_err)?;
}
ShaderRunnerRequire::DepthStencil(format) => {
writeln!(shader_test_file, "depthstencil {format}").map_err(io_err)?;
}
ShaderRunnerRequire::Framebuffer(format) => {
writeln!(shader_test_file, "framebuffer {format}").map_err(io_err)?;
}
ShaderRunnerRequire::ShaderFloat64 => {
writeln!(shader_test_file, "shaderFloat64").map_err(io_err)?;
}
ShaderRunnerRequire::GeometryShader => {
writeln!(shader_test_file, "geometryShader").map_err(io_err)?;
}
ShaderRunnerRequire::WideLines => {
writeln!(shader_test_file, "wideLines").map_err(io_err)?;
}
ShaderRunnerRequire::LogicOp => {
writeln!(shader_test_file, "logicOp").map_err(io_err)?;
}
ShaderRunnerRequire::SubgroupSize(size) => {
writeln!(shader_test_file, "subgroup_size {size}").map_err(io_err)?;
}
ShaderRunnerRequire::FragmentStoresAndAtomics => {
writeln!(shader_test_file, "fragmentStoresAndAtomics")
.map_err(io_err)?;
}
ShaderRunnerRequire::BufferDeviceAddress => {
writeln!(shader_test_file, "bufferDeviceAddress").map_err(io_err)?;
}
}
}
writeln!(shader_test_file).map_err(io_err)?;
}
}
for pass in &request.passes {
match pass {
ShaderRunnerPass::VertPassthrough => {
writeln!(shader_test_file, "[vertex shader passthrough]").map_err(io_err)?;
}
ShaderRunnerPass::VertSpirv { vert_spvasm_path } => {
writeln!(shader_test_file, "[vertex shader spirv]").map_err(io_err)?;
let mut spvasm = String::new();
let path = if vert_spvasm_path.starts_with("/tmp") {
vert_spvasm_path.clone()
} else {
format!("/tmp/{vert_spvasm_path}")
};
File::open(&path)
.map_err(|e| {
McpError::internal_error(
format!("Failed to open vertex shader SPIR-V file at {path}"),
Some(json!({"error": e.to_string()})),
)
})?
.read_to_string(&mut spvasm)
.map_err(|e| {
McpError::internal_error(
"Failed to read vertex shader SPIR-V file",
Some(json!({"error": e.to_string()})),
)
})?;
writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
}
ShaderRunnerPass::FragSpirv { frag_spvasm_path } => {
writeln!(shader_test_file, "[fragment shader spirv]").map_err(io_err)?;
let mut spvasm = String::new();
let path = if frag_spvasm_path.starts_with("/tmp") {
frag_spvasm_path.clone()
} else {
format!("/tmp/{frag_spvasm_path}")
};
File::open(&path)
.map_err(|e| {
McpError::internal_error(
format!("Failed to open fragment shader SPIR-V file at {path}"),
Some(json!({"error": e.to_string()})),
)
})?
.read_to_string(&mut spvasm)
.map_err(|e| {
McpError::internal_error(
"Failed to read fragment shader SPIR-V file",
Some(json!({"error": e.to_string()})),
)
})?;
writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
}
ShaderRunnerPass::CompSpirv { comp_spvasm_path } => {
writeln!(shader_test_file, "[compute shader spirv]").map_err(io_err)?;
let mut spvasm = String::new();
let path = if comp_spvasm_path.starts_with("/tmp") {
comp_spvasm_path.clone()
} else {
format!("/tmp/{comp_spvasm_path}")
};
File::open(&path)
.map_err(|e| {
McpError::internal_error(
format!("Failed to open compute shader SPIR-V file at {path}"),
Some(json!({"error": e.to_string()})),
)
})?
.read_to_string(&mut spvasm)
.map_err(|e| {
McpError::internal_error(
"Failed to read compute shader SPIR-V file",
Some(json!({"error": e.to_string()})),
)
})?;
writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
}
ShaderRunnerPass::GeomSpirv { geom_spvasm_path } => {
writeln!(shader_test_file, "[geometry shader spirv]").map_err(io_err)?;
let mut spvasm = String::new();
let path = if geom_spvasm_path.starts_with("/tmp") {
geom_spvasm_path.clone()
} else {
format!("/tmp/{geom_spvasm_path}")
};
File::open(&path)
.map_err(|e| {
McpError::internal_error(
format!("Failed to open geometry shader SPIR-V file at {path}"),
Some(json!({"error": e.to_string()})),
)
})?
.read_to_string(&mut spvasm)
.map_err(|e| {
McpError::internal_error(
"Failed to read geometry shader SPIR-V file",
Some(json!({"error": e.to_string()})),
)
})?;
writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
}
ShaderRunnerPass::TescSpirv { tesc_spvasm_path } => {
writeln!(shader_test_file, "[tessellation control shader spirv]")
.map_err(io_err)?;
let mut spvasm = String::new();
let path = if tesc_spvasm_path.starts_with("/tmp") {
tesc_spvasm_path.clone()
} else {
format!("/tmp/{tesc_spvasm_path}")
};
File::open(&path)
.map_err(|e| {
McpError::internal_error(
format!(
"Failed to open tessellation control shader SPIR-V file at {path}"
),
Some(json!({"error": e.to_string()})),
)
})?
.read_to_string(&mut spvasm)
.map_err(|e| {
McpError::internal_error(
"Failed to read tessellation control shader SPIR-V file",
Some(json!({"error": e.to_string()})),
)
})?;
writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
}
ShaderRunnerPass::TeseSpirv { tese_spvasm_path } => {
writeln!(shader_test_file, "[tessellation evaluation shader spirv]")
.map_err(io_err)?;
let mut spvasm = String::new();
let path = if tese_spvasm_path.starts_with("/tmp") {
tese_spvasm_path.clone()
} else {
format!("/tmp/{tese_spvasm_path}")
};
File::open(&path)
.map_err(|e| {
McpError::internal_error(
format!(
"Failed to open tessellation evaluation shader SPIR-V file at {path}"
),
Some(json!({"error": e.to_string()})),
)
})?
.read_to_string(&mut spvasm)
.map_err(|e| {
McpError::internal_error(
"Failed to read tessellation evaluation shader SPIR-V file",
Some(json!({"error": e.to_string()})),
)
})?;
writeln!(shader_test_file, "{spvasm}").map_err(io_err)?;
}
}
writeln!(shader_test_file).map_err(io_err)?;
}
if let Some(vertex_data) = &request.vertex_data {
writeln!(shader_test_file, "[vertex data]").map_err(io_err)?;
for data in vertex_data {
if let ShaderRunnerVertexData::AttributeFormat { location, format } = data {
writeln!(shader_test_file, "{location}/{format}").map_err(io_err)?;
} else {
match data {
ShaderRunnerVertexData::Vec2 { x, y } => {
write!(shader_test_file, "{x} {y}").map_err(io_err)?;
}
ShaderRunnerVertexData::Vec3 { x, y, z } => {
write!(shader_test_file, "{x} {y} {z}").map_err(io_err)?;
}
ShaderRunnerVertexData::Vec4 { x, y, z, w } => {
write!(shader_test_file, "{x} {y} {z} {w}").map_err(io_err)?;
}
ShaderRunnerVertexData::RGB { r, g, b } => {
write!(shader_test_file, "{r} {g} {b}").map_err(io_err)?;
}
ShaderRunnerVertexData::Hex { value } => {
write!(shader_test_file, "{value}").map_err(io_err)?;
}
ShaderRunnerVertexData::GenericComponents { components } => {
for component in components {
write!(shader_test_file, "{component} ").map_err(io_err)?;
}
}
_ => {}
}
writeln!(shader_test_file).map_err(io_err)?;
}
}
writeln!(shader_test_file).map_err(io_err)?;
}
writeln!(shader_test_file, "[test]").map_err(io_err)?;
for test_cmd in &request.tests {
match test_cmd {
ShaderRunnerTest::FragmentEntrypoint { name } => {
writeln!(shader_test_file, "fragment entrypoint {name}").map_err(io_err)?;
}
ShaderRunnerTest::VertexEntrypoint { name } => {
writeln!(shader_test_file, "vertex entrypoint {name}").map_err(io_err)?;
}
ShaderRunnerTest::ComputeEntrypoint { name } => {
writeln!(shader_test_file, "compute entrypoint {name}").map_err(io_err)?;
}
ShaderRunnerTest::GeometryEntrypoint { name } => {
writeln!(shader_test_file, "geometry entrypoint {name}").map_err(io_err)?;
}
ShaderRunnerTest::DrawRect {
x,
y,
width,
height,
} => {
writeln!(shader_test_file, "draw rect {x} {y} {width} {height}")
.map_err(io_err)?;
}
ShaderRunnerTest::DrawArrays {
primitive_type,
first,
count,
} => {
writeln!(
shader_test_file,
"draw arrays {primitive_type} {first} {count}"
)
.map_err(io_err)?;
}
ShaderRunnerTest::DrawArraysIndexed {
primitive_type,
first,
count,
} => {
writeln!(
shader_test_file,
"draw arrays indexed {primitive_type} {first} {count}"
)
.map_err(io_err)?;
}
ShaderRunnerTest::SSBO {
binding,
size,
data,
descriptor_set,
} => {
let set_prefix = if let Some(set) = descriptor_set {
format!("{set}:")
} else {
String::new()
};
if let Some(size) = size {
writeln!(shader_test_file, "ssbo {set_prefix}{binding} {size}")
.map_err(io_err)?;
} else if let Some(_data) = data {
writeln!(shader_test_file, "ssbo {set_prefix}{binding} data")
.map_err(io_err)?;
}
}
ShaderRunnerTest::SSBOSubData {
binding,
data_type,
offset,
values,
descriptor_set,
} => {
let set_prefix = if let Some(set) = descriptor_set {
format!("{set}:")
} else {
String::new()
};
write!(
shader_test_file,
"ssbo {set_prefix}{binding} subdata {data_type} {offset}"
)
.map_err(io_err)?;
for value in values {
write!(shader_test_file, " {value}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
ShaderRunnerTest::UBO {
binding,
data: _,
descriptor_set,
} => {
let set_prefix = if let Some(set) = descriptor_set {
format!("{set}:")
} else {
String::new()
};
writeln!(shader_test_file, "ubo {set_prefix}{binding} data").map_err(io_err)?;
}
ShaderRunnerTest::UBOSubData {
binding,
data_type,
offset,
values,
descriptor_set,
} => {
let set_prefix = if let Some(set) = descriptor_set {
format!("{set}:")
} else {
String::new()
};
write!(
shader_test_file,
"ubo {set_prefix}{binding} subdata {data_type} {offset}"
)
.map_err(io_err)?;
for value in values {
write!(shader_test_file, " {value}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
ShaderRunnerTest::BufferLayout {
buffer_type,
layout_type,
} => {
writeln!(shader_test_file, "{buffer_type} layout {layout_type}")
.map_err(io_err)?;
}
ShaderRunnerTest::Push {
data_type,
offset,
values,
} => {
write!(shader_test_file, "push {data_type} {offset}").map_err(io_err)?;
for value in values {
write!(shader_test_file, " {value}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
ShaderRunnerTest::PushLayout { layout_type } => {
writeln!(shader_test_file, "push layout {layout_type}").map_err(io_err)?;
}
ShaderRunnerTest::Compute { x, y, z } => {
writeln!(shader_test_file, "compute {x} {y} {z}").map_err(io_err)?;
}
ShaderRunnerTest::Probe {
probe_type,
format,
args,
} => {
write!(shader_test_file, "probe {probe_type} {format}").map_err(io_err)?;
for arg in args {
write!(shader_test_file, " {arg}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
ShaderRunnerTest::RelativeProbe {
probe_type,
format,
args,
} => {
write!(shader_test_file, "relative probe {probe_type} {format}")
.map_err(io_err)?;
for arg in args {
write!(shader_test_file, " {arg}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
ShaderRunnerTest::Tolerance { values } => {
write!(shader_test_file, "tolerance").map_err(io_err)?;
for value in values {
write!(shader_test_file, " {value}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
ShaderRunnerTest::Clear => {
writeln!(shader_test_file, "clear").map_err(io_err)?;
}
ShaderRunnerTest::DepthTestEnable { enable } => {
writeln!(shader_test_file, "depthTestEnable {enable}").map_err(io_err)?;
}
ShaderRunnerTest::DepthWriteEnable { enable } => {
writeln!(shader_test_file, "depthWriteEnable {enable}").map_err(io_err)?;
}
ShaderRunnerTest::DepthCompareOp { op } => {
writeln!(shader_test_file, "depthCompareOp {op}").map_err(io_err)?;
}
ShaderRunnerTest::StencilTestEnable { enable } => {
writeln!(shader_test_file, "stencilTestEnable {enable}").map_err(io_err)?;
}
ShaderRunnerTest::FrontFace { mode } => {
writeln!(shader_test_file, "frontFace {mode}").map_err(io_err)?;
}
ShaderRunnerTest::StencilOp {
face,
op_name,
value,
} => {
writeln!(shader_test_file, "{face}.{op_name} {value}",).map_err(io_err)?;
}
ShaderRunnerTest::StencilReference { face, value } => {
writeln!(shader_test_file, "{face}.reference {value}",).map_err(io_err)?;
}
ShaderRunnerTest::StencilCompareOp { face, op } => {
writeln!(shader_test_file, "{face}.compareOp {op}",).map_err(io_err)?;
}
ShaderRunnerTest::ColorWriteMask { mask } => {
writeln!(shader_test_file, "colorWriteMask {mask}",).map_err(io_err)?;
}
ShaderRunnerTest::LogicOpEnable { enable } => {
writeln!(shader_test_file, "logicOpEnable {enable}",).map_err(io_err)?;
}
ShaderRunnerTest::LogicOp { op } => {
writeln!(shader_test_file, "logicOp {op}",).map_err(io_err)?;
}
ShaderRunnerTest::CullMode { mode } => {
writeln!(shader_test_file, "cullMode {mode}",).map_err(io_err)?;
}
ShaderRunnerTest::LineWidth { width } => {
writeln!(shader_test_file, "lineWidth {width}").map_err(io_err)?;
}
ShaderRunnerTest::Require {
feature,
parameters,
} => {
write!(shader_test_file, "require {feature}").map_err(io_err)?;
for param in parameters {
write!(shader_test_file, " {param}").map_err(io_err)?;
}
writeln!(shader_test_file).map_err(io_err)?;
}
}
}
shader_test_file.flush().map_err(io_err)?;
let tmp_image_path = "/tmp/vkrunner_output.ppm";
if let Some(output_path) = &request.output_path {
if output_path.starts_with("/tmp") {
tmp_image_path.to_string()
} else {
format!("/tmp/{output_path}")
}
} else {
tmp_image_path.to_string()
};
let mut vkrunner_args = vec![shader_test_path];
if request.output_path.is_some() {
vkrunner_args.push("--image");
vkrunner_args.push(tmp_image_path);
}
let vkrunner_output = Command::new("vkrunner")
.args(&vkrunner_args)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.map_err(|e| {
McpError::internal_error(
"Failed to run vkrunner",
Some(json!({"error": e.to_string()})),
)
})?;
let stdout = String::from_utf8_lossy(&vkrunner_output.stdout).to_string();
let stderr = String::from_utf8_lossy(&vkrunner_output.stderr).to_string();
let mut result_message = if vkrunner_output.status.success() {
format!(
"Shader compilation successful using shaderc-rs.\nVkRunner execution successful.\n\nOutput:\n{stdout}\n\n"
)
} else {
format!(
"Shader compilation successful using shaderc-rs.\nVkRunner execution failed.\n\nOutput:\n{stdout}\n\nError:\n{stderr}\n\n",
)
};
if let Some(output_path) = &request.output_path {
if vkrunner_output.status.success() && Path::new(tmp_image_path).exists() {
match read_and_decode_ppm_file(tmp_image_path) {
Ok(img) => {
if let Some(parent) = Path::new(output_path).parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).map_err(|e| {
McpError::internal_error(
"Failed to create output directory",
Some(json!({"error": e.to_string()})),
)
})?;
}
}
img.save(output_path).map_err(|e| {
McpError::internal_error(
"Failed to save output image",
Some(json!({"error": e.to_string()})),
)
})?;
result_message.push_str(&format!("Image saved to: {output_path}\n"));
}
Err(e) => {
result_message.push_str(&format!("Failed to convert output image: {e}\n"));
}
}
} else if vkrunner_output.status.success() {
result_message.push_str("No output image was generated by VkRunner.\n");
}
}
result_message.push_str("\nShader Test File Contents:\n");
result_message.push_str(
&std::fs::read_to_string(shader_test_path)
.unwrap_or_else(|_| "Failed to read shader test file".to_string()),
);
Ok(CallToolResult::success(vec![Content::text(result_message)]))
}
}
impl Default for ShadercVkrunnerMcp {
fn default() -> Self {
Self::new()
}
}
const_string!(Echo = "echo");
#[tool(tool_box)]
impl ServerHandler for ShadercVkrunnerMcp {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder()
.enable_tools()
.build(),
server_info: Implementation::from_build_env(),
instructions: Some("This server provides tools for compiling and running GLSL shaders using Vulkan infrastructure. The typical workflow is:
1. Compile GLSL source code to SPIR-V assembly using the 'compile_run_shaders' tool
2. Specify hardware/feature requirements if needed (e.g., geometry shaders, floating-point formats)
3. Reference compiled SPIR-V files in shader passes
4. Define vertex data if rendering geometry
5. Set up test commands to draw or compute
6. Optionally save the rendered output as an image".to_string()),
}
}
}
#[derive(Parser, Debug)]
#[command(about)]
struct Args {
#[clap(short, long, value_parser)]
work_dir: Option<PathBuf>,
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
if let Some(work_dir) = args.work_dir {
std::env::set_current_dir(&work_dir)
.map_err(|e| {
eprintln!("Failed to set working directory to {work_dir:?}: {e}");
std::process::exit(1);
})
.unwrap();
}
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
.with_writer(std::io::stderr)
.with_ansi(false)
.init();
tracing::info!("Starting MCP server");
let service = ShadercVkrunnerMcp::new()
.serve(stdio())
.await
.inspect_err(|e| {
tracing::error!("serving error: {:?}", e);
})?;
service.waiting().await?;
Ok(())
}
```
--------------------------------------------------------------------------------
/vkrunner/vkrunner/requirements.rs:
--------------------------------------------------------------------------------
```rust
// vkrunner
//
// Copyright (C) 2019 Intel Corporation
// Copyright 2023 Neil Roberts
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice (including the next
// paragraph) shall be included in all copies or substantial portions of the
// Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
// THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::vk;
use crate::vulkan_funcs;
use crate::vulkan_funcs::{NEXT_PTR_OFFSET, FIRST_FEATURE_OFFSET};
use crate::result;
use std::mem;
use std::collections::{HashMap, HashSet};
use std::ffi::CStr;
use std::convert::TryInto;
use std::fmt;
use std::cell::UnsafeCell;
use std::ptr;
#[derive(Debug)]
struct Extension {
// The name of the extension that provides this set of features
// stored as a null-terminated byte sequence. It is stored this
// way because that is how bindgen generates string literals from
// headers.
name_bytes: &'static [u8],
// The size of the corresponding features struct
struct_size: usize,
// The enum for this struct
struct_type: vk::VkStructureType,
// List of feature names in this extension in the order they
// appear in the features struct. The names are as written in the
// struct definition.
features: &'static [&'static str],
}
impl Extension {
fn name(&self) -> &str {
// The name comes from static data so it should always be
// valid
CStr::from_bytes_with_nul(self.name_bytes)
.unwrap()
.to_str()
.unwrap()
}
}
include!("features.rs");
// Lazily generated extensions data used for passing to Vulkan
struct LazyExtensions {
// All of the null-terminated strings concatenated into a buffer
strings: Vec<u8>,
// Pointers to the start of each string
pointers: Vec<* const u8>,
}
// Lazily generated structures data used for passing to Vulkan
struct LazyStructures {
// A buffer used to return a linked chain of structs that can be
// accessed from C.
list: Vec<u8>,
}
// Number of features in VkPhysicalDeviceFeatures. The struct is just
// a series of VkBool32 so we should be able to safely calculate the
// number based on the struct size.
const N_PHYSICAL_DEVICE_FEATURES: usize =
mem::size_of::<vk::VkPhysicalDeviceFeatures>()
/ mem::size_of::<vk::VkBool32>();
// Lazily generated VkPhysicalDeviceFeatures struct to pass to Vulkan
struct LazyBaseFeatures {
// Array big enough to store a VkPhysicalDeciveFeatures struct.
// It’s easier to manipulate as an array instead of the actual
// struct because we want to update the bools by index.
bools: [vk::VkBool32; N_PHYSICAL_DEVICE_FEATURES],
}
#[derive(Debug, Default, Clone)]
pub struct CooperativeMatrix {
pub m_size: Option<u32>,
pub n_size: Option<u32>,
pub k_size: Option<u32>,
pub a_type: Option<vk::VkComponentTypeKHR>,
pub b_type: Option<vk::VkComponentTypeKHR>,
pub c_type: Option<vk::VkComponentTypeKHR>,
pub result_type: Option<vk::VkComponentTypeKHR>,
pub saturating_accumulation: Option<vk::VkBool32>,
pub scope: Option<vk::VkScopeKHR>,
pub line: String,
}
#[derive(Debug)]
pub struct Requirements {
// Minimum vulkan version
version: u32,
// Set of extension names required
extensions: HashSet<String>,
// A map indexed by extension number. The value is an array of
// bools representing whether we need each feature in the
// extension’s feature struct.
features: HashMap<usize, Box<[bool]>>,
// An array of bools corresponding to each feature in the base
// VkPhysicalDeviceFeatures struct
base_features: [bool; N_BASE_FEATURES],
// Required subgroup size to be used with
// VkPipelineShaderStageRequiredSubgroupSizeCreateInfo.
pub required_subgroup_size: Option<u32>,
cooperative_matrix_reqs: Vec<CooperativeMatrix>,
// The rest of the struct is lazily created from the above data
// and shouldn’t be part of the PartialEq implementation.
lazy_extensions: UnsafeCell<Option<LazyExtensions>>,
lazy_structures: UnsafeCell<Option<LazyStructures>>,
lazy_base_features: UnsafeCell<Option<LazyBaseFeatures>>,
}
/// Error returned by [Requirements::check]
#[derive(Debug)]
pub enum Error {
EnumerateDeviceExtensionPropertiesFailed,
ExtensionMissingNullTerminator,
ExtensionInvalidUtf8,
/// A required base feature from VkPhysicalDeviceFeatures is missing.
MissingBaseFeature(usize),
/// A required extension is missing. The string is the name of the
/// extension.
MissingExtension(String),
/// A required feature is missing.
MissingFeature { extension: usize, feature: usize },
/// The API version reported by the driver is too low
VersionTooLow { required_version: u32, actual_version: u32 },
/// Required subgroup size out of supported range.
RequiredSubgroupSizeInvalid { size: u32, min: u32, max: u32 },
MissingCooperativeMatrixProperties(String),
}
impl Error {
pub fn result(&self) -> result::Result {
match self {
Error::EnumerateDeviceExtensionPropertiesFailed => {
result::Result::Fail
},
Error::ExtensionMissingNullTerminator => result::Result::Fail,
Error::ExtensionInvalidUtf8 => result::Result::Fail,
Error::MissingBaseFeature(_) => result::Result::Skip,
Error::MissingExtension(_) => result::Result::Skip,
Error::MissingFeature { .. } => result::Result::Skip,
Error::VersionTooLow { .. } => result::Result::Skip,
Error::RequiredSubgroupSizeInvalid { .. } => result::Result::Skip,
Error::MissingCooperativeMatrixProperties(_) => result::Result::Skip,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::EnumerateDeviceExtensionPropertiesFailed => {
write!(f, "vkEnumerateDeviceExtensionProperties failed")
},
Error::ExtensionMissingNullTerminator => {
write!(
f,
"NULL terminator missing in string returned from \
vkEnumerateDeviceExtensionProperties"
)
},
Error::ExtensionInvalidUtf8 => {
write!(
f,
"Invalid UTF-8 in string returned from \
vkEnumerateDeviceExtensionProperties"
)
},
&Error::MissingBaseFeature(feature_num) => {
write!(
f,
"Missing required feature: {}",
BASE_FEATURES[feature_num],
)
},
Error::MissingExtension(s) => {
write!(f, "Missing required extension: {}", s)
},
&Error::MissingFeature { extension, feature } => {
write!(
f,
"Missing required feature “{}” from extension “{}”",
EXTENSIONS[extension].features[feature],
EXTENSIONS[extension].name(),
)
},
&Error::VersionTooLow { required_version, actual_version } => {
let (req_major, req_minor, req_patch) =
extract_version(required_version);
let (actual_major, actual_minor, actual_patch) =
extract_version(actual_version);
write!(
f,
"Vulkan API version {}.{}.{} required but the driver \
reported {}.{}.{}",
req_major, req_minor, req_patch,
actual_major, actual_minor, actual_patch,
)
},
&Error::RequiredSubgroupSizeInvalid { size, min, max } => {
write!(
f,
"Required subgroup size {} is not in the supported range {}-{}",
size, min, max
)
},
Error::MissingCooperativeMatrixProperties(s) => {
write!(f, "Physical device can't fulfill cooperative matrix requirements: {}", s.trim())
},
}
}
}
/// Convert a decomposed Vulkan version into an integer. This is the
/// same as the `VK_MAKE_VERSION` macro in the Vulkan headers.
pub const fn make_version(major: u32, minor: u32, patch: u32) -> u32 {
(major << 22) | (minor << 12) | patch
}
/// Decompose a Vulkan version into its component major, minor and
/// patch parts. This is the equivalent of the `VK_VERSION_MAJOR`,
/// `VK_VERSION_MINOR` and `VK_VERSION_PATCH` macros in the Vulkan
/// header.
pub const fn extract_version(version: u32) -> (u32, u32, u32) {
(version >> 22, (version >> 12) & 0x3ff, version & 0xfff)
}
impl Requirements {
pub fn new() -> Requirements {
Requirements {
version: make_version(1, 0, 0),
extensions: HashSet::new(),
features: HashMap::new(),
base_features: [false; N_BASE_FEATURES],
required_subgroup_size: None,
lazy_extensions: UnsafeCell::new(None),
lazy_structures: UnsafeCell::new(None),
lazy_base_features: UnsafeCell::new(None),
cooperative_matrix_reqs: Vec::new(),
}
}
/// Get the required Vulkan version that was previously set with
/// [add_version](Requirements::add_version).
pub fn version(&self) -> u32 {
self.version
}
/// Set the minimum required Vulkan version.
pub fn add_version(&mut self, major: u32, minor: u32, patch: u32) {
self.version = std::cmp::max(self.version, make_version(major, minor, patch));
}
pub fn add_cooperative_matrix_req(&mut self, req: CooperativeMatrix) {
self.cooperative_matrix_reqs.push(req);
}
fn update_lazy_extensions(&self) -> &[* const u8] {
// SAFETY: The only case where lazy_extensions will be
// modified is if it is None and the only way for that to
// happen is if something modifies the requirements through a
// mutable reference. If that is the case then at that point
// there were no other references to the requirements so
// nothing can be holding a reference to the lazy data.
let extensions = unsafe {
match &mut *self.lazy_extensions.get() {
Some(data) => return data.pointers.as_ref(),
option @ None => {
option.insert(LazyExtensions {
strings: Vec::new(),
pointers: Vec::new(),
})
},
}
};
// Store a list of offsets into the c_extensions array for
// each extension. We can’t directly store the pointers yet
// because the Vec will probably be reallocated while we are
// adding to it.
let mut offsets = Vec::<usize>::new();
for extension in self.extensions.iter() {
offsets.push(extensions.strings.len());
extensions.strings.extend_from_slice(extension.as_bytes());
// Add the null terminator
extensions.strings.push(0);
}
let base_ptr = extensions.strings.as_ptr();
extensions.pointers.reserve(offsets.len());
for offset in offsets {
// SAFETY: These are all valid offsets into the
// c_extensions Vec so they should all be in the same
// allocation and no overflow is possible.
extensions.pointers.push(unsafe { base_ptr.add(offset) });
}
extensions.pointers.as_ref()
}
/// Return a reference to an array of pointers to C-style strings
/// that can be passed to `vkCreateDevice`.
pub fn c_extensions(&self) -> &[* const u8] {
self.update_lazy_extensions()
}
// Make a linked list of features structures that is suitable for
// passing to VkPhysicalDeviceFeatures2 or vkCreateDevice. All of
// the bools are set to 0. The vec with the structs is returned as
// well as a list of offsets to the bools along with the extension
// number.
fn make_empty_structures(
&self
) -> (Vec<u8>, Vec<(usize, usize)>)
{
let mut structures = Vec::new();
// Keep an array of offsets and corresponding extension num in a
// vec for each structure that will be returned. The offset is
// also used to update the pNext pointers. We have to do this
// after filling the vec because the vec will probably
// reallocate while we are adding to it which would invalidate
// the pointers.
let mut offsets = Vec::<(usize, usize)>::new();
for (&extension_num, features) in self.features.iter() {
let extension = &EXTENSIONS[extension_num];
// Make sure the struct size is big enough to hold all of
// the bools
assert!(
extension.struct_size
>= (FIRST_FEATURE_OFFSET
+ features.len() * mem::size_of::<vk::VkBool32>())
);
let offset = structures.len();
structures.resize(offset + extension.struct_size, 0);
// The structure type is the first field in the structure
let type_end = offset + mem::size_of::<vk::VkStructureType>();
structures[offset..type_end]
.copy_from_slice(&extension.struct_type.to_ne_bytes());
offsets.push((offset + FIRST_FEATURE_OFFSET, extension_num));
}
let mut last_offset = 0;
for &(offset, _) in &offsets[1..] {
let offset = offset - FIRST_FEATURE_OFFSET;
// SAFETY: The offsets are all valid offsets into the
// structures vec so they should all be in the same
// allocation and no overflow should occur.
let ptr = unsafe { structures.as_ptr().add(offset) };
let ptr_start = last_offset + NEXT_PTR_OFFSET;
let ptr_end = ptr_start + mem::size_of::<*const u8>();
structures[ptr_start..ptr_end]
.copy_from_slice(&(ptr as usize).to_ne_bytes());
last_offset = offset;
}
(structures, offsets)
}
fn update_lazy_structures(&self) -> &[u8] {
// SAFETY: The only case where lazy_structures will be
// modified is if it is None and the only way for that to
// happen is if something modifies the requirements through a
// mutable reference. If that is the case then at that point
// there were no other references to the requirements so
// nothing can be holding a reference to the lazy data.
let (structures, offsets) = unsafe {
match &mut *self.lazy_structures.get() {
Some(data) => return data.list.as_slice(),
option @ None => {
let (list, offsets) = self.make_empty_structures();
(option.insert(LazyStructures { list }), offsets)
},
}
};
for (offset, extension_num) in offsets {
let features = &self.features[&extension_num];
for (feature_num, &feature) in features.iter().enumerate() {
let feature_start =
offset
+ mem::size_of::<vk::VkBool32>()
* feature_num;
let feature_end =
feature_start + mem::size_of::<vk::VkBool32>();
structures.list[feature_start..feature_end]
.copy_from_slice(&(feature as vk::VkBool32).to_ne_bytes());
}
}
structures.list.as_slice()
}
/// Return a pointer to a linked list of feature structures that
/// can be passed to `vkCreateDevice`, or `None` if no feature
/// structs are required.
pub fn c_structures(&self) -> Option<&[u8]> {
if self.features.is_empty() {
None
} else {
Some(self.update_lazy_structures())
}
}
fn update_lazy_base_features(&self) -> &[vk::VkBool32] {
// SAFETY: The only case where lazy_structures will be
// modified is if it is None and the only way for that to
// happen is if something modifies the requirements through a
// mutable reference. If that is the case then at that point
// there were no other references to the requirements so
// nothing can be holding a reference to the lazy data.
let base_features = unsafe {
match &mut *self.lazy_base_features.get() {
Some(data) => return &data.bools,
option @ None => {
option.insert(LazyBaseFeatures {
bools: [0; N_PHYSICAL_DEVICE_FEATURES],
})
},
}
};
for (feature_num, &feature) in self.base_features.iter().enumerate() {
base_features.bools[feature_num] = feature as vk::VkBool32;
}
&base_features.bools
}
/// Return a pointer to a `VkPhysicalDeviceFeatures` struct that
/// can be passed to `vkCreateDevice`.
pub fn c_base_features(&self) -> &vk::VkPhysicalDeviceFeatures {
unsafe {
&*self.update_lazy_base_features().as_ptr().cast()
}
}
fn add_extension_name(&mut self, name: &str) {
// It would be nice to use get_or_insert_owned here if it
// becomes stable. It’s probably better not to use
// HashSet::replace directly because if it’s already in the
// set then we’ll pointlessly copy the str slice and
// immediately free it.
if !self.extensions.contains(name) {
self.extensions.insert(name.to_owned());
// SAFETY: self is immutable so there should be no other
// reference to the lazy data
unsafe {
*self.lazy_extensions.get() = None;
}
}
}
/// Adds a requirement to the list of requirements.
///
/// The name can be either a feature as written in the
/// corresponding features struct or the name of an extension. If
/// it is a feature it needs to be either the name of a field in
/// the `VkPhysicalDeviceFeatures` struct or a field in any of the
/// features structs of the extensions that vkrunner knows about.
/// In the latter case the name of the corresponding extension
/// will be automatically added as a requirement.
pub fn add(&mut self, name: &str) {
if let Some((extension_num, feature_num)) = find_feature(name) {
let extension = &EXTENSIONS[extension_num];
self.add_extension_name(extension.name());
let features = self
.features
.entry(extension_num)
.or_insert_with(|| {
vec![false; extension.features.len()].into_boxed_slice()
});
if !features[feature_num] {
// SAFETY: self is immutable so there should be no other
// reference to the lazy data
unsafe {
*self.lazy_structures.get() = None;
}
features[feature_num] = true;
}
} else if let Some(num) = find_base_feature(name) {
if !self.base_features[num] {
self.base_features[num] = true;
// SAFETY: self is immutable so there should be no other
// reference to the lazy data
unsafe {
*self.lazy_structures.get() = None;
}
}
} else {
self.add_extension_name(name);
}
}
fn check_base_features(
&self,
vkinst: &vulkan_funcs::Instance,
device: vk::VkPhysicalDevice
) -> Result<(), Error> {
let mut actual_features: [vk::VkBool32; N_BASE_FEATURES] =
[0; N_BASE_FEATURES];
unsafe {
vkinst.vkGetPhysicalDeviceFeatures.unwrap()(
device,
actual_features.as_mut_ptr().cast()
);
}
for (feature_num, &required) in self.base_features.iter().enumerate() {
if required && actual_features[feature_num] == 0 {
return Err(Error::MissingBaseFeature(feature_num));
}
}
Ok(())
}
fn get_device_extensions(
&self,
vkinst: &vulkan_funcs::Instance,
device: vk::VkPhysicalDevice
) -> Result<HashSet::<String>, Error> {
let mut property_count = 0u32;
let res = unsafe {
vkinst.vkEnumerateDeviceExtensionProperties.unwrap()(
device,
std::ptr::null(), // layerName
&mut property_count as *mut u32,
std::ptr::null_mut(), // properties
)
};
if res != vk::VK_SUCCESS {
return Err(Error::EnumerateDeviceExtensionPropertiesFailed);
}
let mut extensions = Vec::<vk::VkExtensionProperties>::with_capacity(
property_count as usize
);
unsafe {
let res = vkinst.vkEnumerateDeviceExtensionProperties.unwrap()(
device,
std::ptr::null(), // layerName
&mut property_count as *mut u32,
extensions.as_mut_ptr(),
);
if res != vk::VK_SUCCESS {
return Err(Error::EnumerateDeviceExtensionPropertiesFailed);
}
// SAFETY: The FFI call to
// vkEnumerateDeviceExtensionProperties should have filled
// the extensions array with valid values so we can safely
// set the length to the capacity we allocated earlier
extensions.set_len(property_count as usize);
};
let mut extensions_set = HashSet::new();
for extension in extensions.iter() {
let name = &extension.extensionName;
// Make sure it has a NULL terminator
if let None = name.iter().find(|&&b| b == 0) {
return Err(Error::ExtensionMissingNullTerminator);
}
// SAFETY: we just checked that the array has a null terminator
let name = unsafe { CStr::from_ptr(name.as_ptr()) };
let name = match name.to_str() {
Err(_) => {
return Err(Error::ExtensionInvalidUtf8);
},
Ok(s) => s,
};
extensions_set.insert(name.to_owned());
}
Ok(extensions_set)
}
fn check_extensions(
&self,
vkinst: &vulkan_funcs::Instance,
device: vk::VkPhysicalDevice
) -> Result<(), Error> {
if self.extensions.is_empty() {
return Ok(());
}
let actual_extensions = self.get_device_extensions(vkinst, device)?;
for extension in self.extensions.iter() {
if !actual_extensions.contains(extension) {
return Err(Error::MissingExtension(extension.to_string()));
}
}
Ok(())
}
fn check_structures(
&self,
vkinst: &vulkan_funcs::Instance,
device: vk::VkPhysicalDevice
) -> Result<(), Error> {
if self.features.is_empty() {
return Ok(());
}
// If vkGetPhysicalDeviceFeatures2KHR isn’t available then we
// can probably assume that none of the extensions are
// available.
let get_features = match vkinst.vkGetPhysicalDeviceFeatures2KHR {
None => {
// Find the first feature and report that as missing
for (&extension, features) in self.features.iter() {
for (feature, &enabled) in features.iter().enumerate() {
if enabled {
return Err(Error::MissingFeature {
extension,
feature,
});
}
}
}
unreachable!("Requirements::features should be empty if no \
features are required");
},
Some(func) => func,
};
let (mut actual_structures, offsets) = self.make_empty_structures();
let mut features_query = vk::VkPhysicalDeviceFeatures2 {
sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2,
pNext: actual_structures.as_mut_ptr().cast(),
features: Default::default(),
};
unsafe {
get_features(
device,
&mut features_query as *mut vk::VkPhysicalDeviceFeatures2,
);
}
for (offset, extension_num) in offsets {
let features = &self.features[&extension_num];
for (feature_num, &feature) in features.iter().enumerate() {
if !feature {
continue;
}
let feature_start =
offset
+ mem::size_of::<vk::VkBool32>()
* feature_num;
let feature_end =
feature_start + mem::size_of::<vk::VkBool32>();
let actual_value = vk::VkBool32::from_ne_bytes(
actual_structures[feature_start..feature_end]
.try_into()
.unwrap()
);
if actual_value == 0 {
return Err(Error::MissingFeature {
extension: extension_num,
feature: feature_num,
});
}
}
}
Ok(())
}
fn check_version(
&self,
vkinst: &vulkan_funcs::Instance,
device: vk::VkPhysicalDevice,
) -> Result<(), Error> {
let actual_version = if self.version() >= make_version(1, 1, 0) {
let mut subgroup_size_props = vk::VkPhysicalDeviceSubgroupSizeControlProperties {
sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_PROPERTIES,
..Default::default()
};
let mut props = vk::VkPhysicalDeviceProperties2 {
sType: vk::VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2,
pNext: std::ptr::addr_of_mut!(subgroup_size_props).cast(),
..Default::default()
};
unsafe {
vkinst.vkGetPhysicalDeviceProperties2.unwrap()(
device,
&mut props as *mut vk::VkPhysicalDeviceProperties2,
);
}
if let Some(size) = self.required_subgroup_size {
let min = subgroup_size_props.minSubgroupSize;
let max = subgroup_size_props.maxSubgroupSize;
if size < min || size > max {
return Err(Error::RequiredSubgroupSizeInvalid {
size, min, max
});
}
}
props.properties.apiVersion
} else {
let mut props = vk::VkPhysicalDeviceProperties::default();
unsafe {
vkinst.vkGetPhysicalDeviceProperties.unwrap()(
device,
&mut props as *mut vk::VkPhysicalDeviceProperties,
);
}
props.apiVersion
};
if actual_version < self.version() {
Err(Error::VersionTooLow {
required_version: self.version(),
actual_version,
})
} else {
Ok(())
}
}
fn check_cooperative_matrix(&self, vkinst: &vulkan_funcs::Instance, device: vk::VkPhysicalDevice) -> Result<(), Error> {
if self.cooperative_matrix_reqs.len() == 0 {
return Ok(());
}
let mut count = 0u32;
unsafe {
vkinst.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR.unwrap()(
device,
ptr::addr_of_mut!(count),
ptr::null_mut(),
);
}
let mut props = Vec::<vk::VkCooperativeMatrixPropertiesKHR>::new();
props.resize_with(count as usize, Default::default);
for prop in &mut props {
prop.sType = vk::VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
}
unsafe {
vkinst.vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR.unwrap()(
device,
ptr::addr_of_mut!(count),
props.as_mut_ptr(),
);
}
for req in &self.cooperative_matrix_reqs {
let mut found = false;
for prop in &props {
if let Some(m) = req.m_size {
if m != prop.MSize {
continue
}
}
if let Some(n) = req.n_size {
if n != prop.NSize {
continue
}
}
if let Some(k) = req.k_size {
if k != prop.KSize {
continue
}
}
if let Some(a) = req.a_type {
if a != prop.AType {
continue
}
}
if let Some(b) = req.b_type {
if b != prop.BType {
continue
}
}
if let Some(c) = req.c_type {
if c != prop.CType {
continue
}
}
if let Some(result) = req.result_type {
if result != prop.ResultType {
continue
}
}
if let Some(sa) = req.saturating_accumulation {
if sa != prop.saturatingAccumulation {
continue
}
}
if let Some(scope) = req.scope {
if scope != prop.scope {
continue
}
}
found = true;
break
}
if !found {
return Err(Error::MissingCooperativeMatrixProperties(req.line.clone()));
}
}
Ok(())
}
pub fn check(
&self,
vkinst: &vulkan_funcs::Instance,
device: vk::VkPhysicalDevice
) -> Result<(), Error> {
self.check_base_features(vkinst, device)?;
self.check_extensions(vkinst, device)?;
self.check_structures(vkinst, device)?;
self.check_version(vkinst, device)?;
self.check_cooperative_matrix(vkinst, device)?;
Ok(())
}
}
// Looks for a feature with the given name. If found it returns the
// index of the extension it was found in and the index of the feature
// name.
fn find_feature(name: &str) -> Option<(usize, usize)> {
for (extension_num, extension) in EXTENSIONS.iter().enumerate() {
for (feature_num, &feature) in extension.features.iter().enumerate() {
if feature == name {
return Some((extension_num, feature_num));
}
}
}
None
}
fn find_base_feature(name: &str) -> Option<usize> {
BASE_FEATURES.iter().position(|&f| f == name)
}
impl PartialEq for Requirements {
fn eq(&self, other: &Requirements) -> bool {
self.version == other.version
&& self.extensions == other.extensions
&& self.features == other.features
&& self.base_features == other.base_features
&& self.required_subgroup_size == other.required_subgroup_size
}
}
impl Eq for Requirements {
}
// Manual implementation of clone to avoid copying the lazy state
impl Clone for Requirements {
fn clone(&self) -> Requirements {
Requirements {
version: self.version,
extensions: self.extensions.clone(),
features: self.features.clone(),
base_features: self.base_features.clone(),
required_subgroup_size: self.required_subgroup_size,
lazy_extensions: UnsafeCell::new(None),
lazy_structures: UnsafeCell::new(None),
lazy_base_features: UnsafeCell::new(None),
cooperative_matrix_reqs: self.cooperative_matrix_reqs.clone(),
}
}
fn clone_from(&mut self, source: &Requirements) {
self.version = source.version;
self.extensions.clone_from(&source.extensions);
self.features.clone_from(&source.features);
self.base_features.clone_from(&source.base_features);
self.required_subgroup_size = source.required_subgroup_size;
self.cooperative_matrix_reqs.clone_from(&source.cooperative_matrix_reqs);
// SAFETY: self is immutable so there should be no other
// reference to the lazy data
unsafe {
*self.lazy_extensions.get() = None;
*self.lazy_structures.get() = None;
*self.lazy_base_features.get() = None;
}
}
}
#[cfg(test)]
mod test {
use super::*;
use std::ffi::{c_char, c_void};
use crate::fake_vulkan;
fn get_struct_type(structure: &[u8]) -> vk::VkStructureType {
let slice = &structure[0..mem::size_of::<vk::VkStructureType>()];
let mut bytes = [0; mem::size_of::<vk::VkStructureType>()];
bytes.copy_from_slice(slice);
vk::VkStructureType::from_ne_bytes(bytes)
}
fn get_next_structure(structure: &[u8]) -> &[u8] {
let slice = &structure[
NEXT_PTR_OFFSET..NEXT_PTR_OFFSET + mem::size_of::<usize>()
];
let mut bytes = [0; mem::size_of::<usize>()];
bytes.copy_from_slice(slice);
let pointer = usize::from_ne_bytes(bytes);
let offset = pointer - structure.as_ptr() as usize;
&structure[offset..]
}
fn find_bools_for_extension<'a, 'b>(
mut structures: &'a [u8],
extension: &'b Extension
) -> &'a [vk::VkBool32] {
while !structures.is_empty() {
let struct_type = get_struct_type(structures);
if struct_type == extension.struct_type {
unsafe {
let bools_ptr = structures
.as_ptr()
.add(FIRST_FEATURE_OFFSET);
return std::slice::from_raw_parts(
bools_ptr as *const vk::VkBool32,
extension.features.len()
);
}
}
structures = get_next_structure(structures);
}
unreachable!("No structure found for extension “{}”", extension.name());
}
unsafe fn extension_in_c_extensions(
reqs: &mut Requirements,
ext: &str
) -> bool {
for &p in reqs.c_extensions().iter() {
if CStr::from_ptr(p as *const c_char).to_str().unwrap() == ext {
return true;
}
}
false
}
#[test]
fn test_all_features() {
let mut reqs = Requirements::new();
for extension in EXTENSIONS.iter() {
for feature in extension.features.iter() {
reqs.add(feature);
}
}
for feature in BASE_FEATURES.iter() {
reqs.add(feature);
}
for (extension_num, extension) in EXTENSIONS.iter().enumerate() {
// All of the extensions should be in the set
assert!(reqs.extensions.contains(extension.name()));
// All of the features of every extension should be true
assert!(reqs.features[&extension_num].iter().all(|&b| b));
}
// All of the base features should be enabled
assert!(reqs.base_features.iter().all(|&b| b));
let base_features = reqs.c_base_features();
assert_eq!(base_features.robustBufferAccess, 1);
assert_eq!(base_features.fullDrawIndexUint32, 1);
assert_eq!(base_features.imageCubeArray, 1);
assert_eq!(base_features.independentBlend, 1);
assert_eq!(base_features.geometryShader, 1);
assert_eq!(base_features.tessellationShader, 1);
assert_eq!(base_features.sampleRateShading, 1);
assert_eq!(base_features.dualSrcBlend, 1);
assert_eq!(base_features.logicOp, 1);
assert_eq!(base_features.multiDrawIndirect, 1);
assert_eq!(base_features.drawIndirectFirstInstance, 1);
assert_eq!(base_features.depthClamp, 1);
assert_eq!(base_features.depthBiasClamp, 1);
assert_eq!(base_features.fillModeNonSolid, 1);
assert_eq!(base_features.depthBounds, 1);
assert_eq!(base_features.wideLines, 1);
assert_eq!(base_features.largePoints, 1);
assert_eq!(base_features.alphaToOne, 1);
assert_eq!(base_features.multiViewport, 1);
assert_eq!(base_features.samplerAnisotropy, 1);
assert_eq!(base_features.textureCompressionETC2, 1);
assert_eq!(base_features.textureCompressionASTC_LDR, 1);
assert_eq!(base_features.textureCompressionBC, 1);
assert_eq!(base_features.occlusionQueryPrecise, 1);
assert_eq!(base_features.pipelineStatisticsQuery, 1);
assert_eq!(base_features.vertexPipelineStoresAndAtomics, 1);
assert_eq!(base_features.fragmentStoresAndAtomics, 1);
assert_eq!(base_features.shaderTessellationAndGeometryPointSize, 1);
assert_eq!(base_features.shaderImageGatherExtended, 1);
assert_eq!(base_features.shaderStorageImageExtendedFormats, 1);
assert_eq!(base_features.shaderStorageImageMultisample, 1);
assert_eq!(base_features.shaderStorageImageReadWithoutFormat, 1);
assert_eq!(base_features.shaderStorageImageWriteWithoutFormat, 1);
assert_eq!(base_features.shaderUniformBufferArrayDynamicIndexing, 1);
assert_eq!(base_features.shaderSampledImageArrayDynamicIndexing, 1);
assert_eq!(base_features.shaderStorageBufferArrayDynamicIndexing, 1);
assert_eq!(base_features.shaderStorageImageArrayDynamicIndexing, 1);
assert_eq!(base_features.shaderClipDistance, 1);
assert_eq!(base_features.shaderCullDistance, 1);
assert_eq!(base_features.shaderFloat64, 1);
assert_eq!(base_features.shaderInt64, 1);
assert_eq!(base_features.shaderInt16, 1);
assert_eq!(base_features.shaderResourceResidency, 1);
assert_eq!(base_features.shaderResourceMinLod, 1);
assert_eq!(base_features.sparseBinding, 1);
assert_eq!(base_features.sparseResidencyBuffer, 1);
assert_eq!(base_features.sparseResidencyImage2D, 1);
assert_eq!(base_features.sparseResidencyImage3D, 1);
assert_eq!(base_features.sparseResidency2Samples, 1);
assert_eq!(base_features.sparseResidency4Samples, 1);
assert_eq!(base_features.sparseResidency8Samples, 1);
assert_eq!(base_features.sparseResidency16Samples, 1);
assert_eq!(base_features.sparseResidencyAliased, 1);
assert_eq!(base_features.variableMultisampleRate, 1);
assert_eq!(base_features.inheritedQueries, 1);
// All of the values should be set in the C structs
for extension in EXTENSIONS.iter() {
let structs = reqs.c_structures().unwrap();
assert!(
find_bools_for_extension(structs, extension)
.iter()
.all(|&b| b == 1)
);
assert!(unsafe { &*reqs.lazy_structures.get() }.is_some());
}
// All of the extensions should be in the c_extensions
for extension in EXTENSIONS.iter() {
assert!(unsafe {
extension_in_c_extensions(&mut reqs, extension.name())
});
assert!(unsafe { &*reqs.lazy_extensions.get() }.is_some());
}
// Sanity check that a made-up extension isn’t in c_extensions
assert!(!unsafe {
extension_in_c_extensions(&mut reqs, "not_a_real_ext")
});
}
#[test]
fn test_version() {
let mut reqs = Requirements::new();
reqs.add_version(2, 1, 5);
assert_eq!(reqs.version(), 0x801005);
}
#[test]
fn test_empty() {
let reqs = Requirements::new();
assert!(unsafe { &*reqs.lazy_extensions.get() }.is_none());
assert!(unsafe { &*reqs.lazy_structures.get() }.is_none());
let base_features_ptr = reqs.c_base_features()
as *const vk::VkPhysicalDeviceFeatures
as *const vk::VkBool32;
unsafe {
let base_features = std::slice::from_raw_parts(
base_features_ptr,
N_BASE_FEATURES
);
assert!(base_features.iter().all(|&b| b == 0));
}
}
#[test]
fn test_eq() {
let mut reqs_a = Requirements::new();
let mut reqs_b = Requirements::new();
assert_eq!(reqs_a, reqs_b);
reqs_a.add("advancedBlendCoherentOperations");
assert_ne!(reqs_a, reqs_b);
reqs_b.add("advancedBlendCoherentOperations");
assert_eq!(reqs_a, reqs_b);
// Getting the C data shouldn’t affect the equality
reqs_a.c_structures();
reqs_a.c_base_features();
reqs_a.c_extensions();
assert_eq!(reqs_a, reqs_b);
// The order of adding shouldn’t matter
reqs_a.add("fake_extension");
reqs_a.add("another_fake_extension");
reqs_b.add("another_fake_extension");
reqs_b.add("fake_extension");
assert_eq!(reqs_a, reqs_b);
reqs_a.add("wideLines");
assert_ne!(reqs_a, reqs_b);
reqs_b.add("wideLines");
assert_eq!(reqs_a, reqs_b);
reqs_a.add_version(3, 1, 2);
assert_ne!(reqs_a, reqs_b);
reqs_b.add_version(3, 1, 2);
assert_eq!(reqs_a, reqs_b);
}
#[test]
fn test_clone() {
let mut reqs = Requirements::new();
reqs.add("wideLines");
reqs.add("fake_extension");
reqs.add("advancedBlendCoherentOperations");
assert_eq!(reqs.clone(), reqs);
assert!(unsafe { &*reqs.lazy_structures.get() }.is_none());
assert_eq!(reqs.c_extensions().len(), 2);
assert_eq!(reqs.c_base_features().wideLines, 1);
let empty = Requirements::new();
reqs.clone_from(&empty);
assert!(unsafe { &*reqs.lazy_structures.get() }.is_none());
assert_eq!(reqs.c_extensions().len(), 0);
assert_eq!(reqs.c_base_features().wideLines, 0);
assert_eq!(reqs, empty);
}
struct FakeVulkanData {
fake_vulkan: Box<fake_vulkan::FakeVulkan>,
_vklib: vulkan_funcs::Library,
vkinst: vulkan_funcs::Instance,
instance: vk::VkInstance,
device: vk::VkPhysicalDevice,
}
impl FakeVulkanData {
fn new() -> FakeVulkanData {
let mut fake_vulkan = fake_vulkan::FakeVulkan::new();
fake_vulkan.physical_devices.push(Default::default());
fake_vulkan.set_override();
let vklib = vulkan_funcs::Library::new().unwrap();
let mut instance = std::ptr::null_mut();
unsafe {
let res = vklib.vkCreateInstance.unwrap()(
std::ptr::null(), // pCreateInfo
std::ptr::null(), // pAllocator
std::ptr::addr_of_mut!(instance),
);
assert_eq!(res, vk::VK_SUCCESS);
}
extern "C" fn get_instance_proc(
func_name: *const c_char,
user_data: *const c_void,
) -> *const c_void {
unsafe {
let fake_vulkan =
&*user_data.cast::<fake_vulkan::FakeVulkan>();
std::mem::transmute(
fake_vulkan.get_function(func_name.cast())
)
}
}
let vkinst = unsafe {
vulkan_funcs::Instance::new(
get_instance_proc,
fake_vulkan.as_ref()
as *const fake_vulkan::FakeVulkan
as *const c_void,
)
};
let mut device = std::ptr::null_mut();
unsafe {
let mut count = 1;
let res = vkinst.vkEnumeratePhysicalDevices.unwrap()(
instance,
std::ptr::addr_of_mut!(count),
std::ptr::addr_of_mut!(device),
);
assert_eq!(res, vk::VK_SUCCESS);
}
FakeVulkanData {
fake_vulkan,
_vklib: vklib,
vkinst,
instance,
device,
}
}
}
impl Drop for FakeVulkanData {
fn drop(&mut self) {
if !std::thread::panicking() {
unsafe {
self.vkinst.vkDestroyInstance.unwrap()(
self.instance,
std::ptr::null(), // allocator
);
}
}
}
}
fn check_base_features<'a>(
reqs: &'a Requirements,
features: &vk::VkPhysicalDeviceFeatures
) -> Result<(), Error> {
let mut data = FakeVulkanData::new();
data.fake_vulkan.physical_devices[0].features = features.clone();
reqs.check(&data.vkinst, data.device)
}
#[test]
fn test_check_base_features() {
let mut features = Default::default();
assert!(matches!(
check_base_features(&Requirements::new(), &features),
Ok(()),
));
features.geometryShader = vk::VK_TRUE;
assert!(matches!(
check_base_features(&Requirements::new(), &features),
Ok(()),
));
let mut reqs = Requirements::new();
reqs.add("geometryShader");
assert!(matches!(
check_base_features(&reqs, &features),
Ok(()),
));
reqs.add("depthBounds");
match check_base_features(&reqs, &features) {
Ok(()) => unreachable!("Requirements::check was supposed to fail"),
Err(e) => {
assert!(matches!(e, Error::MissingBaseFeature(_)));
assert_eq!(
e.to_string(),
"Missing required feature: depthBounds"
);
assert_eq!(e.result(), result::Result::Skip);
},
}
}
#[test]
fn test_check_extensions() {
let mut data = FakeVulkanData::new();
let mut reqs = Requirements::new();
assert!(matches!(
reqs.check(&data.vkinst, data.device),
Ok(()),
));
reqs.add("fake_extension");
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected extensions check to fail"),
Err(e) => {
assert!(matches!(
e,
Error::MissingExtension(_),
));
assert_eq!(
e.to_string(),
"Missing required extension: fake_extension"
);
assert_eq!(e.result(), result::Result::Skip);
},
};
data.fake_vulkan.physical_devices[0].add_extension("fake_extension");
assert!(matches!(
reqs.check(&data.vkinst, data.device),
Ok(()),
));
// Add an extension via a feature
reqs.add("multiviewGeometryShader");
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected extensions check to fail"),
Err(e) => {
assert!(matches!(
e,
Error::MissingExtension(_)
));
assert_eq!(
e.to_string(),
"Missing required extension: VK_KHR_multiview",
);
assert_eq!(e.result(), result::Result::Skip);
},
};
data.fake_vulkan.physical_devices[0].add_extension("VK_KHR_multiview");
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected extensions check to fail"),
Err(e) => {
assert!(matches!(
e,
Error::MissingFeature { .. },
));
assert_eq!(
e.to_string(),
"Missing required feature “multiviewGeometryShader” from \
extension “VK_KHR_multiview”",
);
assert_eq!(e.result(), result::Result::Skip);
},
};
// Make an unterminated UTF-8 character
let extension_name = &mut data
.fake_vulkan
.physical_devices[0]
.extensions[0]
.extensionName;
extension_name[0] = -1i8 as c_char;
extension_name[1] = 0;
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected extensions check to fail"),
Err(e) => {
assert_eq!(
e.to_string(),
"Invalid UTF-8 in string returned from \
vkEnumerateDeviceExtensionProperties"
);
assert!(matches!(e, Error::ExtensionInvalidUtf8));
assert_eq!(e.result(), result::Result::Fail);
},
};
// No null-terminator in the extension
extension_name.fill(32);
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected extensions check to fail"),
Err(e) => {
assert_eq!(
e.to_string(),
"NULL terminator missing in string returned from \
vkEnumerateDeviceExtensionProperties"
);
assert!(matches!(e, Error::ExtensionMissingNullTerminator));
assert_eq!(e.result(), result::Result::Fail);
},
};
}
#[test]
fn test_check_structures() {
let mut data = FakeVulkanData::new();
let mut reqs = Requirements::new();
reqs.add("multiview");
data.fake_vulkan.physical_devices[0].add_extension("VK_KHR_multiview");
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected features check to fail"),
Err(e) => {
assert_eq!(
e.to_string(),
"Missing required feature “multiview” \
from extension “VK_KHR_multiview”",
);
assert!(matches!(
e,
Error::MissingFeature { .. },
));
assert_eq!(e.result(), result::Result::Skip);
},
};
data.fake_vulkan.physical_devices[0].multiview.multiview = vk::VK_TRUE;
assert!(matches!(
reqs.check(&data.vkinst, data.device),
Ok(()),
));
reqs.add("shaderBufferInt64Atomics");
data.fake_vulkan.physical_devices[0].add_extension(
"VK_KHR_shader_atomic_int64"
);
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected features check to fail"),
Err(e) => {
assert_eq!(
e.to_string(),
"Missing required feature “shaderBufferInt64Atomics” \
from extension “VK_KHR_shader_atomic_int64”",
);
assert!(matches!(
e,
Error::MissingFeature { .. },
));
assert_eq!(e.result(), result::Result::Skip);
},
};
data.fake_vulkan
.physical_devices[0]
.shader_atomic
.shaderBufferInt64Atomics =
vk::VK_TRUE;
assert!(matches!(
reqs.check(&data.vkinst, data.device),
Ok(()),
));
}
#[test]
fn test_check_version() {
let mut data = FakeVulkanData::new();
let mut reqs = Requirements::new();
reqs.add_version(1, 2, 0);
data.fake_vulkan.physical_devices[0].properties.apiVersion =
make_version(1, 1, 0);
match reqs.check(&data.vkinst, data.device) {
Ok(()) => unreachable!("expected version check to fail"),
Err(e) => {
// The check will report that the version is 1.0.0
// because vkEnumerateInstanceVersion is not available
assert_eq!(
e.to_string(),
"Vulkan API version 1.2.0 required but the driver \
reported 1.1.0",
);
assert!(matches!(
e,
Error::VersionTooLow { .. },
));
assert_eq!(e.result(), result::Result::Skip);
},
};
// Set a valid version
data.fake_vulkan.physical_devices[0].properties.apiVersion =
make_version(1, 3, 0);
assert!(matches!(
reqs.check(&data.vkinst, data.device),
Ok(()),
));
}
}
```