diff options
Diffstat (limited to 'app')
| -rw-r--r-- | app/main.go | 16 | ||||
| -rw-r--r-- | app/message.go | 80 |
2 files changed, 60 insertions, 36 deletions
diff --git a/app/main.go b/app/main.go index 69aefae..b02972d 100644 --- a/app/main.go +++ b/app/main.go @@ -31,7 +31,7 @@ func main() { receivedData := string(buf[:size]) fmt.Printf("Received %d bytes from %s: %s\n", size, source, receivedData) requestHeader := ParseHeader(buf[:size]) - question := ParseQuestion(buf[:size]) + questions := ParseQuestions(buf[:size], requestHeader.QDCOUNT) rcode := uint8(4) if requestHeader.OPCODE == 0 { @@ -48,17 +48,19 @@ func main() { RA: 0, Z: 0, RCODE: rcode, - QDCOUNT: 1, + QDCOUNT: 0, ANCOUNT: 0, - NSCOUNT: 0, - ARCOUNT: 0, + NSCOUNT: requestHeader.NSCOUNT, + ARCOUNT: requestHeader.ARCOUNT, } response := MakeMessage(header) - response.AddQuestion(question) - answer := MakeAnswer(question.Name, []byte("\x08\x08\x08\x08")) - response.AddAnswer(answer) + for _, question := range questions { + response.AddQuestion(question) + answer := MakeAnswer(question.Name, []byte("\x08\x08\x08\x08")) + response.AddAnswer(answer) + } _, err = udpConn.WriteToUDP(response.Bytes(), source) if err != nil { diff --git a/app/message.go b/app/message.go index c00346e..3d523ab 100644 --- a/app/message.go +++ b/app/message.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "encoding/binary" "strings" ) @@ -22,9 +23,9 @@ type DNSHeader struct { } type DNSMessage struct { - Header DNSHeader - Question []byte - Answer DNSAnswer + Header DNSHeader + Questions []DNSQuestion + Answers []DNSAnswer } type DNSAnswer struct { @@ -60,16 +61,24 @@ func (m *DNSHeader) combineFlags() uint16 { } func (m *DNSMessage) AddQuestion(q DNSQuestion) { - m.Question = q.Bytes() + m.Questions = append(m.Questions, q) + m.Header.QDCOUNT = m.Header.QDCOUNT + 1 } func (m *DNSMessage) Bytes() []byte { headerBytes := m.Header.Bytes() - answerBytes := m.Answer.Bytes() bytes := []byte{} bytes = append(bytes, headerBytes...) - bytes = append(bytes, m.Question...) - bytes = append(bytes, answerBytes...) + + for _, question := range m.Questions { + questionBytes := question.Bytes() + bytes = append(bytes, questionBytes...) + } + + for _, answer := range m.Answers { + answerBytes := answer.Bytes() + bytes = append(bytes, answerBytes...) + } return bytes } @@ -85,8 +94,8 @@ func (a *DNSAnswer) Bytes() []byte { } func (m *DNSMessage) AddAnswer(a DNSAnswer) { - m.Answer = a - m.Header.ANCOUNT = 1 + m.Answers = append(m.Answers, a) + m.Header.ANCOUNT = m.Header.ANCOUNT + 1 } func (q *DNSQuestion) Bytes() []byte { @@ -98,7 +107,7 @@ func (q *DNSQuestion) Bytes() []byte { } func MakeMessage(header DNSHeader) DNSMessage { - return DNSMessage{Header: header, Question: []byte{}, Answer: DNSAnswer{}} + return DNSMessage{Header: header, Questions: []DNSQuestion{}, Answers: []DNSAnswer{}} } func MakeAnswer(name []byte, rdata []byte) DNSAnswer { @@ -127,13 +136,23 @@ func MakeQuestion(name []byte) DNSQuestion { return DNSQuestion{Name: name, Type: 1, Class: 1} } -func ParseQuestion(buf []byte) DNSQuestion { - return MakeQuestion(ParseDomain(buf)) +func ParseQuestions(buf []byte, questionCount uint16) []DNSQuestion { + questions := []DNSQuestion{} + offset := 12 + + for i := 0; i < int(questionCount); i++ { + len := bytes.Index(buf[offset:], []byte{0}) + label := ParseDomain(buf[offset:offset+len+1], buf) + question := MakeQuestion(label) + questions = append(questions, question) + offset += len + 1 + 4 + } + + return questions } -func ParseDomain(data []byte) []byte { - domainByte := data[12:] - domain := decodeDNSPacket(domainByte) +func ParseDomain(data []byte, source []byte) []byte { + domain := decodeDNSPacket(data, source) segments := strings.Split(domain, ".") var encodedDomain []byte @@ -146,25 +165,28 @@ func ParseDomain(data []byte) []byte { return encodedDomain } -func decodeDNSPacket(packet []byte) string { - var domain string - i := 0 - - for i < len(packet) && packet[i] != 0 { - labelLength := int(packet[i]) +func decodeDNSPacket(packet []byte, source []byte) string { + offset := 0 + labels := []string{} - i++ - if i+labelLength > len(packet) { + for { + if packet[offset] == 0 { break } - domain += string(packet[i : i+labelLength]) - - i += labelLength - if i < len(packet) && packet[i] != 0 { - domain += "." + if (packet[offset]&0xc0)>>6 == 0b11 { + pointer := int(binary.BigEndian.Uint16(packet[offset:offset+2]) << 2 >> 2) + length := bytes.Index(source[pointer:], []byte{0}) + labels = append(labels, decodeDNSPacket(source[pointer:pointer+length+1], source)) + offset += 2 + continue } + + length := int(packet[offset]) + substring := packet[offset+1 : offset+1+length] + labels = append(labels, string(substring)) + offset += length + 1 } - return domain + return strings.Join(labels, ".") } |