From 47233e25fa21af56d0f89aeb40be8f3a781d6ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torbjo=CC=88rn=20Einarsson?= Date: Tue, 30 Sep 2025 17:08:39 +0200 Subject: [PATCH 1/3] docs: CHANGELOG line about fix of issue 444 --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9e32ee7..632381aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] -- Nothing yet +### Fixes + +- Proper handling of trailing bytes in avc1 (VisualSampleEntryBox). Issue 444 ## [0.50.0] - 2025-09-05 From 22e6eb4465b721fff2397d19ca4f000311c71c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torbjo=CC=88rn=20Einarsson?= Date: Tue, 30 Sep 2025 13:25:19 +0200 Subject: [PATCH 2/3] chore: moved and renamed test file --- examples/add-sidx/main_test.go | 2 +- .../testdata/v300_multiple_segments.mp4 | Bin 2 files changed, 1 insertion(+), 1 deletion(-) rename examples/resegmenter/testdata/testV300.mp4 => mp4/testdata/v300_multiple_segments.mp4 (100%) diff --git a/examples/add-sidx/main_test.go b/examples/add-sidx/main_test.go index 4d140ed7..a644cd5e 100644 --- a/examples/add-sidx/main_test.go +++ b/examples/add-sidx/main_test.go @@ -48,7 +48,7 @@ func TestCommandLine(t *testing.T) { }, { desc: "normal file with styp", - args: []string{appName, "../resegmenter/testdata/testV300.mp4", path.Join(tmpDir, "out4.mp4")}, + args: []string{appName, "../../mp4/testdata/v300_multiple_segments.mp4", path.Join(tmpDir, "out4.mp4")}, checkOutput: true, wantedNrSegs: 4, wantedFirstDur: 180000, diff --git a/examples/resegmenter/testdata/testV300.mp4 b/mp4/testdata/v300_multiple_segments.mp4 similarity index 100% rename from examples/resegmenter/testdata/testV300.mp4 rename to mp4/testdata/v300_multiple_segments.mp4 From 1d446edaa8f803bd0be68e97a807580c873071b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torbjo=CC=88rn=20Einarsson?= Date: Tue, 30 Sep 2025 17:07:39 +0200 Subject: [PATCH 3/3] feat: add streaming/incremental processing for fragmented MP4 files. Add Copy method for StypBox and FtypBox. Implements support for processing fragmented MP4 files incrementally without loading entire files into memory. This is essential for handling large media files, live streaming scenarios, and network streams where fragments arrive over time. Key components: 1. StreamFile API (stream.go) - DecodeStream/InitDecodeStream: Reads init segment (ftyp, moov) and stops - ProcessFragments: Iteratively processes fragments with callbacks - Sliding window memory management with configurable retention - Fragment lifecycle callbacks for processing and cleanup control 2. BoxSeekReader (boxseekreader.go) - Emulates seeking on non-seekable streams using buffered reading - Auto-growing buffer sized for current needs - Peek operations for looking ahead without consuming data 3. SampleAccessor interface - GetSample: Fetch individual samples (1-based indexing) - GetSampleRange: Fetch ranges of samples efficiently - GetSamples: Fetch all samples for a track Buffer management: - ReadFullBox returns a slice view of the buffer (no copy for performance) - Callers must use data immediately before next operation - ResetBuffer() explicitly clears buffer after parsing is complete - Buffer clearing happens after box parsing, not during ReadFullBox - For mdat seeks, ResetBuffer() called after seek to avoid buffer corruption Testing: - Comprehensive test suite using external black-box testing (mp4_test package) - Tests cover streaming, callbacks, sliding windows, retention, sample access - All tests use public API only per project guidelines --- CHANGELOG.md | 10 + Makefile | 1 + examples/stream-encrypt/README.md | 128 +++++ examples/stream-encrypt/encryptor.go | 89 +++ examples/stream-encrypt/main.go | 173 ++++++ examples/stream-encrypt/main_test.go | 774 ++++++++++++++++++++++++++ examples/stream-encrypt/refragment.go | 137 +++++ mp4/boxseekreader.go | 357 ++++++++++++ mp4/boxseekreader_test.go | 760 +++++++++++++++++++++++++ mp4/ftyp.go | 7 + mp4/stream.go | 628 +++++++++++++++++++++ mp4/stream_test.go | 489 ++++++++++++++++ mp4/styp.go | 7 + 13 files changed, 3560 insertions(+) create mode 100644 examples/stream-encrypt/README.md create mode 100644 examples/stream-encrypt/encryptor.go create mode 100644 examples/stream-encrypt/main.go create mode 100644 examples/stream-encrypt/main_test.go create mode 100644 examples/stream-encrypt/refragment.go create mode 100644 mp4/boxseekreader.go create mode 100644 mp4/boxseekreader_test.go create mode 100644 mp4/stream.go create mode 100644 mp4/stream_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 632381aa..cec8d399 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Copy methods for FtypBox and StypBox +- StreamFile for decoding a stream of fragmented mp4 +- BoxSeekReader to make an io.Reader available for lazy mdat processing +- examples/stream-encrypt showing how to read and process a multi-segment file + - On an HTTP request, a file is read, optionally further fragmented, and then encrypted + + ### Fixes - Proper handling of trailing bytes in avc1 (VisualSampleEntryBox). Issue 444 +- Proper removal of boxes when decrypting PR 464 ## [0.50.0] - 2025-09-05 diff --git a/Makefile b/Makefile index 81118e19..3223b650 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ build: mp4ff-crop mp4ff-decrypt mp4ff-encrypt mp4ff-info mp4ff-nallister mp4ff-p prepare: go mod tidy +.PHONY: mp4ff-crop mp4ff-decrypt mp4ff-encrypt mp4ff-info mp4ff-nallister mp4ff-pslister mp4ff-subslister mp4ff-crop mp4ff-decrypt mp4ff-encrypt mp4ff-info mp4ff-nallister mp4ff-pslister mp4ff-subslister: go build -ldflags "-X github.com/Eyevinn/mp4ff/mp4.commitVersion=$$(git describe --tags HEAD) -X github.com/Eyevinn/mp4ff/mp4.commitDate=$$(git log -1 --format=%ct)" -o out/$@ ./cmd/$@/main.go diff --git a/examples/stream-encrypt/README.md b/examples/stream-encrypt/README.md new file mode 100644 index 00000000..d3ea082d --- /dev/null +++ b/examples/stream-encrypt/README.md @@ -0,0 +1,128 @@ +# stream-encrypt + +HTTP streaming server that encrypts and refragments MP4 files on-the-fly using the `mp4.StreamFile` API. + +## Features + +- **HTTP Streaming**: Serves MP4 files via HTTP with chunked transfer encoding +- **Refragmentation**: Splits input fragments into smaller output fragments with configurable sample count +- **Encryption**: Encrypts fragments using Common Encryption (CENC or CBCS) +- **Low Latency**: Uses `GetSampleRange()` for minimal buffering and immediate delivery +- **Sequence Number Preservation**: Sub-fragments maintain the same sequence number as their parent + +## Usage + +### Basic Streaming (No Encryption, No Refragmentation) + +```bash +go run *.go +curl http://localhost:8080/enc.mp4 -o output.mp4 +``` + +### Using a Custom Input File + +```bash +go run *.go -input /path/to/your/video.mp4 +curl http://localhost:8080/enc.mp4 -o output.mp4 +``` + +### Refragmentation + +Split fragments to 30 samples each: + +```bash +go run *.go -samples 30 +curl http://localhost:8080/enc.mp4 -o refragmented.mp4 +``` + +### Encryption with Refragmentation + +```bash +go run *.go \ + -samples 30 \ + -key 11223344556677889900aabbccddeeff \ + -keyid 00112233445566778899aabbccddeeff \ + -iv 00000000000000000000000000000000 \ + -scheme cenc + +curl http://localhost:8080/enc.mp4 -o encrypted.mp4 +``` + +### Command-Line Options + +``` + -input string + Input MP4 file path (default "../../mp4/testdata/v300_multiple_segments.mp4") + -port int + HTTP server port (default 8080) + -samples int + Samples per fragment (0=no refrag) (default 0) + -key string + Encryption key (hex) + -keyid string + Key ID (hex) + -iv string + IV (hex) + -scheme string + Encryption scheme (cenc/cbcs) (default "cenc") +``` + +## How It Works + +### Streaming Pipeline + +``` +Input File → StreamFile.InitDecodeStream() → +[Refragment?] → [Encrypt?] → HTTP Response (Chunked) → Client +``` + +1. **Init Segment**: Read and optionally modify for encryption, write immediately +2. **Fragment Processing**: For each input fragment: + - Use `GetSampleRange()` to fetch only needed samples + - Create sub-fragments if refragmentation enabled + - Encrypt if encryption configured + - Write and flush immediately + +### Refragmentation Strategy + +- **Input**: Fragment with 60 samples +- **Output** (samplesPerFrag=30): Two fragments with 30 samples each +- **Sequence Numbers**: Both sub-fragments keep the parent's sequence number +- **Benefits**: Lower latency, smaller chunk sizes for adaptive streaming + +### Encryption + +- **Schemes**: CENC (AES-CTR) or CBCS (AES-CBC) +- **IV Derivation**: Incremental IV per fragment based on fragment number +- **Metadata**: Adds `senc`, `saiz`, `saio` boxes to fragments +- **Init Modification**: Converts sample entries to `encv`/`enca`, adds `sinf` structure + +## Testing + +Run all tests: + +```bash +go test -v +``` + +Individual test steps: + +```bash +go test -v -run TestStep1 # Basic streaming +go test -v -run TestStep2 # Refragmentation +go test -v -run TestStep3 # Encryption +``` + +## Implementation Files + +- **main.go**: HTTP server and request handling +- **refragment.go**: Fragment splitting using `GetSampleRange()` +- **encryptor.go**: Encryption setup and per-fragment encryption +- **main_test.go**: Integration tests for all features + +## Design Principles + +1. **Streaming-First**: Never buffer entire fragments, use `GetSampleRange()` instead of `GetSamples()` +2. **Memory Efficient**: Process and deliver each sub-fragment immediately +3. **Sequence Number Preservation**: All sub-fragments from same input share sequence number +4. **Reuse Existing Code**: Leverages `mp4.InitProtect()` and `mp4.EncryptFragment()` diff --git a/examples/stream-encrypt/encryptor.go b/examples/stream-encrypt/encryptor.go new file mode 100644 index 00000000..128fc782 --- /dev/null +++ b/examples/stream-encrypt/encryptor.go @@ -0,0 +1,89 @@ +package main + +import ( + "encoding/hex" + "fmt" + + "github.com/Eyevinn/mp4ff/mp4" +) + +type EncryptConfig struct { + Key []byte + KeyID []byte + IV []byte + Scheme string +} + +type StreamEncryptor struct { + config EncryptConfig + ipd *mp4.InitProtectData + fragNum uint32 + encryptedInit *mp4.InitSegment +} + +func NewStreamEncryptor(init *mp4.InitSegment, config EncryptConfig) (*StreamEncryptor, error) { + if len(config.IV) != 16 && len(config.IV) != 8 { + return nil, fmt.Errorf("IV must be 8 or 16 bytes") + } + if len(config.Key) != 16 { + return nil, fmt.Errorf("key must be 16 bytes") + } + if len(config.KeyID) != 16 { + return nil, fmt.Errorf("keyID must be 16 bytes") + } + + kidHex := hex.EncodeToString(config.KeyID) + kidUUID, err := mp4.NewUUIDFromString(kidHex) + if err != nil { + return nil, fmt.Errorf("invalid key ID: %w", err) + } + + ipd, err := mp4.InitProtect(init, config.Key, config.IV, config.Scheme, kidUUID, nil) + if err != nil { + return nil, fmt.Errorf("init protect: %w", err) + } + + return &StreamEncryptor{ + config: config, + ipd: ipd, + encryptedInit: init, + }, nil +} + +func (se *StreamEncryptor) GetEncryptedInit() *mp4.InitSegment { + return se.encryptedInit +} + +func (se *StreamEncryptor) EncryptFragment(frag *mp4.Fragment) error { + se.fragNum++ + + iv := se.deriveIV(se.fragNum) + + err := mp4.EncryptFragment(frag, se.config.Key, iv, se.ipd) + if err != nil { + return fmt.Errorf("encrypt fragment %d: %w", se.fragNum, err) + } + + return nil +} + +func (se *StreamEncryptor) deriveIV(fragNum uint32) []byte { + iv := make([]byte, 16) + copy(iv, se.config.IV) + + for i := 15; i >= 12; i-- { + carry := fragNum & 0xFF + sum := uint32(iv[i]) + carry + iv[i] = byte(sum & 0xFF) + fragNum = fragNum >> 8 + if fragNum == 0 { + break + } + } + + return iv +} + +func ParseHexKey(s string) ([]byte, error) { + return hex.DecodeString(s) +} diff --git a/examples/stream-encrypt/main.go b/examples/stream-encrypt/main.go new file mode 100644 index 00000000..f56f65d0 --- /dev/null +++ b/examples/stream-encrypt/main.go @@ -0,0 +1,173 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + + "github.com/Eyevinn/mp4ff/mp4" +) + +const ( + appName = "stream-encrypt" +) + +var usg = `%s is an HTTP streaming server that encrypts and refragments MP4 files on-the-fly. + +It serves the specified input MP4 file at /enc.mp4 with optional encryption and refragmentation. + +Usage of %s: +` + +type options struct { + port int + samplesPerFrag int + key string + keyID string + iv string + scheme string + inputFile string +} + +func parseOptions(fs *flag.FlagSet, args []string) (*options, error) { + fs.Usage = func() { + fmt.Fprintf(os.Stderr, usg, appName, appName) + fmt.Fprintf(os.Stderr, "\n%s [options]\n\noptions:\n", appName) + fs.PrintDefaults() + } + + opts := options{} + + fs.IntVar(&opts.port, "port", 8080, "HTTP server port") + fs.IntVar(&opts.samplesPerFrag, "samples", 0, "Samples per fragment (0=no refrag)") + fs.StringVar(&opts.key, "key", "", "Encryption key (hex)") + fs.StringVar(&opts.keyID, "keyid", "", "Key ID (hex)") + fs.StringVar(&opts.iv, "iv", "", "IV (hex)") + fs.StringVar(&opts.scheme, "scheme", "cenc", "Encryption scheme (cenc/cbcs)") + fs.StringVar(&opts.inputFile, "input", "../../mp4/testdata/v300_multiple_segments.mp4", "Input MP4 file path") + + err := fs.Parse(args[1:]) + return &opts, err +} + +func makeStreamHandler(opts options) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + f, err := os.Open(opts.inputFile) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to open input file: %v", err), http.StatusInternalServerError) + return + } + defer f.Close() + + w.Header().Set("Content-Type", "video/mp4") + w.Header().Set("Transfer-Encoding", "chunked") + + config := RefragmentConfig{ + SamplesPerFrag: uint32(opts.samplesPerFrag), + } + + var encryptor *StreamEncryptor + + sf, err := mp4.InitDecodeStream(f, + mp4.WithFragmentCallback(func(frag *mp4.Fragment, sa mp4.SampleAccessor) error { + return processFragment(frag, sa, config, func(outFrag *mp4.Fragment) error { + if encryptor != nil { + if err := encryptor.EncryptFragment(outFrag); err != nil { + return err + } + } + + if err := outFrag.Encode(w); err != nil { + return err + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + return nil + }) + })) + + if err != nil { + log.Printf("InitDecodeStream failed: %v", err) + return + } + + if opts.key != "" && opts.keyID != "" && opts.iv != "" { + keyBytes, err := ParseHexKey(opts.key) + if err != nil { + http.Error(w, fmt.Sprintf("Invalid key: %v", err), http.StatusBadRequest) + return + } + keyIDBytes, err := ParseHexKey(opts.keyID) + if err != nil { + http.Error(w, fmt.Sprintf("Invalid keyID: %v", err), http.StatusBadRequest) + return + } + ivBytes, err := ParseHexKey(opts.iv) + if err != nil { + http.Error(w, fmt.Sprintf("Invalid IV: %v", err), http.StatusBadRequest) + return + } + + encConfig := EncryptConfig{ + Key: keyBytes, + KeyID: keyIDBytes, + IV: ivBytes, + Scheme: opts.scheme, + } + + encryptor, err = NewStreamEncryptor(sf.Init, encConfig) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create encryptor: %v", err), http.StatusInternalServerError) + return + } + + sf.Init = encryptor.GetEncryptedInit() + } + + if sf.Init != nil { + if err := sf.Init.Encode(w); err != nil { + log.Printf("Write init failed: %v", err) + return + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + + if err := sf.ProcessFragments(); err != nil { + trailingBoxes := &mp4.TrailingBoxesErrror{} + if errors.As(err, &trailingBoxes) { + log.Printf("ProcessFragments warning: %v", err) + } else { + log.Printf("ProcessFragments failed: %v", err) + } + } + } +} + +func main() { + if err := run(os.Args); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} + +func run(args []string) error { + fs := flag.NewFlagSet(appName, flag.ContinueOnError) + opts, err := parseOptions(fs, args) + if err != nil { + if errors.Is(err, flag.ErrHelp) { + return nil + } + return err + } + + http.HandleFunc("/enc.mp4", makeStreamHandler(*opts)) + addr := fmt.Sprintf(":%d", opts.port) + log.Printf("Server starting on %s, serving %s at /enc.mp4", addr, opts.inputFile) + return http.ListenAndServe(addr, nil) +} diff --git a/examples/stream-encrypt/main_test.go b/examples/stream-encrypt/main_test.go new file mode 100644 index 00000000..30938118 --- /dev/null +++ b/examples/stream-encrypt/main_test.go @@ -0,0 +1,774 @@ +package main + +import ( + "bytes" + "errors" + "flag" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/Eyevinn/mp4ff/mp4" +) + +func TestStep1_BasicHTTPStreaming(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + opts := options{inputFile: inputFile} + server := httptest.NewServer(makeStreamHandler(opts)) + defer server.Close() + + resp, err := http.Get(server.URL + "/enc.mp4") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", resp.StatusCode) + } + + if ct := resp.Header.Get("Content-Type"); ct != "video/mp4" { + t.Errorf("Expected Content-Type video/mp4, got %s", ct) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + if len(body) == 0 { + t.Fatal("Response body is empty") + } + + parsedFile, err := mp4.DecodeFile(bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to parse MP4: %v", err) + } + + if parsedFile.Init == nil { + t.Error("Init segment is nil") + } + + if len(parsedFile.Segments) == 0 { + t.Error("No segments found") + } + + originalData, err := os.ReadFile(inputFile) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + originalFile, err := mp4.DecodeFile(bytes.NewReader(originalData)) + if err != nil { + t.Fatalf("Failed to parse original: %v", err) + } + + originalFragCount := 0 + for _, seg := range originalFile.Segments { + originalFragCount += len(seg.Fragments) + } + + outputFragCount := 0 + for _, seg := range parsedFile.Segments { + outputFragCount += len(seg.Fragments) + } + + if outputFragCount != originalFragCount { + t.Errorf("Fragment count mismatch: got %d, expected %d", outputFragCount, originalFragCount) + } + + t.Logf("Successfully streamed MP4 with %d fragments", outputFragCount) +} + +func TestStep2_Refragmentation(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + opts := options{ + inputFile: inputFile, + samplesPerFrag: 30, + } + server := httptest.NewServer(makeStreamHandler(opts)) + defer server.Close() + + resp, err := http.Get(server.URL + "/enc.mp4") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + parsedFile, err := mp4.DecodeFile(bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to parse MP4: %v", err) + } + + originalData, err := os.ReadFile(inputFile) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + originalFile, err := mp4.DecodeFile(bytes.NewReader(originalData)) + if err != nil { + t.Fatalf("Failed to parse original: %v", err) + } + + originalFragCount := 0 + for _, seg := range originalFile.Segments { + originalFragCount += len(seg.Fragments) + } + + outputFragCount := 0 + for _, seg := range parsedFile.Segments { + outputFragCount += len(seg.Fragments) + } + + if outputFragCount <= originalFragCount { + t.Errorf("Expected more fragments after refragmentation: got %d, original %d", outputFragCount, originalFragCount) + } + + for _, seg := range parsedFile.Segments { + for _, frag := range seg.Fragments { + for _, traf := range frag.Moof.Trafs { + sampleCount := traf.Trun.SampleCount() + if sampleCount > 30 { + t.Errorf("Fragment has %d samples, expected <= 30", sampleCount) + } + } + } + } + + for _, seg := range parsedFile.Segments { + for i, frag := range seg.Fragments { + if i > 0 { + prevFrag := seg.Fragments[i-1] + if frag.Moof.Mfhd.SequenceNumber != prevFrag.Moof.Mfhd.SequenceNumber { + prevSamples := prevFrag.Moof.Traf.Trun.SampleCount() + if prevSamples <= 30 { + t.Logf("Sequence number changed from %d to %d (previous frag had %d samples)", + prevFrag.Moof.Mfhd.SequenceNumber, frag.Moof.Mfhd.SequenceNumber, prevSamples) + } + } + } + } + } + + t.Logf("Successfully refragmented: %d → %d fragments, max %d samples per fragment", + originalFragCount, outputFragCount, opts.samplesPerFrag) +} + +func TestStep3_Encryption(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + opts := options{ + inputFile: inputFile, + samplesPerFrag: 30, + key: "11223344556677889900aabbccddeeff", + keyID: "00112233445566778899aabbccddeeff", + iv: "00000000000000000000000000000000", + scheme: "cenc", + } + + originalData, err := os.ReadFile(inputFile) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + originalFile, err := mp4.DecodeFile(bytes.NewReader(originalData)) + if err != nil { + t.Fatalf("Failed to parse original: %v", err) + } + + server := httptest.NewServer(makeStreamHandler(opts)) + defer server.Close() + + resp, err := http.Get(server.URL + "/enc.mp4") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + parsedFile, err := mp4.DecodeFile(bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to parse MP4: %v", err) + } + + if parsedFile.Init == nil { + t.Fatal("No init segment") + } + + hasPssh := false + for _, child := range parsedFile.Init.Moov.Children { + if child.Type() == "pssh" { + hasPssh = true + break + } + } + + stsd := parsedFile.Init.Moov.Trak.Mdia.Minf.Stbl.Stsd + if len(stsd.Children) == 0 { + t.Fatal("No sample entries in stsd") + } + + sampleEntry := stsd.Children[0] + isEncrypted := sampleEntry.Type() == "encv" || sampleEntry.Type() == "enca" + if !isEncrypted { + t.Errorf("Sample entry not encrypted: %s", sampleEntry.Type()) + } + + for _, seg := range parsedFile.Segments { + for _, frag := range seg.Fragments { + hasSenc := false + for _, child := range frag.Moof.Traf.Children { + if child.Type() == "senc" { + hasSenc = true + break + } + } + if !hasSenc { + t.Error("Fragment missing senc box") + } + } + } + + t.Logf("Successfully encrypted: hasPssh=%v, sampleEntry=%s, fragments=%d", + hasPssh, sampleEntry.Type(), len(parsedFile.Segments[0].Fragments)) + + keyBytes, err := ParseHexKey(opts.key) + if err != nil { + t.Fatalf("Failed to parse key: %v", err) + } + + decInfo, err := mp4.DecryptInit(parsedFile.Init) + if err != nil { + t.Fatalf("Failed to get decrypt info: %v", err) + } + + for _, seg := range parsedFile.Segments { + err := mp4.DecryptSegment(seg, decInfo, keyBytes) + if err != nil { + t.Fatalf("Failed to decrypt segment: %v", err) + } + } + + var origAllSamples [][]byte + var decAllSamples [][]byte + + trackID := uint32(0) + if len(originalFile.Segments) > 0 && len(originalFile.Segments[0].Fragments) > 0 { + trackID = originalFile.Segments[0].Fragments[0].Moof.Traf.Tfhd.TrackID + } + + origTrex, ok := originalFile.Init.Moov.Mvex.GetTrex(trackID) + if !ok { + t.Fatalf("Failed to get trex for original track %d", trackID) + } + + for segIdx, origSeg := range originalFile.Segments { + for fragIdx, origFrag := range origSeg.Fragments { + origSamples, err := origFrag.GetFullSamples(origTrex) + if err != nil { + t.Fatalf("Failed to get original samples from segment %d fragment %d: %v", + segIdx, fragIdx, err) + } + for _, sample := range origSamples { + origAllSamples = append(origAllSamples, sample.Data) + } + } + } + + decTrex, ok := parsedFile.Init.Moov.Mvex.GetTrex(trackID) + if !ok { + t.Fatalf("Failed to get trex for decrypted track %d", trackID) + } + + for segIdx, decSeg := range parsedFile.Segments { + for fragIdx, decFrag := range decSeg.Fragments { + decSamples, err := decFrag.GetFullSamples(decTrex) + if err != nil { + t.Fatalf("Failed to get decrypted samples from segment %d fragment %d: %v", + segIdx, fragIdx, err) + } + for _, sample := range decSamples { + decAllSamples = append(decAllSamples, sample.Data) + } + } + } + + if len(origAllSamples) != len(decAllSamples) { + t.Fatalf("Total sample count mismatch: original=%d, decrypted=%d", + len(origAllSamples), len(decAllSamples)) + } + + for i := range origAllSamples { + if !bytes.Equal(origAllSamples[i], decAllSamples[i]) { + t.Errorf("Sample data mismatch at sample %d: size original=%d, decrypted=%d", + i, len(origAllSamples[i]), len(decAllSamples[i])) + } + } + + t.Logf("Successfully decrypted and verified all samples match original") +} + +func TestFtypStypPreservation(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + opts := options{ + inputFile: inputFile, + samplesPerFrag: 30, + key: "11223344556677889900aabbccddeeff", + keyID: "00112233445566778899aabbccddeeff", + iv: "00000000000000000000000000000000", + scheme: "cenc", + } + + originalData, err := os.ReadFile(inputFile) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + originalFile, err := mp4.DecodeFile(bytes.NewReader(originalData)) + if err != nil { + t.Fatalf("Failed to parse original: %v", err) + } + + server := httptest.NewServer(makeStreamHandler(opts)) + defer server.Close() + + resp, err := http.Get(server.URL + "/enc.mp4") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + outputFile, err := mp4.DecodeFile(bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to parse output: %v", err) + } + + if outputFile.Ftyp.MajorBrand() != originalFile.Ftyp.MajorBrand() { + t.Errorf("Ftyp major brand mismatch: got %s, expected %s", + outputFile.Ftyp.MajorBrand(), originalFile.Ftyp.MajorBrand()) + } + + if outputFile.Ftyp.MinorVersion() != originalFile.Ftyp.MinorVersion() { + t.Errorf("Ftyp minor version mismatch: got %d, expected %d", + outputFile.Ftyp.MinorVersion(), originalFile.Ftyp.MinorVersion()) + } + + if len(outputFile.Ftyp.CompatibleBrands()) != len(originalFile.Ftyp.CompatibleBrands()) { + t.Errorf("Ftyp compatible brands count mismatch: got %d, expected %d", + len(outputFile.Ftyp.CompatibleBrands()), len(originalFile.Ftyp.CompatibleBrands())) + } + + originalStypCount := len(originalFile.Segments) + outputStypCount := 0 + for _, seg := range outputFile.Segments { + if seg.Styp != nil { + outputStypCount++ + } + } + + if outputStypCount != originalStypCount { + t.Errorf("Styp count mismatch: got %d, expected %d", outputStypCount, originalStypCount) + } + + if len(outputFile.Segments) > 0 && len(outputFile.Segments[0].Fragments) > 0 { + outputStyp := outputFile.Segments[0].Styp + originalStyp := originalFile.Segments[0].Styp + + if outputStyp == nil || originalStyp == nil { + t.Fatal("Missing styp box") + } + + if outputStyp.MajorBrand() != originalStyp.MajorBrand() { + t.Errorf("Styp major brand mismatch: got %s, expected %s", + outputStyp.MajorBrand(), originalStyp.MajorBrand()) + } + + if outputStyp.MinorVersion() != originalStyp.MinorVersion() { + t.Errorf("Styp minor version mismatch: got %d, expected %d", + outputStyp.MinorVersion(), originalStyp.MinorVersion()) + } + } + + t.Logf("Ftyp and Styp preserved correctly (styp count: %d)", outputStypCount) +} + +func TestEncryptorErrors(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + originalData, err := os.ReadFile(inputFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + originalFile, err := mp4.DecodeFile(bytes.NewReader(originalData)) + if err != nil { + t.Fatalf("Failed to parse test file: %v", err) + } + + validKey := []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff} + validKeyID := []byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff} + validIV := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + + tests := []struct { + name string + key []byte + keyID []byte + iv []byte + scheme string + expectErr string + }{ + { + name: "invalid IV length (too short)", + key: validKey, + keyID: validKeyID, + iv: []byte{0x00, 0x00, 0x00}, + scheme: "cenc", + expectErr: "IV must be 8 or 16 bytes", + }, + { + name: "invalid key length", + key: []byte{0x11, 0x22}, + keyID: validKeyID, + iv: validIV, + scheme: "cenc", + expectErr: "key must be 16 bytes", + }, + { + name: "invalid keyID length", + key: validKey, + keyID: []byte{0x00, 0x11}, + iv: validIV, + scheme: "cenc", + expectErr: "keyID must be 16 bytes", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + config := EncryptConfig{ + Key: tc.key, + KeyID: tc.keyID, + IV: tc.iv, + Scheme: tc.scheme, + } + + _, err := NewStreamEncryptor(originalFile.Init, config) + if err == nil { + t.Fatalf("Expected error, got nil") + } + if tc.expectErr != "" && !bytes.Contains([]byte(err.Error()), []byte(tc.expectErr)) { + t.Errorf("Expected error containing %q, got %q", tc.expectErr, err.Error()) + } + }) + } +} + +func TestInvalidHexKeys(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + tests := []struct { + name string + opts options + expect string + }{ + { + name: "invalid key hex", + opts: options{ + inputFile: inputFile, + key: "invalid-hex", + keyID: "00112233445566778899aabbccddeeff", + iv: "00000000000000000000000000000000", + }, + expect: "Invalid key", + }, + { + name: "invalid keyID hex", + opts: options{ + inputFile: inputFile, + key: "11223344556677889900aabbccddeeff", + keyID: "zzz", + iv: "00000000000000000000000000000000", + }, + expect: "Invalid keyID", + }, + { + name: "invalid IV hex", + opts: options{ + inputFile: inputFile, + key: "11223344556677889900aabbccddeeff", + keyID: "00112233445566778899aabbccddeeff", + iv: "notvalidhex", + }, + expect: "Invalid IV", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(makeStreamHandler(tc.opts)) + defer server.Close() + + resp, err := http.Get(server.URL + "/enc.mp4") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + if !bytes.Contains(body, []byte(tc.expect)) { + t.Errorf("Expected error message containing %q, got %q", tc.expect, string(body)) + } + }) + } +} + +func TestEncryptorWithInvalidKeyID(t *testing.T) { + inputFile := "../../mp4/testdata/v300_multiple_segments.mp4" + originalData, err := os.ReadFile(inputFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + originalFile, err := mp4.DecodeFile(bytes.NewReader(originalData)) + if err != nil { + t.Fatalf("Failed to parse test file: %v", err) + } + + config := EncryptConfig{ + Key: []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + KeyID: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + IV: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + Scheme: "cenc", + } + + _, err = NewStreamEncryptor(originalFile.Init, config) + if err != nil { + t.Logf("NewStreamEncryptor returned expected error: %v", err) + } +} + +func TestParseOptions(t *testing.T) { + tests := []struct { + name string + args []string + expectError bool + }{ + { + name: "default options", + args: []string{"stream-encrypt"}, + expectError: false, + }, + { + name: "with port", + args: []string{"stream-encrypt", "-port", "9090"}, + expectError: false, + }, + { + name: "with all encryption options", + args: []string{"stream-encrypt", "-key", "abc", "-keyid", "def", "-iv", "ghi", "-scheme", "cbcs"}, + expectError: false, + }, + { + name: "with samples", + args: []string{"stream-encrypt", "-samples", "60"}, + expectError: false, + }, + { + name: "invalid flag", + args: []string{"stream-encrypt", "-invalid"}, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + opts, err := parseOptions(fs, tc.args) + + if tc.expectError { + if err == nil { + t.Errorf("Expected error, got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if opts == nil { + t.Errorf("Expected options, got nil") + } + } + }) + } +} + +func TestRunWithInvalidFlags(t *testing.T) { + tests := []struct { + name string + args []string + expectError bool + }{ + { + name: "invalid flag", + args: []string{"stream-encrypt", "-input", "a.mp4", "-nonexistent"}, + expectError: true, + }, + { + name: "help flag", + args: []string{"stream-encrypt", "-h"}, + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := run(tc.args) + + if tc.expectError { + if err == nil { + t.Errorf("Expected error, got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + }) + } +} + +func TestTrailingBoxHandling(t *testing.T) { + // Read the original test file + originalFile := "../../mp4/testdata/v300_multiple_segments.mp4" + originalData, err := os.ReadFile(originalFile) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + // Create a temp file with the original content plus a trailing skip box + tempFile, err := os.CreateTemp("", "test_trailing_*.mp4") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tempFile.Name()) + defer tempFile.Close() + + // Write original content + _, err = tempFile.Write(originalData) + if err != nil { + t.Fatalf("Failed to write original data: %v", err) + } + + // Create and append a skip box + // Skip box format: size (4 bytes) + type (4 bytes) + data + skipBox := mp4.NewSkipBox([]byte("trailing data")) + buf := &bytes.Buffer{} + err = skipBox.Encode(buf) + if err != nil { + t.Fatalf("Failed to encode skip box: %v", err) + } + + _, err = tempFile.Write(buf.Bytes()) + if err != nil { + t.Fatalf("Failed to write skip box: %v", err) + } + + tempFile.Close() + + // Test with the file containing trailing box directly + testData, err := os.ReadFile(tempFile.Name()) + if err != nil { + t.Fatalf("Failed to read temp file: %v", err) + } + + // First, let's verify the trailing box detection works with InitDecodeStream + reader := bytes.NewReader(testData) + var detectedTrailingBox bool + sf, err := mp4.InitDecodeStream(reader, + mp4.WithFragmentCallback(func(frag *mp4.Fragment, sa mp4.SampleAccessor) error { + // Just process the fragment + return nil + })) + + if err != nil { + t.Fatalf("InitDecodeStream failed: %v", err) + } + + // Process fragments and check for trailing boxes error + err = sf.ProcessFragments() + if err != nil { + trailingBoxes := &mp4.TrailingBoxesErrror{} + if errors.As(err, &trailingBoxes) { + detectedTrailingBox = true + t.Logf("Detected trailing boxes: %v", trailingBoxes.BoxNames) + + // Verify that we have exactly one box and it's "skip" + if len(trailingBoxes.BoxNames) != 1 { + t.Errorf("Expected exactly 1 trailing box, got %d: %v", len(trailingBoxes.BoxNames), trailingBoxes.BoxNames) + } + + if len(trailingBoxes.BoxNames) > 0 && trailingBoxes.BoxNames[0] != "skip" { + t.Errorf("Expected trailing box to be 'skip', got: %s", trailingBoxes.BoxNames[0]) + } + } else { + t.Fatalf("ProcessFragments failed with unexpected error: %v", err) + } + } + + if !detectedTrailingBox { + t.Error("Expected TrailingBoxesErrror but didn't get one") + } + + // Now test with the actual HTTP handler + opts := options{ + inputFile: tempFile.Name(), + } + + server := httptest.NewServer(makeStreamHandler(opts)) + defer server.Close() + + // Make the request + resp, err := http.Get(server.URL + "/enc.mp4") + if err != nil { + t.Fatalf("Failed to GET: %v", err) + } + defer resp.Body.Close() + + // Read the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read body: %v", err) + } + + // Verify we got valid output despite the trailing box + parsedFile, err := mp4.DecodeFile(bytes.NewReader(body)) + if err != nil { + t.Fatalf("Failed to parse output MP4: %v", err) + } + + if parsedFile.Init == nil { + t.Error("Output file has no init segment") + } + + if len(parsedFile.Segments) == 0 { + t.Error("Output file has no segments") + } + + t.Logf("Successfully handled file with trailing skip box - TrailingBoxesErrror detected and handled gracefully") +} diff --git a/examples/stream-encrypt/refragment.go b/examples/stream-encrypt/refragment.go new file mode 100644 index 00000000..2af07eb5 --- /dev/null +++ b/examples/stream-encrypt/refragment.go @@ -0,0 +1,137 @@ +package main + +import ( + "fmt" + + "github.com/Eyevinn/mp4ff/mp4" +) + +type RefragmentConfig struct { + SamplesPerFrag uint32 +} + +func processFragment( + inputFrag *mp4.Fragment, + sa mp4.SampleAccessor, + config RefragmentConfig, + writeFunc func(*mp4.Fragment) error, +) error { + trackID := inputFrag.Moof.Traf.Tfhd.TrackID + totalSamples := getTotalSampleCount(inputFrag.Moof.Traf.Trun) + + if config.SamplesPerFrag == 0 || totalSamples <= config.SamplesPerFrag { + samples, err := sa.GetSamples(trackID) + if err != nil { + return err + } + + inputFrag.Mdat.Data = nil + inputFrag.Mdat.SetLazyDataSize(0) + for _, s := range samples { + inputFrag.Mdat.AddSampleData(s.Data) + } + + return writeFunc(inputFrag) + } + + isFirstSubFrag := true + for startNr := uint32(1); startNr <= totalSamples; { + endNr := min(startNr+config.SamplesPerFrag-1, totalSamples) + + samples, err := sa.GetSampleRange(trackID, startNr, endNr) + if err != nil { + return fmt.Errorf("GetSampleRange(%d, %d): %w", startNr, endNr, err) + } + + outFrag, err := createFragmentFromSamples(inputFrag, samples, startNr, endNr, isFirstSubFrag) + if err != nil { + return fmt.Errorf("createFragmentFromSamples: %w", err) + } + + if err := writeFunc(outFrag); err != nil { + return err + } + + startNr = endNr + 1 + isFirstSubFrag = false + } + + return nil +} + +func getTotalSampleCount(trun *mp4.TrunBox) uint32 { + return trun.SampleCount() +} + +func createFragmentFromSamples( + inputFrag *mp4.Fragment, + samples []mp4.FullSample, + _startNr, _endNr uint32, + isFirstSubFrag bool, +) (*mp4.Fragment, error) { + if len(samples) == 0 { + return nil, fmt.Errorf("no samples provided") + } + + newFrag := mp4.NewFragment() + + for _, child := range inputFrag.Children { + switch child.Type() { + case "styp": + if isFirstSubFrag { + newFrag.AddChild(child) + } + case "sidx", "emsg", "prft": + newFrag.AddChild(child) + } + } + + seqNum := inputFrag.Moof.Mfhd.SequenceNumber + tfhd := inputFrag.Moof.Traf.Tfhd + trackID := tfhd.TrackID + + moof := &mp4.MoofBox{} + mfhd := mp4.CreateMfhd(seqNum) + _ = moof.AddChild(mfhd) + + traf := &mp4.TrafBox{} + _ = moof.AddChild(traf) + + newTfhd := mp4.CreateTfhd(trackID) + if tfhd.HasDefaultSampleDuration() { + newTfhd.DefaultSampleDuration = tfhd.DefaultSampleDuration + } + if tfhd.HasDefaultSampleSize() { + newTfhd.DefaultSampleSize = tfhd.DefaultSampleSize + } + if tfhd.HasDefaultSampleFlags() { + newTfhd.DefaultSampleFlags = tfhd.DefaultSampleFlags + } + _ = traf.AddChild(newTfhd) + + tfdt := mp4.CreateTfdt(samples[0].DecodeTime) + _ = traf.AddChild(tfdt) + + trun := mp4.CreateTrun(0) + + mdat := &mp4.MdatBox{} + + for _, fullSample := range samples { + trun.AddSample(fullSample.Sample) + mdat.AddSampleData(fullSample.Data) + } + + _ = traf.AddChild(trun) + + newFrag.AddChild(moof) + newFrag.AddChild(mdat) + + return newFrag, nil +} + +func min(a, b uint32) uint32 { + if a < b { + return a + } + return b +} diff --git a/mp4/boxseekreader.go b/mp4/boxseekreader.go new file mode 100644 index 00000000..5795e6b0 --- /dev/null +++ b/mp4/boxseekreader.go @@ -0,0 +1,357 @@ +package mp4 + +import ( + "encoding/binary" + "fmt" + "io" +) + +// BoxSeekReader wraps an io.Reader and provides limited io.ReadSeeker functionality. +// It maintains a single growing buffer that's reused for: +// 1. Reading entire top-level boxes into memory for parsing with DecodeBoxSR +// 2. Buffering mdat payload data on-demand when samples are accessed +// +// The buffer grows to accommodate the largest box seen and is reused across boxes. +type BoxSeekReader struct { + reader io.Reader + buffer []byte // Single reusable buffer that grows as needed + bufferPos uint64 // Absolute position of first byte in buffer + currentPos uint64 // Current read position in stream + mdatStart uint64 // Start of current mdat payload (when mdatActive) + mdatSize uint64 // Size of current mdat payload + mdatActive bool // Whether we're within mdat and doing lazy reading +} + +// NewBoxSeekReader creates a BoxSeekReader with initial buffer capacity. +func NewBoxSeekReader(r io.Reader, initialSize int) *BoxSeekReader { + if initialSize <= 0 { + initialSize = 64 * 1024 // Default 64KB, will grow as needed + } + return &BoxSeekReader{ + reader: r, + buffer: make([]byte, 0, initialSize), + bufferPos: 0, + currentPos: 0, + mdatActive: false, + } +} + +// ReadFullBox reads an entire box into the buffer and returns a slice view. +// Should be called after PeekBoxHeader, which already has the header in the buffer. +// Reads the remaining payload and returns the complete box data. +func (bsr *BoxSeekReader) ReadFullBox(boxSize uint64) ([]byte, error) { + if boxSize > uint64(2<<30) { // Sanity check: 2GB limit + return nil, fmt.Errorf("box size %d too large", boxSize) + } + + size := int(boxSize) + headerLen := len(bsr.buffer) // Header bytes already in buffer from PeekBoxHeader + payloadLen := size - headerLen + + if payloadLen < 0 { + return nil, fmt.Errorf("box size %d smaller than header %d", size, headerLen) + } + + // Ensure buffer has enough capacity for full box + if cap(bsr.buffer) < size { + // Need to grow - copy header to new buffer + newBuf := make([]byte, size) + copy(newBuf, bsr.buffer[:headerLen]) + bsr.buffer = newBuf + } else { + // Reuse existing buffer, resize to full box size + bsr.buffer = bsr.buffer[:size] + } + + // Read payload into buffer after header + if payloadLen > 0 { + n, err := io.ReadFull(bsr.reader, bsr.buffer[headerLen:]) + if err != nil { + return nil, err + } + if n != payloadLen { + return nil, fmt.Errorf("read %d payload bytes, expected %d", n, payloadLen) + } + } + + // Update position tracking + bsr.currentPos = bsr.bufferPos + uint64(size) + + // Return slice view of buffer - caller must use before next operation + return bsr.buffer[:size], nil +} + +// SetMdatBounds configures the emulator for lazy reading of an mdat box. +// Sets up the bounds but does NOT read data into buffer yet - data is read on-demand +// when samples are accessed via Read operations. +// mdatPayloadStart is the absolute file position where mdat payload begins. +// mdatPayloadSize is the size of the mdat payload in bytes. +func (bsr *BoxSeekReader) SetMdatBounds(mdatPayloadStart, mdatPayloadSize uint64) { + bsr.mdatStart = mdatPayloadStart + bsr.mdatSize = mdatPayloadSize + bsr.mdatActive = true + + // Clear buffer and reset position to start of mdat payload + // Buffer will be filled on-demand when Read is called + bsr.buffer = bsr.buffer[:0] + bsr.bufferPos = mdatPayloadStart + bsr.currentPos = mdatPayloadStart +} + +// ResetBuffer clears the buffer and mdat state. +// Buffer capacity is preserved for reuse. +func (bsr *BoxSeekReader) ResetBuffer() { + bsr.buffer = bsr.buffer[:0] + bsr.bufferPos = bsr.currentPos + bsr.mdatActive = false + bsr.mdatStart = 0 + bsr.mdatSize = 0 +} + +// Read reads data from the underlying reader, updating the buffer as needed. +// When mdatActive is true, enforces bounds checking to stay within mdat payload. +// Note that n may be less than len(p) if hitting mdat bounds or EOF. +func (bsr *BoxSeekReader) Read(p []byte) (n int, err error) { + // Bounds check if within mdat + if bsr.mdatActive { + if bsr.currentPos < bsr.mdatStart { + return 0, fmt.Errorf("read position %d before mdat start %d", bsr.currentPos, bsr.mdatStart) + } + mdatEnd := bsr.mdatStart + bsr.mdatSize + if bsr.currentPos >= mdatEnd { + return 0, io.EOF + } + // Limit read to mdat bounds + maxRead := mdatEnd - bsr.currentPos + if uint64(len(p)) > maxRead { + p = p[:maxRead] + } + } + + // Check if we can read from buffer + if bsr.currentPos >= bsr.bufferPos && bsr.currentPos < bsr.bufferPos+uint64(len(bsr.buffer)) { + // Read from buffer + offsetInBuffer := int(bsr.currentPos - bsr.bufferPos) + availableInBuffer := len(bsr.buffer) - offsetInBuffer + toCopy := len(p) + if toCopy > availableInBuffer { + toCopy = availableInBuffer + } + copy(p[:toCopy], bsr.buffer[offsetInBuffer:offsetInBuffer+toCopy]) + bsr.currentPos += uint64(toCopy) + + if toCopy == len(p) { + return toCopy, nil + } + + // Need more data from underlying reader + remaining := p[toCopy:] + n2, err := bsr.reader.Read(remaining) + if n2 > 0 { + bsr.buffer = append(bsr.buffer, remaining[:n2]...) + bsr.currentPos += uint64(n2) + } + return toCopy + n2, err + } + + // Read from underlying reader + n, err = bsr.reader.Read(p) + if n > 0 { + bsr.buffer = append(bsr.buffer, p[:n]...) + bsr.currentPos += uint64(n) + } + return n, err +} + +// Seek moves the read position within the current mdat or buffered data. +// When mdatActive is true, seeks are restricted to the mdat payload bounds. +// Only supports limited backward seeks within the buffer. +func (bsr *BoxSeekReader) Seek(offset int64, whence int) (int64, error) { + var newPos int64 + + switch whence { + case io.SeekStart: + newPos = offset + case io.SeekCurrent: + newPos = int64(bsr.currentPos) + offset + case io.SeekEnd: + return 0, fmt.Errorf("seek from end not supported in stream mode") + default: + return 0, fmt.Errorf("invalid whence value: %d", whence) + } + + if newPos < 0 { + return 0, fmt.Errorf("seek to negative position: %d", newPos) + } + + // Bounds check if within mdat + if bsr.mdatActive { + if newPos < int64(bsr.mdatStart) { + return 0, fmt.Errorf("seek position %d before mdat start %d", newPos, bsr.mdatStart) + } + mdatEnd := int64(bsr.mdatStart + bsr.mdatSize) + if newPos > mdatEnd { + return 0, fmt.Errorf("seek position %d beyond mdat end %d", newPos, mdatEnd) + } + } + + // Check if target position is within buffer + bufferStart := int64(bsr.bufferPos) + bufferEnd := int64(bsr.bufferPos) + int64(len(bsr.buffer)) + + if newPos >= bufferStart && newPos <= bufferEnd { + // Seeking within buffer + bsr.currentPos = uint64(newPos) + return newPos, nil + } + + if newPos < bufferStart { + return 0, fmt.Errorf("seek position %d is before buffer start %d (buffer size: %d bytes)", + newPos, bufferStart, len(bsr.buffer)) + } + + // Forward seek beyond buffer - read data directly into buffer + if newPos > int64(bsr.currentPos) { + toRead := newPos - int64(bsr.currentPos) + + // Grow buffer to accommodate the data we need to read + currentLen := len(bsr.buffer) + neededLen := currentLen + int(toRead) + if cap(bsr.buffer) < neededLen { + // Need to grow capacity + newBuf := make([]byte, neededLen) + copy(newBuf, bsr.buffer) + bsr.buffer = newBuf + } else { + // Have enough capacity, just extend length + bsr.buffer = bsr.buffer[:neededLen] + } + + // Read directly into the buffer at the current position using ReadFull + n, err := io.ReadFull(bsr.reader, bsr.buffer[currentLen:neededLen]) + if n > 0 { + bsr.currentPos += uint64(n) + } + if err != nil { + // Adjust buffer to actual size read + bsr.buffer = bsr.buffer[:currentLen+n] + return int64(bsr.currentPos), err + } + return newPos, nil + } + + return 0, fmt.Errorf("seek position %d not reachable from current position %d", + newPos, bsr.currentPos) +} + +// GetBufferInfo returns current buffer state for debugging. +func (bsr *BoxSeekReader) GetBufferInfo() (bufferStart uint64, bufferLen int, currentPos uint64) { + return bsr.bufferPos, len(bsr.buffer), bsr.currentPos +} + +// GetBufferCapacity returns the current buffer capacity. +func (bsr *BoxSeekReader) GetBufferCapacity() int { + return cap(bsr.buffer) +} + +// GetCurrentPos returns the current read position in the stream. +func (bsr *BoxSeekReader) GetCurrentPos() uint64 { + return bsr.currentPos +} + +// IsMdatActive returns whether mdat bounds are currently active. +func (bsr *BoxSeekReader) IsMdatActive() bool { + return bsr.mdatActive +} + +// GetMdatBounds returns the current mdat bounds if active. +func (bsr *BoxSeekReader) GetMdatBounds() (start, size uint64, active bool) { + return bsr.mdatStart, bsr.mdatSize, bsr.mdatActive +} + +// PeekBoxHeader reads just enough to determine the box type and size. +// The read header bytes are stored in the buffer so ReadFullBox can include them. +// Returns the header and the absolute position where the box starts. +func (bsr *BoxSeekReader) PeekBoxHeader() (BoxHeader, uint64, error) { + boxStartPos := bsr.currentPos + + // Check if we already have a header in the buffer from a previous peek + // currentPos might be at box start OR already advanced past the header from a previous peek + if bsr.currentPos >= bsr.bufferPos && + bsr.currentPos <= bsr.bufferPos+uint64(len(bsr.buffer)) && + len(bsr.buffer) >= boxHeaderSize { + // If currentPos is past bufferPos, this is a second peek - use bufferPos as box start + if bsr.currentPos > bsr.bufferPos { + boxStartPos = bsr.bufferPos + } + + // Parse header from buffer + size := uint64(binary.BigEndian.Uint32(bsr.buffer[0:4])) + boxType := string(bsr.buffer[4:8]) + headerLen := boxHeaderSize + + if size == 1 && len(bsr.buffer) >= boxHeaderSize+largeSizeLen { + size = binary.BigEndian.Uint64(bsr.buffer[boxHeaderSize:]) + headerLen += largeSizeLen + } + + if size == 0 { + return BoxHeader{}, 0, fmt.Errorf("size 0, meaning to end of file, not supported") + } + + if uint64(headerLen) > size { + return BoxHeader{}, 0, fmt.Errorf("box header size %d exceeds box size %d", headerLen, size) + } + + // Update position to after header + bsr.currentPos = boxStartPos + uint64(headerLen) + + return BoxHeader{boxType, size, headerLen}, boxStartPos, nil + } + + // Need to read header from underlying reader + headerBuf := make([]byte, boxHeaderSize) + n, err := io.ReadFull(bsr.reader, headerBuf) + if err != nil { + return BoxHeader{}, 0, err + } + if n != boxHeaderSize { + return BoxHeader{}, 0, io.ErrUnexpectedEOF + } + + size := uint64(binary.BigEndian.Uint32(headerBuf[0:4])) + boxType := string(headerBuf[4:8]) + headerLen := boxHeaderSize + + // Check for large size + if size == 1 { + largeSizeBuf := make([]byte, largeSizeLen) + n, err := io.ReadFull(bsr.reader, largeSizeBuf) + if err != nil { + return BoxHeader{}, 0, err + } + if n != largeSizeLen { + return BoxHeader{}, 0, io.ErrUnexpectedEOF + } + size = binary.BigEndian.Uint64(largeSizeBuf) + headerLen += largeSizeLen + // Append large size bytes to header + headerBuf = append(headerBuf, largeSizeBuf...) + } + + if size == 0 { + return BoxHeader{}, 0, fmt.Errorf("size 0, meaning to end of file, not supported") + } + + if uint64(headerLen) > size { + return BoxHeader{}, 0, fmt.Errorf("box header size %d exceeds box size %d", headerLen, size) + } + + // Store peeked header in buffer and update position + bsr.buffer = bsr.buffer[:0] // Clear buffer + bsr.buffer = append(bsr.buffer, headerBuf...) + bsr.bufferPos = boxStartPos + bsr.currentPos = boxStartPos + uint64(headerLen) + + return BoxHeader{boxType, size, headerLen}, boxStartPos, nil +} diff --git a/mp4/boxseekreader_test.go b/mp4/boxseekreader_test.go new file mode 100644 index 00000000..c5f4689c --- /dev/null +++ b/mp4/boxseekreader_test.go @@ -0,0 +1,760 @@ +package mp4_test + +import ( + "bytes" + "encoding/binary" + "io" + "testing" + + "github.com/Eyevinn/mp4ff/mp4" +) + +func TestNewBoxSeekReader(t *testing.T) { + data := []byte("test data") + reader := bytes.NewReader(data) + + bsr := mp4.NewBoxSeekReader(reader, 1024) + if bsr == nil { + t.Fatal("mp4.NewBoxSeekReader returned nil") + } + + if bsr.GetBufferCapacity() != 1024 { + t.Errorf("buffer capacity: got %d, expected 1024", bsr.GetBufferCapacity()) + } + + if bsr.GetCurrentPos() != 0 { + t.Errorf("currentPos: got %d, expected 0", bsr.GetCurrentPos()) + } + + if bsr.IsMdatActive() { + t.Error("mdatActive should be false initially") + } +} + +func TestNewBoxSeekReaderDefaultSize(t *testing.T) { + reader := bytes.NewReader([]byte("test")) + bsr := mp4.NewBoxSeekReader(reader, 0) + + if bsr.GetBufferCapacity() != 64*1024 { + t.Errorf("default buffer capacity: got %d, expected %d", bsr.GetBufferCapacity(), 64*1024) + } +} + +func TestPeekBoxHeader(t *testing.T) { + // Create a simple box: size (4) + type (4) = 8 byte header + 4 byte payload = 12 total + boxData := make([]byte, 12) + binary.BigEndian.PutUint32(boxData[0:4], 12) // size + copy(boxData[4:8], "test") // type + binary.BigEndian.PutUint32(boxData[8:12], 0x12345678) // payload + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + hdr, startPos, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("PeekBoxHeader failed: %v", err) + } + + if hdr.Name != "test" { + t.Errorf("box type: got %s, expected test", hdr.Name) + } + + if hdr.Size != 12 { + t.Errorf("box size: got %d, expected 12", hdr.Size) + } + + if startPos != 0 { + t.Errorf("start position: got %d, expected 0", startPos) + } + + if hdr.Hdrlen != 8 { + t.Errorf("header length: got %d, expected 8", hdr.Hdrlen) + } + + // currentPos should be at end of header + if bsr.GetCurrentPos() != 8 { + t.Errorf("currentPos after peek: got %d, expected 8", bsr.GetCurrentPos()) + } +} + +func TestPeekBoxHeaderLargeSize(t *testing.T) { + // Create box with large size (size=1 means use next 8 bytes for actual size) + boxData := make([]byte, 24) + binary.BigEndian.PutUint32(boxData[0:4], 1) // size=1 means largesize follows + copy(boxData[4:8], "test") // type + binary.BigEndian.PutUint64(boxData[8:16], 24) // actual size + // 8 bytes payload + binary.BigEndian.PutUint64(boxData[16:24], 0x1234567890ABCDEF) + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + hdr, startPos, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("PeekBoxHeader with large size failed: %v", err) + } + + if hdr.Name != "test" { + t.Errorf("box type: got %s, expected test", hdr.Name) + } + + if hdr.Size != 24 { + t.Errorf("box size: got %d, expected 24", hdr.Size) + } + + if hdr.Hdrlen != 16 { + t.Errorf("header length: got %d, expected 16", hdr.Hdrlen) + } + + if startPos != 0 { + t.Errorf("start position: got %d, expected 0", startPos) + } + + if bsr.GetCurrentPos() != 16 { + t.Errorf("currentPos: got %d, expected 16", bsr.GetCurrentPos()) + } +} + +func TestPeekBoxHeaderTwice(t *testing.T) { + // Test that peeking twice without consuming returns same header + boxData := make([]byte, 12) + binary.BigEndian.PutUint32(boxData[0:4], 12) + copy(boxData[4:8], "test") + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + hdr1, pos1, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("First peek failed: %v", err) + } + + hdr2, pos2, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("Second peek failed: %v", err) + } + + if hdr1.Name != hdr2.Name || hdr1.Size != hdr2.Size { + t.Error("Second peek returned different header") + } + + if pos1 != pos2 { + t.Errorf("position mismatch: first=%d, second=%d", pos1, pos2) + } +} + +func TestReadFullBox(t *testing.T) { + boxData := make([]byte, 20) + binary.BigEndian.PutUint32(boxData[0:4], 20) + copy(boxData[4:8], "test") + for i := 8; i < 20; i++ { + boxData[i] = byte(i) + } + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Peek first + hdr, _, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("PeekBoxHeader failed: %v", err) + } + + // Now read full box + fullBox, err := bsr.ReadFullBox(hdr.Size) + if err != nil { + t.Fatalf("ReadFullBox failed: %v", err) + } + + if len(fullBox) != 20 { + t.Errorf("full box length: got %d, expected 20", len(fullBox)) + } + + if !bytes.Equal(fullBox, boxData) { + t.Error("full box data doesn't match original") + } + + if bsr.GetCurrentPos() != 20 { + t.Errorf("currentPos: got %d, expected 20", bsr.GetCurrentPos()) + } +} + +func TestReadFullBoxGrowsBuffer(t *testing.T) { + // Create a box larger than initial buffer + boxSize := 200 + boxData := make([]byte, boxSize) + binary.BigEndian.PutUint32(boxData[0:4], uint32(boxSize)) + copy(boxData[4:8], "bigg") + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) // Small initial buffer + + hdr, _, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("PeekBoxHeader failed: %v", err) + } + + fullBox, err := bsr.ReadFullBox(hdr.Size) + if err != nil { + t.Fatalf("ReadFullBox failed: %v", err) + } + + if len(fullBox) != boxSize { + t.Errorf("full box length: got %d, expected %d", len(fullBox), boxSize) + } + + if bsr.GetBufferCapacity() < boxSize { + t.Errorf("buffer should have grown to at least %d, got %d", boxSize, bsr.GetBufferCapacity()) + } +} + +func TestSetMdatBounds(t *testing.T) { + reader := bytes.NewReader(make([]byte, 1000)) + bsr := mp4.NewBoxSeekReader(reader, 64) + + mdatStart := uint64(100) + mdatSize := uint64(500) + + bsr.SetMdatBounds(mdatStart, mdatSize) + + if !bsr.IsMdatActive() { + t.Error("mdatActive should be true") + } + + if bsr.GetCurrentPos() != mdatStart { + t.Errorf("currentPos: got %d, expected %d", bsr.GetCurrentPos(), mdatStart) + } +} + +func TestResetBuffer(t *testing.T) { + reader := bytes.NewReader(make([]byte, 1000)) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Set some state + bsr.SetMdatBounds(100, 500) + + // Reset + bsr.ResetBuffer() + + if bsr.IsMdatActive() { + t.Error("mdatActive should be false after reset") + } + + _, bufferLen, _ := bsr.GetBufferInfo() + if bufferLen != 0 { + t.Errorf("buffer length: got %d, expected 0", bufferLen) + } +} + +func TestReadWithinBuffer(t *testing.T) { + testData := []byte("0123456789abcdef") + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Read first 8 bytes + buf1 := make([]byte, 8) + n, err := bsr.Read(buf1) + if err != nil { + t.Fatalf("First read failed: %v", err) + } + if n != 8 { + t.Errorf("First read: got %d bytes, expected 8", n) + } + if string(buf1) != "01234567" { + t.Errorf("First read data: got %s, expected 01234567", string(buf1)) + } + + // Read next 8 bytes + buf2 := make([]byte, 8) + n, err = bsr.Read(buf2) + if err != nil { + t.Fatalf("Second read failed: %v", err) + } + if n != 8 { + t.Errorf("Second read: got %d bytes, expected 8", n) + } + if string(buf2) != "89abcdef" { + t.Errorf("Second read data: got %s, expected 89abcdef", string(buf2)) + } + + if bsr.GetCurrentPos() != 16 { + t.Errorf("currentPos: got %d, expected 16", bsr.GetCurrentPos()) + } +} + +func TestReadBeyondBuffer(t *testing.T) { + testData := []byte("0123456789abcdef") + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Read all data at once + buf := make([]byte, 16) + n, err := bsr.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != 16 { + t.Errorf("Read: got %d bytes, expected 16", n) + } + if !bytes.Equal(buf, testData) { + t.Error("Read data doesn't match") + } +} + +func TestReadWithMdatBounds(t *testing.T) { + testData := make([]byte, 200) + for i := range testData { + testData[i] = byte(i) + } + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Set mdat bounds: start at 50, size 100 + bsr.SetMdatBounds(50, 100) + + // Try to read within bounds + buf := make([]byte, 50) + n, err := bsr.Read(buf) + if err != nil { + t.Fatalf("Read within bounds failed: %v", err) + } + if n != 50 { + t.Errorf("Read: got %d bytes, expected 50", n) + } + + // Try to read beyond mdat end (should be limited) + buf2 := make([]byte, 100) + n, err = bsr.Read(buf2) + if err != nil && err != io.EOF { + t.Fatalf("Read beyond bounds failed: %v", err) + } + if n != 50 { + t.Errorf("Read should be limited to 50 bytes, got %d", n) + } + + // Next read should hit EOF + buf3 := make([]byte, 10) + _, err = bsr.Read(buf3) + if err != io.EOF { + t.Errorf("Expected EOF, got %v", err) + } +} + +func TestSeekWithinBuffer(t *testing.T) { + testData := []byte("0123456789") + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Read some data to fill buffer + buf := make([]byte, 10) + _, _ = bsr.Read(buf) + + // Seek back to position 5 + newPos, err := bsr.Seek(5, io.SeekStart) + if err != nil { + t.Fatalf("Seek failed: %v", err) + } + if newPos != 5 { + t.Errorf("Seek returned %d, expected 5", newPos) + } + + // Read from new position + buf2 := make([]byte, 3) + n, err := bsr.Read(buf2) + if err != nil { + t.Fatalf("Read after seek failed: %v", err) + } + if n != 3 { + t.Errorf("Read: got %d bytes, expected 3", n) + } + if string(buf2) != "567" { + t.Errorf("Read after seek: got %s, expected 567", string(buf2)) + } +} + +func TestSeekCurrent(t *testing.T) { + testData := []byte("0123456789") + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Read to position 5 + buf := make([]byte, 5) + _, _ = bsr.Read(buf) + + // Seek forward 2 from current + newPos, err := bsr.Seek(2, io.SeekCurrent) + if err != nil { + t.Fatalf("Seek current failed: %v", err) + } + if newPos != 7 { + t.Errorf("Seek returned %d, expected 7", newPos) + } + + if bsr.GetCurrentPos() != 7 { + t.Errorf("currentPos: got %d, expected 7", bsr.GetCurrentPos()) + } +} + +func TestSeekForwardBeyondBuffer(t *testing.T) { + testData := make([]byte, 100) + for i := range testData { + testData[i] = byte(i) + } + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Seek to position 50 + newPos, err := bsr.Seek(50, io.SeekStart) + if err != nil { + t.Fatalf("Seek forward failed: %v", err) + } + if newPos != 50 { + t.Errorf("Seek returned %d, expected 50", newPos) + } + + // Read from new position + buf := make([]byte, 10) + n, err := bsr.Read(buf) + if err != nil { + t.Fatalf("Read after forward seek failed: %v", err) + } + if n != 10 { + t.Errorf("Read: got %d bytes, expected 10", n) + } + + // Verify data + expected := testData[50:60] + if !bytes.Equal(buf, expected) { + t.Error("Data after forward seek doesn't match") + } +} + +func TestSeekWithMdatBounds(t *testing.T) { + testData := make([]byte, 200) + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Set mdat bounds + bsr.SetMdatBounds(50, 100) + + // Seek within bounds + newPos, err := bsr.Seek(75, io.SeekStart) + if err != nil { + t.Fatalf("Seek within bounds failed: %v", err) + } + if newPos != 75 { + t.Errorf("Seek returned %d, expected 75", newPos) + } + + // Try to seek before mdat start + _, err = bsr.Seek(25, io.SeekStart) + if err == nil { + t.Error("Expected error for seek before mdat start") + } + + // Try to seek beyond mdat end + _, err = bsr.Seek(200, io.SeekStart) + if err == nil { + t.Error("Expected error for seek beyond mdat end") + } +} + +func TestGetBufferInfo(t *testing.T) { + testData := []byte("0123456789") + reader := bytes.NewReader(testData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + // Read some data + buf := make([]byte, 5) + _, _ = bsr.Read(buf) + + bufStart, bufLen, currentPos := bsr.GetBufferInfo() + + if bufStart != 0 { + t.Errorf("buffer start: got %d, expected 0", bufStart) + } + + if bufLen != 5 { + t.Errorf("buffer length: got %d, expected 5", bufLen) + } + + if currentPos != 5 { + t.Errorf("current position: got %d, expected 5", currentPos) + } +} + +func TestPeekBoxHeaderInsufficientData(t *testing.T) { + // Only 6 bytes - not enough for full header + shortData := []byte{0, 0, 0, 10, 't', 'e'} + reader := bytes.NewReader(shortData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, err := bsr.PeekBoxHeader() + if err == nil { + t.Error("Expected error for insufficient data") + } +} + +func TestReadFullBoxSizeTooLarge(t *testing.T) { + boxData := make([]byte, 12) + binary.BigEndian.PutUint32(boxData[0:4], 12) + copy(boxData[4:8], "test") + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, _ = bsr.PeekBoxHeader() + + // Try to read a box claiming to be 3GB + _, err := bsr.ReadFullBox(3 * 1024 * 1024 * 1024) + if err == nil { + t.Error("Expected error for box size too large") + } +} + +func TestGetMdatBounds(t *testing.T) { + reader := bytes.NewReader([]byte("test data")) + bsr := mp4.NewBoxSeekReader(reader, 64) + + start, size, active := bsr.GetMdatBounds() + if active { + t.Error("mdat should not be active initially") + } + if start != 0 || size != 0 { + t.Errorf("expected zero bounds, got start=%d, size=%d", start, size) + } + + bsr.SetMdatBounds(100, 500) + start, size, active = bsr.GetMdatBounds() + if !active { + t.Error("mdat should be active after SetMdatBounds") + } + if start != 100 { + t.Errorf("mdat start: got %d, expected 100", start) + } + if size != 500 { + t.Errorf("mdat size: got %d, expected 500", size) + } + + bsr.ResetBuffer() + _, _, active = bsr.GetMdatBounds() + if active { + t.Error("mdat should not be active after ResetBuffer") + } +} + +func TestReadAtMdatEnd(t *testing.T) { + data := make([]byte, 100) + for i := range data { + data[i] = byte(i) + } + reader := bytes.NewReader(data) + bsr := mp4.NewBoxSeekReader(reader, 64) + + bsr.SetMdatBounds(50, 30) + + buf := make([]byte, 30) + n, err := bsr.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != 30 { + t.Errorf("Read count: got %d, expected 30", n) + } + + buf2 := make([]byte, 10) + n, err = bsr.Read(buf2) + if err != io.EOF { + t.Errorf("Expected EOF at mdat end, got: %v", err) + } + if n != 0 { + t.Errorf("Expected 0 bytes read at mdat end, got %d", n) + } +} + +func TestReadPartialFromBuffer(t *testing.T) { + data := []byte("0123456789abcdefghijklmnop") + reader := bytes.NewReader(data) + bsr := mp4.NewBoxSeekReader(reader, 64) + + buf1 := make([]byte, 10) + n, err := bsr.Read(buf1) + if err != nil { + t.Fatalf("First read failed: %v", err) + } + if n != 10 { + t.Errorf("First read: got %d bytes, expected 10", n) + } + + _, err = bsr.Seek(5, io.SeekStart) + if err != nil { + t.Fatalf("Seek failed: %v", err) + } + + buf2 := make([]byte, 15) + n, err = bsr.Read(buf2) + if err != nil { + t.Fatalf("Second read failed: %v", err) + } + if n != 15 { + t.Errorf("Second read: got %d bytes, expected 15", n) + } + + expected := "56789abcdefghij" + if string(buf2) != expected { + t.Errorf("Data mismatch: got %q, expected %q", string(buf2), expected) + } +} + +func TestSeekFromEnd(t *testing.T) { + reader := bytes.NewReader([]byte("test data")) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, err := bsr.Seek(0, io.SeekEnd) + if err == nil { + t.Error("Seek from end should not be supported") + } +} + +func TestSeekInvalidWhence(t *testing.T) { + reader := bytes.NewReader([]byte("test data")) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, err := bsr.Seek(0, 999) + if err == nil { + t.Error("Seek with invalid whence should fail") + } +} + +func TestSeekNegativePosition(t *testing.T) { + reader := bytes.NewReader([]byte("test data")) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, err := bsr.Seek(-10, io.SeekStart) + if err == nil { + t.Error("Seek to negative position should fail") + } +} + +func TestSeekBeforeBuffer(t *testing.T) { + data := make([]byte, 100) + reader := bytes.NewReader(data) + bsr := mp4.NewBoxSeekReader(reader, 64) + + buf := make([]byte, 50) + _, err := bsr.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + bsr.ResetBuffer() + + _, err = bsr.Seek(10, io.SeekStart) + if err == nil { + t.Error("Seek before buffer start should fail") + } +} + +func TestSeekBeyondMdatEnd(t *testing.T) { + data := make([]byte, 100) + reader := bytes.NewReader(data) + bsr := mp4.NewBoxSeekReader(reader, 64) + + bsr.SetMdatBounds(20, 30) + + _, err := bsr.Seek(60, io.SeekStart) + if err == nil { + t.Error("Seek beyond mdat end should fail when mdat is active") + } +} + +func TestSeekBeforeMdatStart(t *testing.T) { + data := make([]byte, 100) + reader := bytes.NewReader(data) + bsr := mp4.NewBoxSeekReader(reader, 64) + + bsr.SetMdatBounds(50, 30) + + _, err := bsr.Seek(40, io.SeekStart) + if err == nil { + t.Error("Seek before mdat start should fail when mdat is active") + } +} + +func TestPeekBoxHeaderSize0(t *testing.T) { + boxData := make([]byte, 12) + binary.BigEndian.PutUint32(boxData[0:4], 0) + copy(boxData[4:8], "test") + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, err := bsr.PeekBoxHeader() + if err == nil { + t.Error("PeekBoxHeader should fail for size=0") + } +} + +func TestPeekBoxHeaderInvalidSize(t *testing.T) { + boxData := make([]byte, 12) + binary.BigEndian.PutUint32(boxData[0:4], 4) + copy(boxData[4:8], "test") + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, err := bsr.PeekBoxHeader() + if err == nil { + t.Error("PeekBoxHeader should fail when header size exceeds box size") + } +} + +func TestReadFullBoxHeaderSmallerThanBox(t *testing.T) { + boxData := make([]byte, 12) + binary.BigEndian.PutUint32(boxData[0:4], 12) + copy(boxData[4:8], "test") + binary.BigEndian.PutUint32(boxData[8:12], 0x12345678) + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, err := bsr.PeekBoxHeader() + if err != nil { + t.Fatalf("PeekBoxHeader failed: %v", err) + } + + _, err = bsr.ReadFullBox(6) + if err == nil { + t.Error("ReadFullBox should fail when box size is smaller than header") + } +} + +func TestPeekBoxHeaderLargeSize0(t *testing.T) { + boxData := make([]byte, 16) + binary.BigEndian.PutUint32(boxData[0:4], 1) + copy(boxData[4:8], "test") + binary.BigEndian.PutUint64(boxData[8:16], 0) + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, err := bsr.PeekBoxHeader() + if err == nil { + t.Error("PeekBoxHeader should fail for largesize=0") + } +} + +func TestPeekBoxHeaderLargeSizeInvalid(t *testing.T) { + boxData := make([]byte, 16) + binary.BigEndian.PutUint32(boxData[0:4], 1) + copy(boxData[4:8], "test") + binary.BigEndian.PutUint64(boxData[8:16], 10) + + reader := bytes.NewReader(boxData) + bsr := mp4.NewBoxSeekReader(reader, 64) + + _, _, err := bsr.PeekBoxHeader() + if err == nil { + t.Error("PeekBoxHeader should fail when largesize is less than header length") + } +} diff --git a/mp4/ftyp.go b/mp4/ftyp.go index f043fdb8..dbd0fdb2 100644 --- a/mp4/ftyp.go +++ b/mp4/ftyp.go @@ -12,6 +12,13 @@ type FtypBox struct { data []byte } +// Copy - deep copy of Ftyp box. +func (b *FtypBox) Copy() *FtypBox { + data := make([]byte, len(b.data)) + copy(data, b.data) + return &FtypBox{data: data} +} + // MajorBrand - major brand (4 chars) func (b *FtypBox) MajorBrand() string { return string(b.data[:4]) diff --git a/mp4/stream.go b/mp4/stream.go new file mode 100644 index 00000000..204d58d8 --- /dev/null +++ b/mp4/stream.go @@ -0,0 +1,628 @@ +package mp4 + +import ( + "fmt" + "io" + + "github.com/Eyevinn/mp4ff/bits" +) + +// TrailingBoxesErrror indicates that there are unexpected boxes after the last fragment. +type TrailingBoxesErrror struct { + BoxNames []string +} + +func (e *TrailingBoxesErrror) Error() string { + return fmt.Sprintf("trailing boxes found after last fragment: %v", e.BoxNames) +} + +// InitDecodeStream reads and parses only the init segment. +// Stops as soon as it peeks a box that belongs to a fragment (styp, sidx, moof, emsg, prft). +// Returns a StreamFile ready for ProcessFragments to consume fragments. +func InitDecodeStream(r io.Reader, options ...StreamOption) (*StreamFile, error) { + f := NewFile() + f.fileDecMode = DecModeLazyMdat + + bsr := NewBoxSeekReader(r, 64*1024) // Start with 64KB, will grow as needed + + sf := &StreamFile{ + File: f, + reader: r, + boxSeekReader: bsr, + maxFragments: 3, + } + + for _, opt := range options { + opt(sf) + } + + for { + // Peek at next box header to see what's coming + hdr, boxStartPos, err := bsr.PeekBoxHeader() + if err == io.EOF { + // Reached EOF before any fragments - file may be init-only + sf.streamPos = boxStartPos + break + } + if err != nil { + return nil, fmt.Errorf("peek box header at %d: %w", boxStartPos, err) + } + + boxType := hdr.Name + boxSize := hdr.Size + + // Check if this box belongs to fragments - if so, stop here + // The header is in the buffer, leave it there for ProcessFragments + switch boxType { + case "styp", "moof", "sidx", "emsg", "prft": + // These boxes indicate start of fragments + // Header bytes are in buffer, currentPos points after header + // Reset currentPos to boxStartPos so ProcessFragments can re-peek + bsr.currentPos = boxStartPos + f.isFragmented = true + if f.Init == nil && f.Moov != nil { + f.Init = NewMP4Init() + if f.Ftyp != nil { + f.Init.AddChild(f.Ftyp) + } + f.Init.AddChild(f.Moov) + } + sf.streamPos = boxStartPos + return sf, nil + case "mdat": + return nil, fmt.Errorf("unexpected mdat box at position %d before fragments", boxStartPos) + } + + // This box is part of the init segment - read and parse it + boxData, err := bsr.ReadFullBox(boxSize) + if err != nil { + return nil, fmt.Errorf("read %s box at %d: %w", boxType, boxStartPos, err) + } + + // Parse box from buffer using DecodeBoxSR + sr := bits.NewFixedSliceReader(boxData) + box, err := DecodeBoxSR(boxStartPos, sr) + if err != nil { + return nil, fmt.Errorf("decode %s box at %d: %w", boxType, boxStartPos, err) + } + + // Clear buffer for next box now that we're done parsing + bsr.ResetBuffer() + + switch boxType { + case "ftyp": + ftypBox := box.(*FtypBox) + f.Ftyp = ftypBox.Copy() + f.Children = append(f.Children, f.Ftyp) + case "moov": + f.Moov = box.(*MoovBox) + f.Children = append(f.Children, box) + if len(f.Moov.Trak.Mdia.Minf.Stbl.Stts.SampleCount) == 0 { + f.isFragmented = true + } else { + return nil, fmt.Errorf("file is progressive, not supported for streaming") + } + default: + // Unknown boxes in init segment - keep them + f.Children = append(f.Children, box) + } + + // Update stream position + sf.streamPos = boxStartPos + boxSize + } + + return sf, nil +} + +// StreamFile wraps File with streaming capabilities for processing fragments incrementally. +type StreamFile struct { + *File + reader io.Reader + boxSeekReader *BoxSeekReader + onFragmentReady FragmentCallback + onFragmentDone FragmentDoneCallback + maxFragments int + streamPos uint64 +} + +// FragmentCallback is called when a fragment's moof box has been parsed and mdat is ready to be accessed. +// The SampleAccessor provides lazy access to sample data. +type FragmentCallback func(f *Fragment, sa SampleAccessor) error + +// FragmentDoneCallback is called after a fragment has been fully processed. +type FragmentDoneCallback func(f *Fragment) error + +// SampleAccessor provides access to samples within a fragment. +type SampleAccessor interface { + GetSample(trackID uint32, sampleNr uint32) (*FullSample, error) + GetSampleRange(trackID uint32, startSampleNr, endSampleNr uint32) ([]FullSample, error) + GetSamples(trackID uint32) ([]FullSample, error) +} + +// StreamOption configures streaming behavior. +type StreamOption func(*StreamFile) + +// WithFragmentCallback sets the callback invoked when a fragment is ready for processing. +// This corresponds to the point after the moof box has been parsed and mdat is ready to be accessed. +func WithFragmentCallback(cb FragmentCallback) StreamOption { + return func(sf *StreamFile) { sf.onFragmentReady = cb } +} + +// WithFragmentDone sets the callback invoked after fragment processing completes. +func WithFragmentDone(cb FragmentDoneCallback) StreamOption { + return func(sf *StreamFile) { sf.onFragmentDone = cb } +} + +// WithMaxFragments sets the maximum number of fragments to retain in memory (sliding window). +// Default is 3. Set to 0 to keep all fragments. +func WithMaxFragments(max int) StreamOption { + return func(sf *StreamFile) { sf.maxFragments = max } +} + +// fragmentSampleAccessor implements SampleAccessor for a fragment using the boxSeekReader. +type fragmentSampleAccessor struct { + fragment *Fragment + boxSeekReader io.ReadSeeker + trex *TrexBox +} + +// GetSample retrieves a specific sample by track ID and sample number (1-based). +func (fsa *fragmentSampleAccessor) GetSample(trackID uint32, sampleNr uint32) (*FullSample, error) { + moof := fsa.fragment.Moof + var traf *TrafBox + for _, tr := range moof.Trafs { + if tr.Tfhd.TrackID == trackID { + traf = tr + break + } + } + if traf == nil { + return nil, fmt.Errorf("track %d not found in fragment", trackID) + } + + if sampleNr < 1 { + return nil, fmt.Errorf("sample number must be >= 1") + } + + tfhd := traf.Tfhd + var baseTime uint64 + if traf.Tfdt != nil { + baseTime = traf.Tfdt.BaseMediaDecodeTime() + } + moofStartPos := moof.StartPos + mdat := fsa.fragment.Mdat + + // Find which trun contains this sample and the sample's position + sampleIdx := uint32(1) + for _, trun := range traf.Truns { + trun.AddSampleDefaultValues(tfhd, fsa.trex) + samples := trun.GetSamples() + + if sampleIdx+uint32(len(samples)) <= sampleNr { + // This sample is in a later trun + for _, s := range samples { + baseTime += uint64(s.Dur) + } + sampleIdx += uint32(len(samples)) + continue + } + + // Sample is in this trun + offsetInTrun := sampleNr - sampleIdx + if offsetInTrun >= uint32(len(samples)) { + return nil, fmt.Errorf("sample number %d out of range", sampleNr) + } + + sample := samples[offsetInTrun] + + // Accumulate decode time for samples before this one in the trun + for i := uint32(0); i < offsetInTrun; i++ { + baseTime += uint64(samples[i].Dur) + } + + // Calculate file offset for this sample + baseOffset := moofStartPos + if tfhd.HasBaseDataOffset() { + baseOffset = tfhd.BaseDataOffset + } else if tfhd.DefaultBaseIfMoof() { + baseOffset = moofStartPos + } + if trun.HasDataOffset() { + baseOffset = uint64(int64(trun.DataOffset) + int64(baseOffset)) + } + + // Add size of samples before this one in the trun + for i := uint32(0); i < offsetInTrun; i++ { + baseOffset += uint64(samples[i].Size) + } + + // Read just this sample's data + data, err := mdat.ReadData(int64(baseOffset), int64(sample.Size), fsa.boxSeekReader) + if err != nil { + return nil, fmt.Errorf("read sample data: %w", err) + } + return &FullSample{ + Sample: sample, + DecodeTime: baseTime, + Data: data, + }, nil + } + + return nil, fmt.Errorf("sample number %d not found in fragment", sampleNr) +} + +func (fsa *fragmentSampleAccessor) GetSampleRange(trackID uint32, startSampleNr, endSampleNr uint32) ([]FullSample, error) { + if startSampleNr < 1 { + return nil, fmt.Errorf("start sample number must be >= 1") + } + if endSampleNr < startSampleNr { + return nil, fmt.Errorf("end sample number %d must be >= start sample number %d", endSampleNr, startSampleNr) + } + + moof := fsa.fragment.Moof + var traf *TrafBox + for _, tr := range moof.Trafs { + if tr.Tfhd.TrackID == trackID { + traf = tr + break + } + } + if traf == nil { + return nil, fmt.Errorf("track %d not found in fragment", trackID) + } + + tfhd := traf.Tfhd + var baseTime uint64 + if traf.Tfdt != nil { + baseTime = traf.Tfdt.BaseMediaDecodeTime() + } + moofStartPos := moof.StartPos + mdat := fsa.fragment.Mdat + + var result []FullSample + sampleIdx := uint32(1) + rangeStarted := false + + for _, trun := range traf.Truns { + trun.AddSampleDefaultValues(tfhd, fsa.trex) + samples := trun.GetSamples() + + // Calculate base offset for this trun + baseOffset := moofStartPos + if tfhd.HasBaseDataOffset() { + baseOffset = tfhd.BaseDataOffset + } else if tfhd.DefaultBaseIfMoof() { + baseOffset = moofStartPos + } + if trun.HasDataOffset() { + baseOffset = uint64(int64(trun.DataOffset) + int64(baseOffset)) + } + + for i, sample := range samples { + currentSampleNr := sampleIdx + uint32(i) + + // If we're past the end of the range, we're done + if currentSampleNr > endSampleNr { + return result, nil + } + + // If we haven't reached the start yet, skip this sample + if currentSampleNr < startSampleNr { + baseTime += uint64(sample.Dur) + baseOffset += uint64(sample.Size) + continue + } + + rangeStarted = true + + // Read this sample's data + data, err := mdat.ReadData(int64(baseOffset), int64(sample.Size), fsa.boxSeekReader) + if err != nil { + return nil, fmt.Errorf("read sample %d data: %w", currentSampleNr, err) + } + result = append(result, FullSample{ + Sample: sample, + DecodeTime: baseTime, + Data: data, + }) + + baseTime += uint64(sample.Dur) + baseOffset += uint64(sample.Size) + } + + sampleIdx += uint32(len(samples)) + } + + if !rangeStarted { + return nil, fmt.Errorf("start sample %d not found in fragment", startSampleNr) + } + + return result, nil +} + +// GetSamples retrieves all samples for a given track ID in the fragment. +// Will not return until the full mdat box has been read. +func (fsa *fragmentSampleAccessor) GetSamples(trackID uint32) ([]FullSample, error) { + moof := fsa.fragment.Moof + var traf *TrafBox + for _, tr := range moof.Trafs { + if tr.Tfhd.TrackID == trackID { + traf = tr + break + } + } + if traf == nil { + return nil, fmt.Errorf("track %d not found in fragment", trackID) + } + + tfhd := traf.Tfhd + var baseTime uint64 + if traf.Tfdt != nil { + baseTime = traf.Tfdt.BaseMediaDecodeTime() + } + moofStartPos := moof.StartPos + mdat := fsa.fragment.Mdat + + var samples []FullSample + for _, trun := range traf.Truns { + trun.AddSampleDefaultValues(tfhd, fsa.trex) + baseOffset := moofStartPos + if tfhd.HasBaseDataOffset() { + baseOffset = tfhd.BaseDataOffset + } else if tfhd.DefaultBaseIfMoof() { + baseOffset = moofStartPos + } + if trun.HasDataOffset() { + baseOffset = uint64(int64(trun.DataOffset) + int64(baseOffset)) + } + + offsetInFile := baseOffset + for _, sample := range trun.GetSamples() { + data, err := mdat.ReadData(int64(offsetInFile), int64(sample.Size), fsa.boxSeekReader) + if err != nil { + return nil, fmt.Errorf("read sample data: %w", err) + } + samples = append(samples, FullSample{ + Sample: sample, + DecodeTime: baseTime, + Data: data, + }) + baseTime += uint64(sample.Dur) + offsetInFile += uint64(sample.Size) + } + } + + return samples, nil +} + +// ProcessFragments reads and processes fragments from the stream until EOF. +// Returns a TrailingBoxesErrror if there are unexpected boxes after the last fragment. +func (sf *StreamFile) ProcessFragments() error { + // Collect boxes between fragments (styp, sidx, emsg, etc.) + var preFragmentBoxes []Box + + for { + // Peek at next box header to get type and size + hdr, boxStartPos, err := sf.boxSeekReader.PeekBoxHeader() + if err == io.EOF { + break + } + if err != nil { + // Check if this might be trailing data or end of stream + if boxStartPos > 0 { + // We successfully read some fragments, this might just be EOF + break + } + return fmt.Errorf("peek box header at %d: %w", boxStartPos, err) + } + + boxType := hdr.Name + boxSize := hdr.Size + + // For non-moof boxes, collect them to include with the next fragment + if boxType != "moof" { + if boxType == "mdat" { + return fmt.Errorf("unexpected mdat box without preceding moof at position %d", boxStartPos) + } + + // Read entire box into buffer + boxData, err := sf.boxSeekReader.ReadFullBox(boxSize) + if err != nil { + return fmt.Errorf("read %s box at %d: %w", boxType, boxStartPos, err) + } + + // Parse box from buffer using DecodeBoxSR + sr := bits.NewFixedSliceReader(boxData) + box, err := DecodeBoxSR(boxStartPos, sr) + if err != nil { + return fmt.Errorf("decode %s box at %d: %w", boxType, boxStartPos, err) + } + sf.boxSeekReader.ResetBuffer() + + // Copy styp boxes to avoid shared mutable state + if boxType == "styp" { + if stypBox, ok := box.(*StypBox); ok { + box = stypBox.Copy() + } + } + + preFragmentBoxes = append(preFragmentBoxes, box) + sf.streamPos = boxStartPos + boxSize + continue + } + + // Read entire moof box into buffer + moofData, err := sf.boxSeekReader.ReadFullBox(boxSize) + if err != nil { + return fmt.Errorf("read moof box at %d: %w", boxStartPos, err) + } + + // Parse moof from buffer using DecodeBoxSR + sr := bits.NewFixedSliceReader(moofData) + moofBox, err := DecodeBoxSR(boxStartPos, sr) + if err != nil { + return fmt.Errorf("decode moof box at %d: %w", boxStartPos, err) + } + sf.boxSeekReader.ResetBuffer() + + // Process the fragment (moof + mdat) + err = sf.processFragment(moofBox.(*MoofBox), boxStartPos, preFragmentBoxes) + if err != nil { + return fmt.Errorf("process fragment: %w", err) + } + + // Clear pre-fragment boxes for next fragment + preFragmentBoxes = nil + + // processFragment positions stream at end of mdat, ready for next box + sf.streamPos, _, _ = sf.boxSeekReader.GetBufferInfo() + } + + if len(preFragmentBoxes) > 0 { + return &TrailingBoxesErrror{BoxNames: func() []string { + names := make([]string, 0, len(preFragmentBoxes)) + for _, box := range preFragmentBoxes { + names = append(names, box.Type()) + } + return names + }()} + } + + return nil +} + +// processFragment handles a complete fragment (moof + mdat). +// moofStartPos is the start position of the moof box. +// preFragmentBoxes are boxes that appeared before the moof (sidx, emsg, styp, etc.) +func (sf *StreamFile) processFragment(moof *MoofBox, moofStartPos uint64, preFragmentBoxes []Box) error { + moof.StartPos = moofStartPos + + // Peek at mdat box header + // Stream should already be positioned at moofEndPos (right after moof box) + hdr, mdatStartPos, err := sf.boxSeekReader.PeekBoxHeader() + if err != nil { + return fmt.Errorf("peek mdat header: %w", err) + } + if hdr.Name != "mdat" { + return fmt.Errorf("expected mdat box after moof, got %s", hdr.Name) + } + + // Create lazy mdat box and skip the header in stream + mdat, err := DecodeMdatLazily(hdr, mdatStartPos) + if err != nil { + return fmt.Errorf("decode mdat lazily: %w", err) + } + mdatBox := mdat.(*MdatBox) + + // Skip past mdat header to position at payload start + mdatPayloadStart := mdatStartPos + uint64(hdr.Hdrlen) + _, err = sf.boxSeekReader.Seek(int64(mdatPayloadStart), io.SeekStart) + if err != nil { + return fmt.Errorf("seek to mdat payload: %w", err) + } + + mdatPayloadSize := mdatBox.GetLazyDataSize() + + // Configure boxSeekReader for this mdat's bounds + // This also pre-allocates buffer if mdat is small enough + sf.boxSeekReader.SetMdatBounds(mdatPayloadStart, mdatPayloadSize) + + // Stream is now positioned at start of mdat payload, ready for sample reads + // Verify position is correct + if mdatBox.PayloadAbsoluteOffset() != mdatPayloadStart { + return fmt.Errorf("mdat payload position mismatch: expected %d, got %d", + mdatPayloadStart, mdatBox.PayloadAbsoluteOffset()) + } + + // Create fragment with all boxes (pre-fragment boxes + moof + mdat) + children := make([]Box, 0, len(preFragmentBoxes)+2) + children = append(children, preFragmentBoxes...) + children = append(children, moof, mdatBox) + + frag := &Fragment{ + Moof: moof, + Mdat: mdatBox, + Children: children, + StartPos: moofStartPos, + } + + // Invoke callback if set + if sf.onFragmentReady != nil { + var trex *TrexBox + if sf.Moov != nil && sf.Moov.Mvex != nil { + trex = sf.Moov.Mvex.Trex + } + accessor := &fragmentSampleAccessor{ + fragment: frag, + boxSeekReader: sf.boxSeekReader, + trex: trex, + } + err = sf.onFragmentReady(frag, accessor) + if err != nil { + return fmt.Errorf("fragment callback: %w", err) + } + } + + // Add to file structure + if len(sf.Segments) == 0 { + sf.AddMediaSegment(&MediaSegment{StartPos: moofStartPos}) + } + lastSeg := sf.LastSegment() + lastSeg.AddFragment(frag) + + // Invoke done callback and handle cleanup + if sf.onFragmentDone != nil { + err = sf.onFragmentDone(frag) + if err != nil { + return fmt.Errorf("fragment done callback: %w", err) + } + } + + // Drop old fragments if sliding window is enabled + if sf.maxFragments > 0 { + totalFragments := 0 + for _, seg := range sf.Segments { + totalFragments += len(seg.Fragments) + } + if totalFragments > sf.maxFragments { + sf.dropOldestFragment() + } + } + + // Skip to end of mdat box to continue to next box + // mdatBox.Size() includes both header and payload + // So end position is mdatStartPos + mdatBox.Size() + mdatEndPos := mdatStartPos + mdatBox.Size() + _, err = sf.boxSeekReader.Seek(int64(mdatEndPos), io.SeekStart) + if err != nil { + return fmt.Errorf("seek past mdat to position %d: %w", mdatEndPos, err) + } + + // Reset mdat-specific state after seeking (clears mdatActive flag and buffer) + sf.boxSeekReader.ResetBuffer() + + // Stream is now positioned at mdatEndPos, ready to read next box header + return nil +} + +// dropOldestFragment removes the oldest fragment from the file structure. +func (sf *StreamFile) dropOldestFragment() { + for i, seg := range sf.Segments { + if len(seg.Fragments) > 0 { + seg.Fragments = seg.Fragments[1:] + if len(seg.Fragments) == 0 { + sf.Segments = sf.Segments[i+1:] + } + return + } + } +} + +// GetActiveFragments returns the currently retained fragments. +func (sf *StreamFile) GetActiveFragments() []*Fragment { + var frags []*Fragment + for _, seg := range sf.Segments { + frags = append(frags, seg.Fragments...) + } + return frags +} diff --git a/mp4/stream_test.go b/mp4/stream_test.go new file mode 100644 index 00000000..bb752f80 --- /dev/null +++ b/mp4/stream_test.go @@ -0,0 +1,489 @@ +package mp4_test + +import ( + "bytes" + "errors" + "os" + "testing" + + "github.com/Eyevinn/mp4ff/mp4" +) + +func TestDecodeStreamBasic(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + reader := bytes.NewReader(data) + sf, err := mp4.InitDecodeStream(reader) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + if sf == nil { + t.Fatal("StreamFile is nil") + } + + if sf.File == nil { + t.Fatal("File is nil") + } +} + +func TestProcessFragmentsWithCallback(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + if _, err := os.Stat(testFile); os.IsNotExist(err) { + t.Skip("Test file not found, skipping") + } + + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + fragmentCount := 0 + sampleCount := 0 + + reader := bytes.NewReader(data) + sf, err := mp4.InitDecodeStream(reader, + mp4.WithFragmentCallback(func(f *mp4.Fragment, sa mp4.SampleAccessor) error { + fragmentCount++ + if f.Moof == nil { + t.Error("Fragment moof is nil") + } + if f.Mdat == nil { + t.Error("Fragment mdat is nil") + } + + // Try to get samples + trackID := f.Moof.Trafs[0].Tfhd.TrackID + samples, err := sa.GetSamples(trackID) + if err != nil { + t.Errorf("GetSamples failed: %v", err) + } + sampleCount += len(samples) + + return nil + }), + ) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + err = sf.ProcessFragments() + if err != nil { + t.Fatalf("ProcessFragments failed: %v", err) + } + + if fragmentCount == 0 { + t.Error("No fragments processed") + } + + t.Logf("Processed %d fragments with %d total samples", fragmentCount, sampleCount) +} + +func TestStreamFileSlidingWindow(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + if _, err := os.Stat(testFile); os.IsNotExist(err) { + t.Skip("Test file not found, skipping") + } + + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + reader := bytes.NewReader(data) + sf, err := mp4.InitDecodeStream(reader, + mp4.WithMaxFragments(2), + ) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + err = sf.ProcessFragments() + if err != nil { + t.Fatalf("ProcessFragments failed: %v", err) + } + + activeFrags := sf.GetActiveFragments() + if len(activeFrags) > 2 { + t.Errorf("Expected at most 2 active fragments, got %d", len(activeFrags)) + } +} + +func TestStreamFileRetainFragment(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + if _, err := os.Stat(testFile); os.IsNotExist(err) { + t.Skip("Test file not found, skipping") + } + + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + reader := bytes.NewReader(data) + sf, err := mp4.InitDecodeStream(reader, + mp4.WithMaxFragments(1), + mp4.WithFragmentDone(func(f *mp4.Fragment) error { + return nil + }), + ) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + err = sf.ProcessFragments() + if err != nil { + t.Fatalf("ProcessFragments failed: %v", err) + } + + // Should have exactly 1 active fragment for sliding window == 1 + activeFrags := sf.GetActiveFragments() + if len(activeFrags) != 1 { + t.Errorf("%d not 1 active fragments in test file", len(activeFrags)) + } +} + +func TestFragmentIncludesPreFragmentBoxes(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + reader := bytes.NewReader(data) + sf, err := mp4.InitDecodeStream(reader, + mp4.WithFragmentCallback(func(f *mp4.Fragment, sa mp4.SampleAccessor) error { + // Check that fragment includes boxes before moof + // v300_multiple_segments.mp4 has styp boxes before each moof + hasStypBox := false + for _, child := range f.Children { + if child.Type() == "styp" { + hasStypBox = true + break + } + } + if !hasStypBox { + t.Errorf("Fragment at pos %d missing styp box in children", f.StartPos) + } + + // Verify moof and mdat are present + hasMoof := false + hasMdat := false + for _, child := range f.Children { + if child.Type() == "moof" { + hasMoof = true + } + if child.Type() == "mdat" { + hasMdat = true + } + } + if !hasMoof || !hasMdat { + t.Errorf("Fragment missing moof or mdat in children") + } + + return nil + }), + ) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + err = sf.ProcessFragments() + if err != nil { + t.Fatalf("ProcessFragments failed: %v", err) + } +} + +func TestSingleSampleAccess(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + // First, read with traditional DecodeFile to get reference samples + tradReader := bytes.NewReader(data) + tradFile, err := mp4.DecodeFile(tradReader) + if err != nil { + t.Fatalf("DecodeFile failed: %v", err) + } + + // Collect samples fragment by fragment from traditional parsing + type FragmentSamples struct { + trackID uint32 + samples []mp4.FullSample + } + var referenceFragments []FragmentSamples + + for _, seg := range tradFile.Segments { + for _, frag := range seg.Fragments { + for _, traf := range frag.Moof.Trafs { + trackID := traf.Tfhd.TrackID + trex, ok := tradFile.Init.Moov.Mvex.GetTrex(trackID) + if !ok { + t.Fatalf("No trex found for track %d", trackID) + } + samples, err := frag.GetFullSamples(trex) + if err != nil { + t.Fatalf("GetFullSamples failed: %v", err) + } + referenceFragments = append(referenceFragments, FragmentSamples{ + trackID: trackID, + samples: samples, + }) + } + } + } + + // Now use streaming processing and compare fragment by fragment + streamReader := bytes.NewReader(data) + fragIdx := 0 + + sf, err := mp4.InitDecodeStream(streamReader, + mp4.WithFragmentCallback(func(f *mp4.Fragment, sa mp4.SampleAccessor) error { + trackID := f.Moof.Traf.Tfhd.TrackID + + // Get all samples from streaming accessor + allSamples, err := sa.GetSamples(trackID) + if err != nil { + return err + } + + if len(allSamples) == 0 { + return nil + } + + // Compare individual sample access with bulk access + for i := 1; i <= len(allSamples); i++ { + sample, err := sa.GetSample(trackID, uint32(i)) + if err != nil { + t.Errorf("GetSample(%d) failed: %v", i, err) + continue + } + + expectedSample := &allSamples[i-1] + if sample.Size != expectedSample.Size { + t.Errorf("Sample %d size mismatch: got %d, expected %d", i, sample.Size, expectedSample.Size) + } + if sample.DecodeTime != expectedSample.DecodeTime { + t.Errorf("Sample %d decode time mismatch: got %d, expected %d", i, sample.DecodeTime, expectedSample.DecodeTime) + } + if len(sample.Data) != len(expectedSample.Data) { + t.Errorf("Sample %d data length mismatch: got %d, expected %d", i, len(sample.Data), len(expectedSample.Data)) + } + if !bytes.Equal(sample.Data, expectedSample.Data) { + t.Errorf("Sample %d data mismatch", i) + } + } + + // Compare against traditional parsing fragment by fragment + if fragIdx >= len(referenceFragments) { + t.Errorf("Stream fragment index %d exceeds reference fragments count %d", fragIdx, len(referenceFragments)) + return nil + } + + refFrag := referenceFragments[fragIdx] + if trackID != refFrag.trackID { + t.Errorf("Fragment %d: track ID mismatch got %d, expected %d", fragIdx, trackID, refFrag.trackID) + } + + if len(allSamples) != len(refFrag.samples) { + t.Errorf("Fragment %d: sample count mismatch got %d, expected %d", fragIdx, len(allSamples), len(refFrag.samples)) + } + + for i, sample := range allSamples { + if i >= len(refFrag.samples) { + break + } + ref := &refFrag.samples[i] + if sample.Size != ref.Size { + t.Errorf("Fragment %d sample %d: size mismatch got %d, expected %d", + fragIdx, i, sample.Size, ref.Size) + } + if sample.DecodeTime != ref.DecodeTime { + t.Errorf("Fragment %d sample %d: decode time mismatch got %d, expected %d", + fragIdx, i, sample.DecodeTime, ref.DecodeTime) + } + if !bytes.Equal(sample.Data, ref.Data) { + t.Errorf("Fragment %d sample %d: data mismatch", fragIdx, i) + } + } + + fragIdx++ + return nil + }), + ) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + err = sf.ProcessFragments() + if err != nil { + t.Fatalf("ProcessFragments failed: %v", err) + } + + // Verify we processed all fragments + if fragIdx != len(referenceFragments) { + t.Errorf("Processed %d fragments but expected %d", fragIdx, len(referenceFragments)) + } +} +func TestSampleRangeAccess(t *testing.T) { + testFile := "testdata/v300_multiple_segments.mp4" + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + reader := bytes.NewReader(data) + sf, err := mp4.InitDecodeStream(reader, + mp4.WithFragmentCallback(func(f *mp4.Fragment, sa mp4.SampleAccessor) error { + if len(f.Moof.Trafs) == 0 { + return nil + } + + trackID := f.Moof.Trafs[0].Tfhd.TrackID + + // Get all samples as reference + allSamples, err := sa.GetSamples(trackID) + if err != nil { + return err + } + + if len(allSamples) == 0 { + return nil + } + + // Test various ranges + testCases := []struct { + start, end uint32 + }{ + {1, 1}, // Single sample + {1, 5}, // First few samples + {3, 7}, // Middle range + {uint32(len(allSamples)), uint32(len(allSamples))}, // Last sample + {1, uint32(len(allSamples))}, // All samples + } + + for _, tc := range testCases { + if tc.end > uint32(len(allSamples)) { + continue + } + + rangeSamples, err := sa.GetSampleRange(trackID, tc.start, tc.end) + if err != nil { + t.Errorf("GetSampleRange(%d, %d) failed: %v", tc.start, tc.end, err) + continue + } + + expectedCount := tc.end - tc.start + 1 + if len(rangeSamples) != int(expectedCount) { + t.Errorf("GetSampleRange(%d, %d): got %d samples, expected %d", + tc.start, tc.end, len(rangeSamples), expectedCount) + continue + } + + // Verify each sample matches + for i, sample := range rangeSamples { + expectedIdx := int(tc.start) - 1 + i + expected := &allSamples[expectedIdx] + + if sample.Size != expected.Size { + t.Errorf("Range [%d,%d] sample %d: size mismatch got %d, expected %d", + tc.start, tc.end, i, sample.Size, expected.Size) + } + if sample.DecodeTime != expected.DecodeTime { + t.Errorf("Range [%d,%d] sample %d: decode time mismatch got %d, expected %d", + tc.start, tc.end, i, sample.DecodeTime, expected.DecodeTime) + } + if !bytes.Equal(sample.Data, expected.Data) { + t.Errorf("Range [%d,%d] sample %d: data mismatch", tc.start, tc.end, i) + } + } + } + + // Test error cases + _, err = sa.GetSampleRange(trackID, 0, 1) + if err == nil { + t.Error("Expected error for start sample 0") + } + + _, err = sa.GetSampleRange(trackID, 5, 3) + if err == nil { + t.Error("Expected error for end < start") + } + + _, err = sa.GetSampleRange(trackID, uint32(len(allSamples)+10), uint32(len(allSamples)+20)) + if err == nil { + t.Error("Expected error for out of range samples") + } + + return nil + }), + ) + if err != nil { + t.Fatalf("DecodeStream failed: %v", err) + } + + err = sf.ProcessFragments() + if err != nil { + t.Fatalf("ProcessFragments failed: %v", err) + } +} + +func TestTrailingBoxesError(t *testing.T) { + // Read the test file + testFile := "testdata/v300_multiple_segments.mp4" + data, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read test file: %v", err) + } + + // Create a buffer with the original data plus a trailing free box + buf := bytes.Buffer{} + buf.Write(data) + + // Append a free box: size (4 bytes) + type (4 bytes) + data + freeBox := mp4.NewFreeBox([]byte("trailing")) + err = freeBox.Encode(&buf) + if err != nil { + t.Fatalf("Failed to encode free box: %v", err) + } + + // Process the stream + reader := bytes.NewReader(buf.Bytes()) + sf, err := mp4.InitDecodeStream(reader, + mp4.WithFragmentCallback(func(f *mp4.Fragment, sa mp4.SampleAccessor) error { + // Just process normally + return nil + }), + ) + if err != nil { + t.Fatalf("InitDecodeStream failed: %v", err) + } + + // ProcessFragments should return TrailingBoxesErrror + err = sf.ProcessFragments() + + // Verify we get the expected error type + var trailingErr *mp4.TrailingBoxesErrror + if !errors.As(err, &trailingErr) { + t.Fatalf("Expected TrailingBoxesErrror, got: %v", err) + } + + // Verify the error contains the free box + if len(trailingErr.BoxNames) != 1 { + t.Errorf("Expected 1 trailing box, got %d: %v", len(trailingErr.BoxNames), trailingErr.BoxNames) + } + + wantedErrMsg := "trailing boxes found after last fragment: [free]" + if err.Error() != wantedErrMsg { + t.Errorf("Unexpected error message: %q, wanted %q", err.Error(), wantedErrMsg) + } + + t.Logf("Successfully detected trailing box: %v", trailingErr.BoxNames) +} diff --git a/mp4/styp.go b/mp4/styp.go index cf0ba3dd..a67e8235 100644 --- a/mp4/styp.go +++ b/mp4/styp.go @@ -12,6 +12,13 @@ type StypBox struct { data []byte } +// Copy - deep copy of Styp box. +func (b *StypBox) Copy() *StypBox { + data := make([]byte, len(b.data)) + copy(data, b.data) + return &StypBox{data: data} +} + // MajorBrand - major brand (4 chars) func (b *StypBox) MajorBrand() string { return string(b.data[:4])