diff --git a/ssh.odin b/ssh.odin index 2b9dff9..5b7df39 100644 --- a/ssh.odin +++ b/ssh.odin @@ -1,5 +1,6 @@ package main +import "base:runtime" import "core:encoding/base64" import "core:fmt" import "core:os" @@ -115,8 +116,11 @@ parse_ssh_private_key :: proc(priv_path: string) -> (kp: Ed25519Keypair, ok: boo if offset + 4 > len(decoded) { return } - num_keys := u32(decoded[offset]) << 24 | u32(decoded[offset + 1]) << 16 | - u32(decoded[offset + 2]) << 8 | u32(decoded[offset + 3]) + num_keys := + u32(decoded[offset]) << 24 | + u32(decoded[offset + 1]) << 16 | + u32(decoded[offset + 2]) << 8 | + u32(decoded[offset + 3]) offset += 4 if num_keys != 1 { @@ -137,11 +141,17 @@ parse_ssh_private_key :: proc(priv_path: string) -> (kp: Ed25519Keypair, ok: boo if inner_offset + 8 > len(priv_blob) { return } - check1 := u32(priv_blob[inner_offset]) << 24 | u32(priv_blob[inner_offset + 1]) << 16 | - u32(priv_blob[inner_offset + 2]) << 8 | u32(priv_blob[inner_offset + 3]) + check1 := + u32(priv_blob[inner_offset]) << 24 | + u32(priv_blob[inner_offset + 1]) << 16 | + u32(priv_blob[inner_offset + 2]) << 8 | + u32(priv_blob[inner_offset + 3]) inner_offset += 4 - check2 := u32(priv_blob[inner_offset]) << 24 | u32(priv_blob[inner_offset + 1]) << 16 | - u32(priv_blob[inner_offset + 2]) << 8 | u32(priv_blob[inner_offset + 3]) + check2 := + u32(priv_blob[inner_offset]) << 24 | + u32(priv_blob[inner_offset + 1]) << 16 | + u32(priv_blob[inner_offset + 2]) << 8 | + u32(priv_blob[inner_offset + 3]) inner_offset += 4 if check1 != check2 { @@ -173,83 +183,35 @@ parse_ssh_private_key :: proc(priv_path: string) -> (kp: Ed25519Keypair, ok: boo return } -is_ed25519_key :: proc(priv_path: string) -> bool { - pub_path, _ := strings.concatenate([]string{priv_path, ".pub"}, context.temp_allocator) - _, ok := parse_ssh_public_key(pub_path) - return ok -} - -is_encrypted_key :: proc(priv_path: string) -> bool { - data, err := os.read_entire_file_from_path(priv_path, context.temp_allocator) - if err != nil { - return true - } - - if !strings.contains(string(data), "BEGIN OPENSSH PRIVATE KEY") { - return true - } - - text := string(data) - lines := strings.split(text, "\n", context.temp_allocator) - - b2: strings.Builder - strings.builder_init(&b2, context.temp_allocator) - defer strings.builder_destroy(&b2) - - in_block := false - for line in lines { - trimmed := strings.trim_space(line) - if trimmed == "-----BEGIN OPENSSH PRIVATE KEY-----" { - in_block = true - continue - } - if trimmed == "-----END OPENSSH PRIVATE KEY-----" { - break - } - if in_block && len(trimmed) > 0 { - fmt.sbprintf(&b2, "%s", trimmed) - } - } - - b64_str := strings.to_string(b2) - decoded, decode_err := base64.decode(b64_str, allocator = context.temp_allocator) - if decode_err != nil { - return true - } - - magic := "openssh-key-v1\x00" - if len(decoded) < len(magic) { - return true - } - for i in 0 ..< len(magic) { - if decoded[i] != u8(magic[i]) { - return true - } - } - - offset := len(magic) - ciphername, cipher_ok := read_wire_string(decoded, &offset) - if !cipher_ok { - return true - } - - return ciphername != "none" +is_ed25519_key :: proc( + priv_path: string, +) -> ( + ok: bool, + err: runtime.Allocator_Error, +) #optional_allocator_error { + pub_path := strings.concatenate([]string{priv_path, ".pub"}, context.temp_allocator) or_return + _, ok = parse_ssh_public_key(pub_path) + return ok, nil } read_wire_string :: proc(data: []u8, offset: ^int) -> (s: string, ok: bool) { if offset^ + 4 > len(data) { return } - length := u32(data[offset^]) << 24 | u32(data[offset^ + 1]) << 16 | - u32(data[offset^ + 2]) << 8 | u32(data[offset^ + 3]) + length := + u32(data[offset^]) << 24 | + u32(data[offset^ + 1]) << 16 | + u32(data[offset^ + 2]) << 8 | + u32(data[offset^ + 3]) offset^ += 4 if offset^ + int(length) > len(data) { return } - s = string(data[offset^ : offset^ + int(length)]) + s = string(data[offset^:offset^ + int(length)]) offset^ += int(length) ok = true return } + diff --git a/ssh_test.odin b/ssh_test.odin index 595826a..4ee1d63 100644 --- a/ssh_test.odin +++ b/ssh_test.odin @@ -72,39 +72,4 @@ test_read_wire_string :: proc(t: ^testing.T) { testing.expect(t, s2 == "", "expected empty string") } -@(test) -test_is_encrypted_key_encrypted :: proc(t: ^testing.T) { - testing.expect( - t, - is_encrypted_key(TEST_KEY_DIR + "/test_ed25519_encrypted"), - "encrypted key should be detected as encrypted", - ) -} - -@(test) -test_is_encrypted_key_unencrypted :: proc(t: ^testing.T) { - testing.expect( - t, - !is_encrypted_key(TEST_KEY_DIR + "/test_ed25519"), - "unencrypted key should not be detected as encrypted", - ) -} - -@(test) -test_is_encrypted_key_rsa_unencrypted :: proc(t: ^testing.T) { - testing.expect( - t, - !is_encrypted_key(TEST_KEY_DIR + "/test_rsa"), - "unencrypted RSA key should not be detected as encrypted", - ) -} - -@(test) -test_is_encrypted_key_missing_file :: proc(t: ^testing.T) { - testing.expect( - t, - is_encrypted_key(TEST_KEY_DIR + "/nonexistent"), - "missing file should be treated as encrypted (fail-safe)", - ) -}