commit a077651f7aff6361629792a63fc5c9be40e2200b Author: TheK0tYaRa Date: Tue Feb 24 05:50:17 2026 +0300 added livekit v1.9.11 sources diff --git a/livekit/.goreleaser.yaml b/livekit/.goreleaser.yaml new file mode 100644 index 0000000..44b9829 --- /dev/null +++ b/livekit/.goreleaser.yaml @@ -0,0 +1,61 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 2 + +before: + hooks: + - go mod tidy + - go generate ./... +builds: + - id: livekit + env: + - CGO_ENABLED=0 + main: ./cmd/server + binary: livekit-server + goarm: + - "7" + goarch: + - amd64 + - arm64 + - arm + goos: + - linux + - windows + +archives: + - format_overrides: + - goos: windows + format: zip + files: + - LICENSE +release: + github: + owner: livekit + name: livekit + draft: true + prerelease: auto +changelog: + sort: asc + filters: + exclude: + - '^docs:' + - '^test:' +gomod: + proxy: true + mod: mod +checksum: + name_template: 'checksums.txt' +snapshot: + name_template: "{{ incpatch .Version }}-next" diff --git a/livekit/CHANGELOG.md b/livekit/CHANGELOG.md new file mode 100644 index 0000000..5e4d377 --- /dev/null +++ b/livekit/CHANGELOG.md @@ -0,0 +1,1501 @@ +# Changelog + +This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.9.11] - 2026-01-15 + +## PLEASE NOTE: The previous release tag v1.9.10 hit a panic under some conditions. Sincerely regret the inconvenience caused. Although we do test rigorously, it is not guaranteed to cover all scenarios. We request you to report any issues you encounter. Thank you. + +### Added +- Support OpenTelemetry tracing. Add Jaeger support. (#4222) +- Add option to force simuclast codec. (#4226) +- Log timeout in API (#4231, #4232) +- Add participant option for data track auto-subscribe. (#4240) + +### Changed +- Remove enable arrival time forwarding method. (#4217) +- sfu/receiver and sfu/buffer refactor (#4221, #4224, #4225) +- Change some logs to debugw (#4229) +- Changing field naming of data track packet (#4235) +- Update Pion transport package. (#4237) +- Wrapping the invalid request errors for CreateSipParticipant (#4239) + +### Fixed +- Swap result sink atomically rather than closing and setting (#4216) +- Address crash in v1.9.10 (#4219, #4220) +- Return on SDP fragment read error. (#4228) + +## [1.9.10] - 2026-01-01 + +## WARNING: Please do not use this release. There is a run time issue which causes the server to panic. The issue has been addressed in #4219 and #4220. + +### Added +- add explicit room exists servicestore op (#4175) +- Add support for TURN static auth secret credentials (#3796) +- Make new path for signalling v1.5 support. (#4180) +- report video size from media data for whip (#4211) +- Support preserving external supplied time. (#4212) + +### Changed +- Use published track for model access in data down track. (#4176) +- Refactor receiver and buffer into Base and higher layer. (#4185, #4186, #4187, #4189, #4196, #4198, #4207) +- Update pion/webrtc to v4.2.1 (#4191) +- Receiver restart related changes. (#4192, #4200, #4202, #4208) +- Do not warn about track not bound if participant is not ready. (#4205, #4206) + +### Fixed +- Flush ext packets on restart/close and release packets. (#4179) +- Resolve RTX pair via OnTrack also. (#4190) +- Handle repair SSRC of simulcast tracks during migration. (#4193) +- return iceservers for whip (#4210) + +## [1.9.9] - 2025-12-18 + +### Added +- Add support for RTP stream restart. (#4161) + +### Changed +- Avoid duplicate track add to room track manager. (#4152, #4153) +- Consistently undo update to sequence number and timestamp when the incoming packet cannot be sequenced. (#4156) +- deregister observability function when participant is closed (#4157) +- Ensure subscribe data track handles are unique (#4162) +- move delete to oss service store (#4164) +- clean up manual roomservice log redaction (#4165) +- skip lost sequence number ranges in getIntervalStats (#4166, #4169) + +### Fixed +- chore: fix a large number of spelling issues (#4147) +- Handle case of sequence number jump just after start. (#4150) +- Drop run away receiver reports. (#4170) +- Publish/Unpublish counter match. (#4173) + +## [1.9.8] - 2025-12-10 + +### Added +- Mark RTCP buffer Write as noinline. (for better heap attribution) (#4138) +- add debug metric for tracking references (#4134) + +### Changed +- Use isEnding to indicate if down track could be resumed. (#4132) +- switch participant callbacks to room to listener interface (#4136) +- protocol deps to get inactive file adjusted memory usage. (#4137) +- update webrtc to 4.1.8 to pick up DTLS fingerprint check during handshake (#4140) + +### Fixed +- Do not pause rid in SDP to prevent race with adaptive streaming (#4129) +- leak fixes (#4131, #4141, #4142, #4143, #4144) + +## [1.9.7] - 2025-12-05 + +### Added +- Data tracks (experimental and not ready for use) (#4089) + +### Changed +- log bucket growth (#4122) +- Update pion/ice to stop gather first on close (#4123) +- move utils.WrapAround to mediatransportutil (#4124) +- Let participant close remove the published tracks. (#4125) + +### Fixed +- Fix concurrent map access for https://github.com/livekit/livekit/issues/4126. (#4127) + +## [1.9.6] - 2025-12-01 + +### Added +- Control latency of lossy data channel (#4088) +- logger proto redaction. (#4090) +- Record join/publish/subscribe cancellations (#4102, #4104) + +### Fixed +- Fix "address" typo in transport logs (addddress → address) (#4097) +- Clear stereo=1 if stereo is not enabled. (#4101) +- Participant session close deadlock fixes (#4107, #4113, #4116) + +### Changed +- Switch forwarding latency log to Debugw (#4098) +- Update mediatransportutil to get OWD estimator relocation (#4115) + +## [1.9.5] - 2025-12-01 - scratched + +## [1.9.4] - 2025-11-15 + +### Added +- Log reason for subscriber not being able to determine codec. (#4071) +- Kind details for connector (#4072) + +### Fixed +- Prevent invalid track access while peer connection is shutting down. (#4054) + +### Changed +- Update PsRPC to get redis pipeliner implementation (#4055) +- Forwarding latency measurement. (#4056. #4057, #4059, #4061, #4062, #4067, #4080) +- Update pion/transport to v3.1.1 (to get batch I/O ping-pong buffer) (#4070) +- Use sync.Pool for objects in packet path. (#4066) +- Bump protocol to pull sip validation changes and error mapping (#4081) + +## [1.9.3] - 2025-11-02 + +### Added +- Opportunistic video layer allocation on setting max spatial layer. (#4003, #4030, #4031, #4033) +- use env var for GOARCH. (#4012) +- Use simulcast codec as default policy for audio track. (#4040) +- Enable AbsCaptureTimeURI in RTC configuration. (#4043) +- Add prom histogram for forwarding latency and jitter. (#4044, #4045) + +### Fixed +- Correct direction for request/response for prom counters. (#4027) +- Do not bind buffer if codec is invalid. (#4028) +- Remove ~ from rid which indicates disabled layer to get the actual rid. (#4032) +- Prevent leakage of previous codec after codec regression. (#4035, #4037) +- fix: add missing Unlock() in AddReceiver. (#4036) + +### Changed +- Some golang modernisation bits. (#4106) +- Use rtp converter from protocol/utils. (#4019, #4020) +- High forwarding latency. (#4034, #4038) +- if RingingTimeout is provided, deadline should be set to that timeout. (#4018) +- Don't warn 0 payload type for PCMU. (#4039) + +## [1.9.2] - 2025-10-17 + +### Added +- Use gzip reader pool (#3903) +- Rpcs for ingress proxy WHIP (#3911) +- Include agent_name as a participant attribute (#3914) +- Clean code as there is no oss sweeper for ingress (#3918) +- Support simulcasting of audio (#3920) +- Subscrbed audio codecs - update from remote nodes. (#3921) +- Log some information around high forwarding latency. (#3944) +- feat: server rpc apis (#3904) +- short circuit participant broadcast filter in livestream mode (#3955) +- Adjust for hold time when fowarding RTCP report. (#3956) +- Add node_ip to config-sample.yaml (#3960) +- add idempotent reference count to telemetry stats worker (#3964) +- add config for user data recording (#3966) +- Provide the InputVideo/AudioState to Ingress in WHIPRTCConnectionNotify (#3982) +- Add encryption datapacket type (#3869) +- Allow passing inline trunk for outbound calls. (#3987) +- Log RPC details. (#3991) +- "Power of Two Random Choices" option for node selection (#3785) +- Adding ProviderInfo to GetSIPTrunkAuthenticationResponse (#3993) +- Use answer with mid -> trackID mapping when in single peer connection (#4005) +- Include mid -> trackID in both SDP offer and answer. (#4007) + +### Fixed +- add incoming request id to request response message (#3912) +- Simulcast audio fixes (#3925) +- Fix dynacast subscriber node clearing on move participant. (#3926) +- mediatransportutil crash fix for logging local address (#3930) +- Do DD restart only if DD structure is present. (#3935) +- Avoid matching on empty track id. (#3937) +- fix stats worker closed condition (#3965) +- Update deps to fix redis issue when 1 cluster address is provided (#3969) +- Revert unintentional change to not handle transport fallback (#3970) +- Do not panic of redis is not configured (#3981) +- Sort codec layers when adding track (#3998) +- Resort to full search for requested quality is not available. (#4000) +- Do not try to read stats from peer connection after close. (#4002) +- Update pion/webrtc to prevent GetStats panic. (#4004) + +### Changed +- update protocol for sip api change (#3902) +- Refactor subscribedTrack + mediaTrackSubscriptions. (#3908) +- Set publisher codec preferences after setting remote description (#3913) +- update protocol for psrpc (#3915) +- Wait for `SetRemoteDescription` before configuring senders. (#3924) +- Update mediatransportutil to log external IP found via STUN. (#3929) +- Add debugging from DD frame number wrap around. (#3933) +- More debugging of DD jump (#3934) +- Use difference in key frame counter to stop seeder. (#3936) +- Update protocol for SipCreateParticipant (#3939) +- mediatransportutil to log local address when validating external IP (#3942) +- Use microseconds for forwarding stats. (#3943) +- Tweaks tresholds for logging high forwarding latency/jitter. (#3945) +- Flush stats when there are no packets. (#3947) +- handle terminated job requests (#3948) +- Switch ops queue a singly linked list. (#3949) +- Revert "Switch ops queue a singly linked list." (#3950) +- Adjust stream allocator ping interval based on state. (#3951) +- avoid logging on small values (#3958) +- Update protocol for EventKey helper. (#3963) +- Do not force codec regression between opus and red. (#3980) +- Do not start forawarding on out-of-order packet. (#3985) +- Use padding only packets for dummy start of audio. (#3984) +- Support Opus mixed with RED when encrypted. (#3986) +- Limit check to red + opus when looking for primary codec match. (#3988) +- Increment RTP timestamp on padding when using dummy start. (#3989) +- Revert to using silence packets for audio dummy start. (#3999) +- Count request/response packets on both client and server side. (#4001) +- Do not call receiver methods under settings lock. (#4006) +- counterfeiter needs an older version of x/tools (#4009) + +## [1.9.1] - 2025-09-05 + +### Fixed + +- swap pub/sub track metrics (#3717) +- Fix bug with SDP rid, clear only overflow. (#3723) +- Don't check bindState on downtrack.Bind (#3726) +- Return highest available layer if requested quality is higher than max (#3729) +- Fix data packet ParticipantIdentity override logic in participant.go (#3735) +- Fix svc encoding for chrome mobile on iOS (#3751) +- offer could be nil when migrating. (#3752) +- fix(deps): update module github.com/livekit/protocol to v1.39.3 (#3733) +- bounds check layer index (#3768) +- Do not send leave if nil (to older clients) (#3817) +- Fix: RingingTimeout was being skipped for transferParticipant (#3831) +- Handle no codecs in track info (#3859) +- Fix missed unlock (#3861) +- Fix timeout handing in StopEgress (#3876) +- fix: ensure the participant kind is set on refresh tokens (#3881) +- Do not advertise NACK for RED. (#3889) +- Do not send both asb-send-time and twcc. (#3890) +- Prevent race in determining BWE type. (#3891) + +### Added + +- Adds Devin to readme so it auto updates DeepWiki weekly (#3699) +- Allow passing extra attributes to RTC endpoint. (#3693) +- warn about credentials when used in tokens (#3705) +- protocol dep for webhook stats buckets (#3706) +- for real, pick up protocol change for webhooks queue length buckets (#3707) +- implement observability for room metrics (#3712) +- e2e reliability for data channel (#3716) +- Add simulcast support for WHIP. (#3719) +- Add Id to SDP signalling messages. (#3722) +- Set and use rid/spatial layer in TrackInfo. (#3724) +- Add log for dropping out of order reliable message (#3728) +- chore: set workerid on job creation (#3737) +- return error when moving egress/agent participant (#3741) +- SVC with RID -> spatial layer mapping (#3754) +- feat(cli-flags): add option for cpu profiling (#3765) +- Enable H265 by default (#3773) +- Signalling V2 protocol implementation start (#3794) +- Signal v2: envelope and fragments as wire message format. (#3800) +- Grouping all signal messages into participant_signal. (#3801) +- starting signaller interface (#3802) +- Signal handling interfaces and participant specific HTTP PATCH. (#3804) +- Split signal segmenter and reassembler. (#3805) +- Filling out messages unlikely to change in v2. (#3806) +- Use signalling utils from protocol (#3807) +- Validation end point for v2 signalling. (#3811) +- More v2 signalling changes (#3814) +- Minor tweak to keep RPC type at service level. (#3815) +- Add country label to edge prom stats. (#3816) +- HTTP DELETE of participant session (#3819) +- Get to the point of establishing subscriber peer connection. (#3821) +- Get to the point of connecting publisher PC and using it for async signalling (#3822) +- Support join request as proto + base64 encoded query param (#3836) +- Use wrapped join request to be able to support compressed and uncompressed. (#3838) +- handle SyncState in join request (#3839) +- Support per simulcast codec layers. (#3840) +- Support video layer mode from client and make most of the code mime aware (#3843) +- Send `participant_connection_aborted` when participant session is closed (#3848) +- Support G.711 A-law and U-law (#3849) +- Extract video size from media stream (#3856) +- update mediatransport util for ice port 3478 (#3877) +- Single peer connection mode (#3873) +- handle frame number wrap back in svc (#3885) +- Use departure timeout from room preset. (#3888) +- Use `RequestResponse` to report protocol handling errors (#3895) + +### Changed + +- Add a trend check before declaring joint queuing region. (#3701) +- Small changes to add/use helper functions for length checks. (#3704) +- remove unused ws signal read loop (#3709) +- Flush stats on close (#3713) +- Do not require create permission for WHIP participant. (#3715) +- Create client config manager in room manager constructor. (#3718) +- Clear rids from default for layers not published. (#3721) +- Clear rids if not present in SDP. (#3731) +- Revert clearing RIDs. (#3732) +- Take ClientInfo from request. (#3738) +- remove unused code (#3740) +- reuse compiled client config scripts (#3743) +- feat(cli): update to urfave/cli/v3 (#3745) +- move egress roomID load to launcher (#3748) +- Log previous allocation to see changes. (#3759) +- Do not need to just clean up receivers. Remove that interface. (#3760) +- ClearAllReceivers interface is used to pause relay tracks. (#3761) +- Temporary change: use pre-defined rids (#3767) +- Revert "Temporary change: use pre-defined rids" (#3769) +- Log SDP rids to understand the mapping better. (#3770) +- Limit taking rids from SDP only in WHIP path. (#3771) +- Set rids for all codecs. (#3772) +- Return default layer for invalid rid + track info combination. (#3778) +- Normalize known rids. (#3779) +- forward agent id to job state (3786) +- Map ErrNoResponse to ErrRequestTimedOut in StopEgress to avoid returning 503 (#3788) +- Set participant active when peerconnection connected (#3790) +- Handle Metadata field from RoomConfig (#3798) +- [🤖 readme-manager] Update README (#3808) +- [🤖 readme-manager] Update README (#3809) +- Rename RTCRest -> WHIP (#3829) +- Delete v2 signalling (#3835) +- Clean up missed v2 pieces (#3837) +- Update go deps (#3849) +- Populate SDP cid in track info when available. (#3845) +- Log signal messages as debug. (#3851) +- Log signal messages on media node. (#3852) +- Log track settings more. (#3853) +- Update pion deps (#3863) +- Update golang Docker tag to v1.25 (#3864) +- Update module github.com/livekit/protocol to v1.40.0 (#3865) +- Remove unnecessary check (#3870) +- chunk room updates (#3880) +- Switch known rids from 012 -> 210, used by OBS. (#3882) +- init ua parser once (#3883) +- Revert to using answer for migration case. (#3884) +- Handle migration better in single peer connection case. (#3886) + +## [1.9.0] - 2025-06-02 + +### Added + +- Add pID and connID to log context to make it easier to search using pID. (#3518) +- add server agent load threshold config (#3520) +- Add a key frame seeder in up track. (#3524) +- Implement SIP update API. (#3141) +- Add option to use different pacer with send side bwe. (#3552) +- Allow specifying extra webhooks with egress requests (#3597) + +### Fixed + +- Fix missing RTCP sender report when forwarding RED as Opus. (#3480) +- Take RTT and jitter from receiver view while reporting track stats for (#3483) +- Fix receiver rtt/jitter. (#3487) +- fix: fix the wrong error return value (#3493) +- load mime type before calling writeBlankFrameRTP (#3502) +- Prevent bind lock deadlock on muted. (#3504) +- Handle subscribe race with track close better. (#3526) +- Do not instantiate 0 sized sequencer. (#3529) +- Fix: Return NotFoundErr instead of Unavailable when the participant does not exist in UpdateParticipant. (#3543) +- skip out of order participant state updates (#3583) +- Exclude RED from enabled codecs for Flutter + 2.4.2 + Android. (#3587) +- protocol update to fix IPv6 SDP fragment parsing (#3603) +- Forward transfer headers to internal request (#3615) +- Do not use Redis pipeline for SIP delete. Fixes Redis clustering support. (#3694) + +### Changed + +- Use a RED transformer to consolidate both RED -> Opus OR Opus -> RED (#3481) +- refactor: using slices.Contains to simplify the code (#3495) +- Do not bind lock across flush which could take time (#3501) +- Log packet drops/forward. (#3510) +- Clean up published track on participant removal. (#3527) +- Do not accept unsupported track type in AddTrack (#3530) +- Use cgroup for memstats. (#3573) +- Replace Promise with Fuse. (#3580) +- Do not drop audio codecs (#3590) +- map PEER_CONNECTION_DISCONNECTED -> CONNECTION_TIMEOUT (#3591) +- Update mediatransportutil for max sctp message (65K) (#3611) +- Disable vp9 for safari 18.4 due to compatibility (#3631) +- Avoid synthesising duplicate feature. (#3632) +- Take AudioFeatures from AddTrack. (#3635) +- Use unordered for lossy data channel. (#3653) +- Send self participant update immediately. (#3656) +- update mediatransportutil for sctp congestion control (#3673) + +## [1.8.4] - 2025-03-01 + +### Added + +- Add support for datastream trailer (#3329) +- Reject ingress if Enabled flag is false (#3293) +- Use nonce in data messages to de-dupe SendData API. (#3366) +- H265 supoort and codec regression (#3358) +- Pass error details and timeouts. (#3402) +- Webhook analytics event. (#3423) +- add participant job type (#3443) +- add datapacket stream metrics (#3450) +- Implement SIP iterators. (#3332) +- Add ice candidates logs for failed peerconnection (#3473) + +### Fixed + +- Disable SCTP zero checksum for old go client (#3319) +- disable sctp zero checksum for unknown sdk (#3321) +- remove code that deletes state from the store for rooms older than 24 hours (#3320) +- Correct off-by-one lost count on a restart. (#3337) +- Do not send DD extension if ID is 0. (#3339) +- allocate node for autocreated room in agent dispatch (#3344) +- Do not seed if stream is already writable. (#3347) +- Clone pending tracks to prevent concurrent update. (#3359) +- Resolve newer participant using higher precision join time. (#3360) +- Resolve FromAsCasing warning in Dockerfile (#3356) +- pass RoomConfig along when creating a new dispatch rule (#3367) +- Reduce chances of metadata cache overflow. (#3369, #3370) +- ReconnectResponse getting mutated due to mutation of client conf. (#3379) +- fire TrackSubscribed event only when subscriber is visible (#3378) +- fix internal signal protocol backward compatibility with 1.7.x (#3384) +- Correct reason for poor/lost score. (#3397) +- Do not skip due to large RR interval. (#3398) +- Update config.go to properly process bool env vars (#3382) +- consolidated mime type checks (#3407, #3418) +- Ignore unknown mime in dynacast manager. (#3419) +- Fix timing issue between track republish (#3428) +- Catch up if the diff is exactly (1 << 16) also. (#3433) +- Don't drop message if calculate duration is too small (#3442) +- Dependent participants should not trigger count towards FirstJoinedAt (#3448) +- Fix codec regression failed after migration (#3455) +- Do not revoke track subscription on permission update for exempt participants. (#3458) + +### Changed + +- Remove duplicate SSRC get. (#3318) +- Exempt egress participant from track permissions. (#3322) +- Use nano time for easier (and hopefully) faster checks/calculations. (#3323) +- move unrolled mime type check for broader use (#3326) +- Request key frame if subscribed is higher than max seen and not congested. (#3348) +- Request key frame on subscription change. (#3349) +- Room creation time with ms resolution (#3362) +- close signal session is request messages are undeliverable (#3364) +- Declare congestion none only if both methods are in DQR. (#3372) +- Clone TrackInfo to TrackPublishRequested event. (#3377) +- Run bandwidth estimation when congestion is relieved also (#3380) +- move ConnectedAt to Participant interface (#3383) +- Starting on padding for RTX stream is accepted. (#3390) +- Adjust receiver report sequence number to be within range of highest. (#3396) +- Split down stream snapshot into sender view and receiver view. (#3422) +- Seed on receiving forwarder state. (#3435) +- Give more cache for RTX. (#3438) +- Properly initialise DD layer selector. (#3467) + +## [1.8.3] - 2025-01-07 + +### Added + +- Allow requesting a dialtone during call transfer (#3122) +- Handle room configuration that's set in the grant itself (#3120) +- Update ICE to pick up accepting use-candidate unconditionally for ICE lite agents (#3150) +- auto create rooms during create agent dispatch api request (#3158) +- Annotate SIP errors with Twirp codes. (#3161) +- TWCC based congestion control (#3165 #3234 #3235 #3244 #3245 #3250 #3251 #3253 #3254 #3256 #3262 #3282) +- Loss based congestion signal detector. (#3168 #3169) +- Fix header size calculation in stats. (#3171) +- add per message deflate to signal ws (#3174) +- Add ResyncDownTracks API that can be used to resync all down tracks on (#3185) +- One shot signalling mode (#3188 #3192 #3194 #3223) +- Server side metrics (#3198) +- Add datastream packet type handling (#3210) +- Support SIP list filters. (#3240) +- Add RTX to downstream (#3247) +- Handle REMB on RTX RTCP (#3257) +- Thottle the publisher data channel sending when subscriber is slow (#3255 #3265 #3281) + +### Fixed + +- avoids NaN (#3119) +- reduce retransmit by seeding duplicate packets and bytes. (#3124) +- don't return video/rtx to client (#3142) +- ignore unexported fields in yaml lint (#3145) +- Fix incorrect computation of SecondsSinceNodeStatsUpdate (#3172) +- Attempt to fix missing participant left webhook. (#3173) +- Set down track connected flag in one-shot-signalling mode. (#3191) +- Don't SetCodecPreferences for video transceiver (#3249) +- Disable av1 for safari (#3284) +- fix completed job status updates causing workers to reconnect (#3294) + +### Changed + +- Display both pairs on selected candidate pair change (#3133) +- Maintain RTT marker for calculations. (#3139) +- Consolidate operations on LocalNode. (#3140) +- Use int64 nanoseconds and reduce conversion in a few places (#3159) +- De-centralize some configs to where they are used. (#3162) +- Split out audio level config. (#3163) +- Use int64 nanoseconds and reduce conversion in a few places (#3159) +- Reduce lock scope. (#3167) +- Clean up forwardRTP function a bit. (#3177) +- StreamAllocator (congestion controller) refactor (#3180) +- convert psprc error to http code in rtc service failure response (#3187) +- skip http request logging when the client aborts the request (#3195) +- Do not treat data publisher as publisher. (#3204) +- Publish data and signal bytes once every 30 seconds. (#3212) +- upgrade to pion/webrtc v4 (#3213) +- Don't wait rtp packet to fire track (#3246) +- Keep negotiated codec parameters in Downtrack.Bind (#3271) +- Structured logging of ParticipantInit (#3279) +- Start stream allocator after creating peer connection. (#3283) +- Reduce memory allocation in WritePaddingRTP / WriteProbePackets (#3288) +- add room/participant to logger context for SIP APIs (#3290) +- vp8 temporal layer selection with dependency descriptor (#3302) +- Use contiguous groups to determine queuing region. (#3308) + +## [1.8.0] - 2024-10-18 + +### Added + +- Support protocol 15 - send signal response for success responses (#2926) +- Add `DisconnectReason` to `ParticipantInfo`. (#2930) +- add roommanager service (#2931) +- Add tracksubscribed event on downtrack added (#2934) +- Speed up track publication (#2952) +- Add FastPublish in JoinResponse (#2964) +- Update protocol. Support SIP Callee dispatch rule type. (#2969) +- Record out-of-packet count/rate in prom. (#2980) +- Support passing SIP headers. (#2993) +- Update ICE via webrtc to get candidate pair stats RTT (#3009) +- Initial plumbing for metrics. (#2950) +- Allow agents to override sender identities on ChatMessage (#3022) +- Implement SIP TransferParticipant API (#3026) +- api for agent worker job count (#3068) +- Add counter for pub&sub time metrics (#3084) +- Support for attributes in initial agent token (#3097) + +### Fixed + +- Handle another old packet condition. (#2947) +- Properly exclude mDNS when not trickling also. (#2956) +- Panic fix for nil candidate check. (#2957) +- Skip ICE restart on unestablished peer connection. (#2967) +- Recreate stats worker on resume if needed. (#2982) +- Handle trailing slashes in URL (#2988) +- Do not take padding packets into account in max pps calculation (#2990) +- compute agent service affinity from available capacity (#2995) +- Do not remove from subscription map on unsubscribe. (#3002) +- Fix forwarder panic defer of nil senderReport (#3011) +- avoid race condition on downtrack.Codec (#3032) +- fix: copy attributes to refresh token (#3036) +- Set mime_type for tracks don't have simulcast_codecs (#3040) +- check data messages for nil payloads (#3062) +- Fix codec name normalisation. (#3081 #3103 #3104 #3106 #3113) +- Set FEC enabled properly in connection stats module. (#3098) +- Type safe IP checks for SIP Trunks. (#3108) +- Parse python, cpp, unity-web, node sdks in clientinfo (#3110) + +### Changed + +- Use monotonic clock in packet path. (#2940) +- Refactor propagation delay estimator. (#2941) +- Propagate SIP attributes from a Dispatch Rule. (#2943) +- Refactor sip create participant (#2949) +- Reduce threshold of out-of-order very old packet detection. (#2951) +- Standardize twirp hooks during server init (#2959) +- Don't remove DD extesion for simucalst codecs (#2960) +- Negotiate downttrack for subscriber before receiver is ready (#2970) +- Allow start streaming on an out-of-order packet. (#2971) +- exponential backoff when calling CreateRoom (#2977) +- Start negotiate immediately if last one is before debouce interval (#2979) +- Seed down track state on re-use. (#2985) +- Cache RTCP sender report in forwarder state. (#2994) +- Set SenderReport to nil on seeding if empty. (#3008) +- Use new track id for republishing (#3020) +- simplify agent registration (#3018) +- enable room creator service by default (#3043) +- Fix clock rate skew calculation. (#3055) +- Forward new disconnect reasons for SIP. (#3056) +- Use difference debounce interval in negotiation (#3078) +- Use lower case mime type in dynacast. (#3080) +- Drop quality a bit faster on score trending lower to be more responsive. (#3093) +- Protocol update to get more precise protoproxy timing (#3107) + +## [1.7.2] - 2024-08-10 + +### Added + +- Feat add prometheus auth (#2252) +- Support for Agent protocol v2 (#2786 #2837 #2872 #2886, #2919) +- Add track subscribed notification to publisher (#2834) +- Always forward DTMF data messages. (#2848) +- Send error response when update metadata fails. (#2849) +- Allow specifying room configuration in token (#2853) +- Make sender report pass through an option. (#2861) +- Add option to disable ice lite (#2862) +- mark final ice candidate (#2871) +- Send the correct room closed reason to clients (#2901) +- distribute load to agents probabilistically, inversely proportionate to load (#2902) + +### Fixed + +- Fixed participant attributes not broadcasted correctly (#2825) +- Handle cases of long mute/rollover of time stamp. (#2842) +- use correct payload type for red primary encoding (#2845) +- Forward correct payload type for mixed up red/primary payload (#2847) +- Check size limits on metadata and name set from client. (#2850) +- Fallback to primary encoding if redundant block overflow (#2858) +- Support updating local track features when pending. (#2863) +- don't send unknown signal message to rust sdk with protocol 9 (#2860) +- Fixed handling of different extensions across multiple media sections (#2878) +- Fix forced rollover of RTP time stamp. (#2896) +- Do not start forwarding on an out-of-order packet. (#2917) +- Reset DD tracker layers when muted. (#2920) + +### Changed + +- add handler interface to receive agent worker updates (#2830) +- log non-trickle candidate in details (#2832) +- RTP packet validity check. (#2833) +- Do not warn on padding (#2839) +- Check sender report against media path. (#2843) +- Do not create room in UpdateRoomMetadata (#2854) +- use atomic pointer for MediaTrackReceiver TrackInfo (#2870) +- Don't create DDParser for non-svc codec (#2883) + +## [1.7.0] - 2024-06-23 + +This version includes a breaking change for SIP service. SIP service now requires `sip.admin` in the token's permission grant +to interact with trunks and dispatch rules; and `sip.call` to dial out to phone numbers. +The latest versions of server SDKs will include the permission grants automatically. + +### Added + +- Support new SIP Trunk API. (#2799) +- Add participant session duration metric (#2801) +- Support for key/value attributes on Participants (#2806) +- Breaking: SIP service requires sip.admin or sip.call grants. (#2808) + +### Fixed + +- Fixed agent jobs not launching when using the CreateRoom API (#2796) + +### Changed + +- Indicate if track is expected to be resumed in `onClose` callback. (#2800) + +## [1.6.2] - 2024-06-15 + +### Added + +- Support for optional publisher datachannel (#2693) +- add room/participant name limit (#2704) +- Pass through timestamp in abs capture time (#2715) +- Support SIP transports. (#2724) + +### Fixed + +- add missing strings.EqualFold for some mimeType comparisons (#2701) +- connection reset without any closing handshake on clientside (#2709) +- Do not propagate RTCP if report is not processed. (#2739) +- Fix DD tracker addition. (#2751) +- Reset tracker on expected layer increase. (#2753) +- Do not add tracker for invalid layers. (#2759) +- Do not compare payload type before bind (#2775) +- fix agent jobs not launching when using the CreateRoom API (#2784) + +### Changed + +- Performance improvements to forwarding by using condition var. (#2691 #2699) +- Simplify time stamp calculation on switches. (#2688) +- Simplify layer roll back. (#2702) +- ensure room is running before attempting to delete (#2705) +- Redact egress object in CreateRoom request (#2710) +- reduce participant lock scope (#2732) +- Demote some less useful/noisy logs. (#2743) +- Stop probe on probe controller reset (#2744) +- initialize bucket size by publish bitrates (#2763) +- Validate RTP packets. (#2778) + +## [1.6.1] - 2024-04-26 + +This release changes the default behavior when creating or updating WHIP +ingress. WHIP ingress will now default to disabling transcoding and +forwarding media unchanged to the LiveKit subscribers. This behavior can +be changed by using the new `enable_transcoding` available in updated +SDKs. The behavior of existing ingresses is unchanged. + +### Added + +- Add support for "abs-capture-time" extension. (#2640) +- Add PropagationDelay API to sender report data (#2646) +- Add support for EnableTranscoding ingress option (#2681) +- Pass new SIP metadata. Update protocol. (#2683) +- Handle UpdateLocalAudioTrack and UpdateLocalVideoTrack. (#2684) +- Forward transcription data packets to the room (#2687) + +### Fixed + +- backwards compatability for IsRecorder (#2647) +- Reduce RED weight in half. (#2648) +- add disconnected chan to participant (#2650) +- add typed ops queue (#2655) +- ICE config cache module. (#2654) +- use typed ops queue in pctransport (#2656) +- Use the ingress state updated_at field to ensure that out of order RPC do not overwrite state (#2657) +- Log ICE candidates to debug TCP connection issues. (#2658) +- Debug logging addition of ICE candidate (#2659) +- fix participant, ensure room name matches (#2660) +- replace keyframe ticker with timer (#2661) +- fix key frame timer (#2662) +- Disable dynamic playout delay for screenshare track (#2663) +- Don't log dd invalid template index (#2664) +- Do codec munging when munging RTP header. (#2665) +- Connection quality LOST only if RTCP is also not available. (#2670) +- Handle large jumps in RTCP sender report timestamp. (#2674) +- Bump golang.org/x/net from 0.22.0 to 0.23.0 (#2673) +- do not capture pointers in ops queue closures (#2675) +- Fix SubParticipant twice when paticipant left (#2672) +- use ttlcache (#2677) +- Detach subscriber datachannel to save memory (#2680) +- Clean up UpdateVideoLayers (#2685) + +## [1.6.0] - 2024-04-10 + +### Added + +- Support for Participant.Kind. (#2505 #2626) +- Support XR request/response for rtt calculation (#2536) +- Added support for departureTimeout to keep the room open after participant depart (#2549) +- Added support for Egress Proxy (#2570) +- Added support for SIP DTMF data messages. (#2559) +- Add option to enable bitrate based scoring (#2600) +- Agent service: support for orchestration v2 & namespaces (#2545 #2641) +- Ability to disable audio loss proxying. (#2629) + +### Fixed + +- Prevent multiple debounce of quality downgrade. (#2499) +- fix pli throttle locking (#2521) +- Use the correct snapshot id for PPS. (#2528) +- Validate SIP trunks and rules when creating new ones. (#2535) +- Remove subscriber if track closed while adding subscriber. (#2537) +- fix #2539, do not kill the keepaliveWorker task when the ping timeout occurs (#2555) +- Improved A/V sync, proper RTCP report past mute. (#2588) +- Protect duplicate subscription. (#2596) +- Fix twcc has chance to miss for firefox simulcast rtx (#2601) +- Limit playout delay change for high jitter (#2635) + +### Changed + +- Replace reflect.Equal with generic sliceEqual (#2494) +- Some optimisations in the forwarding path. (#2035) +- Reduce heap for dependency descriptor in forwarding path. (#2496) +- Separate buffer size config for video and audio. (#2498) +- update pion/ice for tcpmux memory improvement (#2500) +- Close published track always. (#2508) +- use dynamic bucket size (#2524) +- Refactoring channel handling (#2532) +- Forward publisher sender report instead of generating. (#2572) +- Notify initial permissions (#2595) +- Replace sleep with sync.Cond to reduce jitter (#2603) +- Prevent large spikes in propagation delay (#2615) +- reduce gc from stream allocator rate monitor (#2638) + +## [1.5.3] - 2024-02-17 + +### Added + +- Added dynamic playout delay if PlayoutDelay enabled in the room (#2403) +- Allow creating SRT URL pull ingress (requires Ingress service release) (#2416) +- Use default max playout delay as chrome (#2411) +- RTX support on publisher transport (#2452) +- Add exponential backoff to room service check retries (#2462) +- Add support for ingress ParticipantMetadata (#2461) + +### Fixed + +- Prevent race of new track and new receiver. (#2345) +- Fixed race condition when applying metadata update. (#2363 #2478) +- Fixed race condition in DownTrack.Bind. (#2388) +- Improved PSRPC over redis reliability with keepalive (#2398) +- Fix race condition on Participant.updateState (#2401) +- Replace /bin/bash with env call for FreeBSD compatibility (#2409) +- Fix startup with -dev and -config (#2442) +- Fix published track leaks: close published tracks on participant close (#2446) +- Enforce empty SID for UserPacket from hidden participants (#2469) +- Ignore duplicate RID. (Fix for spec breakage by Firefox on Windows 10) (#2471) + +### Changed + +- Logging improvements (various PRs) +- Server shuts down after a second SIGINT to simplify development lifecycle (#2364) +- A/V sync improvements (#2369 #2437 #2472) +- Prometheus: larger max session start time bin size (#2380) +- Updated SIP protocol for creating participants. (requires latest SIP release) (#2404 #2474) +- Improved reliability of signal stream starts with retries (#2414) +- Use Deque instead of channels in internal communications to reduce memory usage. (#2418 #2419) +- Do not synthesise DISCONNECT on session change. (#2412) +- Prometheus: larger buckets for jitter histogram (#2468) +- Support for improved Ingress internal RPC (#2485) +- Let track events go through after participant close. (#2487) + +### Removed + +- Removed code related to legacy (pre 1.5.x) RPC protocol (#2384 #2385) + +## [1.5.2] - 2023-12-21 + +Support for LiveKit SIP Bridge + +### Added + +- Add SIP Support (#2240 #2241 #2244 #2250 #2263 #2291 #2293) +- Introduce `LOST` connection quality. (#2265 #2276) +- Expose detailed connection info with ICEConnectionDetails (#2287) +- Add Version to TrackInfo. (#2324 #2325) + +### Fixed + +- Guard against bad quality in trackInfo (#2271) +- Group SDES items for one SSRC in the same chunk. (#2280) +- Avoid dropping data packets on local router (#2270) +- Fix signal response delivery after session start failure (#2294) +- Populate disconnect updates with participant identity (#2310) +- Fix mid info lost when migrating multi-codec simulcast track (#2315) +- Store identity in participant update cache. (#2320) +- Fix panic occurs when starting livekit-server with key-file option (#2312) (#2313) + +### Changed + +- INFO logging reduction (#2243 #2273 #2275 #2281 #2283 #2285 #2322) +- Clean up restart a bit. (#2247) +- Use a worker to report signal/data stats. (#2260) +- Consolidate TrackInfo. (#2331) + +## [1.5.1] - 2023-11-09 + +Support for the Agent framework. + +### Added + +- PSRPC based room and participant service. disabled by default (#2171 #2205) +- Add configuration to limit MaxBufferedAmount for data channel (#2170) +- Agent framework worker support (#2203 #2227 #2230 #2231 #2232) + +### Fixed + +- Fixed panic in StreamTracker when SVC is used (#2147) +- fix CreateEgress not completing (#2156) +- Do not update highest time on padding packet. (#2157) +- Clear flags in packet metadata cache before setting them. (#2160) +- Drop not relevant packet only if contiguous. (#2167) +- Fixed edge cases in SVC codec support (#2176 #2185 #2191 #2196 #2197 #2214 #2215 #2216 #2218 #2219) +- Do not post to closed channels. (#2179) +- Only launch room egress once (#2175) +- Remove un-preferred codecs for android firefox (#2183) +- Fix pre-extended value on wrap back restart. (#2202) +- Declare audio inactive if stale. (#2229) + +### Changed + +- Defer close of source and sink to prevent error logs. (#2149) +- Continued AV Sync improvements (#2150 #2153) +- Egress store/IO cleanup (required for Egress 1.8.0) (#2152) +- More fine grained filtering NACKs after a key frame. (#2159) +- Don't filter out ipv6 address for client don't support prflx over relay (#2193) +- Disable h264 for android firefox (#2190) +- Do not block on down track close with flush. (#2201) +- Separate publish and subscribe enabled codecs for finer grained control. (#2217) +- improve participant hidden (#2220) +- Reject migration if codec mismatch with published tracks (#2225) + +## [1.5.0] - 2023-10-15 + +### Added + +- Add option to issue full reconnect on data channel error. (#2026) +- Support non-SVC AV1 track publishing (#2030) +- Add batch i/o to improve throughput (#2033) +- Integrate updated TWCC responder (#2038) +- Allow RoomService.SendData to use participant identities (#2051 #2058) +- Support for Participant Egress (#2070) +- Add max playout delay config (#2089) +- Enable SVC codecs by default (#2109) +- Add SyncStreams flag to Room, protocol 10 (#2110) + +### Fixed + +- Unlock pendingTracksLock when mid is empty (#1994) +- Do not offer H.264 high profile in subscriber offer, fixes negotiation failures (#1997) +- Prevent erroneous stream pause. (#2008) +- Handle duplicate padding packet in the up stream. (#2012) +- Do not process packets not processed by RTPStats. (#2015) +- Adjust extended sequence number to account for dropped packets (#2017) +- Do not force reconnect on resume if there is a pending track (#2081) +- Fix out-of-range access. (#2082) +- Start key frame requester on start. (#2111) +- Handle RED extended sequence number. (#2123) +- Handle playoutDelay for Firefox (#2135) +- Fix ICE connection fallback (#2144) + +### Changed + +- Drop padding only packets on publisher side. (#1990) +- Do not generate a stream key for URL pull ingress (#1993) +- RTPStats optimizations and improvements (#1999 #2000 #2001 #2002 #2003 #2004 #2078) +- Remove sender report warp logs. (#2007) +- Don't create new slice when return broadcast downtracks (#2013) +- Disconnect participant when signal proxy is closed (#2024) +- Use random NodeID instead of MAC based (#2029) +- Split RTPStats into receiver and sender. (#2055) +- Reduce packet meta data cache (#2073 #2078) +- Reduce ghost participant disconnect timeout (#2077) +- Per-session TURN credentials (#2080) +- Use marshal + unmarshal to ensure unmarshallable fields are not copied. (#2092) +- Allow playout delay even when sync stream isn't used. (#2133) +- Increase accuracy of delay since last sender report. (#2136) + +## [1.4.5] - 2023-08-22 + +### Added + +- Add ability to roll back video layer selection. (#1871) +- Allow listing ingress by id (#1874) +- E2EE trailer for server injected packets. (#1908) +- Add support for ingress URL pull (#1938 #1939) +- (experimental) Add control of playout delay (#1838 #1930) +- Add option to advertise external ip only (#1962) +- Allow data packet to be sent to participants by identity (#1982) + +### Fixed + +- Fix RTC IP when binding to 0.0.0.0 (#1862) +- Prevent anachronous sample reading in connection stats (#1863) +- Fixed resubscribe race due to desire changed before cleaning up (#1865) +- Fixed numPublisher computation by marking dirty after track published changes (#1878) +- Attempt to avoid out-of-order max subscribed layer notifications. (#1882) +- Improved packet loss handling for SVC codecs (#1912 ) +- Frame integrity check for SVC codecs (#1914) +- Issue full reconnect if subscriber PC is closed on ICERestart (#1919) +- Do not post max layer event for audio. (#1932) +- Never use dd tracker for non-svc codec (#1952) +- Fix race condition causing new participants to have stale room metadata (#1969) +- Fixed VP9 handling for non-SVC content. (#1973) +- Ensure older session does not clobber newer session. (#1974) +- Do not start RTPStats on a padding packet. (#1984) + +### Changed + +- Push track quality to poor on a bandwidth constrained pause (#1867) +- AV sync improvements (#1875 #1892 #1944 #1951 #1955 #1956 #1968 #1971 #1986) +- Do not send unnecessary room updates when content isn't changed (#1881) +- start reading signal messages before session handler finishes (#1883) +- changing key file permissions control to allow group readable (#1893) +- close disconnected participants when signal channel fails (#1895) +- Improved stream allocator handling during transitions and reallocation. (#1905 #1906) +- Stream allocator tweaks to reduce re-allocation (#1936) +- Reduce NACK traffic by delaying retransmission after first send. (#1918) +- Temper stream allocator more to avoid false negative downgrades (#1920) + +## [1.4.4] - 2023-07-08 + +### Added + +- Add dependency descriptor stream tracker for svc codecs (#1788) +- Full reconnect on publication mismatch on resume. (#1823) +- Pacer interface in down stream path. (#1835) +- retry egress on timeout/resource exhausted (#1852) + +### Fixed + +- Send Room metadata updates immediately after update (#1787) +- Do not send ParticipantJoined webhook if connection was resumed (#1795) +- Reduce memory leaks by avoiding references in closure. (#1809) +- Honor bind address passed as `--bind` also for RTC ports (#1815) +- Avoid dangling downtracks by always deleting them in receiver close. (#1842) +- Better cleanup of subscriptions with needsCleanup. (#1845) +- Fix nack issue for svc codecs (#1856) +- Fixed hidden participant update were still sent when track is published (#1857) +- Fixed Redis lockup when unlocking room with canceled request context (#1859) + +### Changed + +- Improvements to A/V sync (#1773 #1781 #1784 ) +- Improved probing to be less disruptive in low bandwidth scenarios (#1782 #1834 #1839) +- Do not mute forwarder when paused due to bandwidth congestion. (#1796) +- Improvements to congestion controller (#1800 #1802 ) +- Close participant on full reconnect. (#1818) +- Do not process events after participant close. (#1824) +- Improvements to dependency descriptor based selection forwarder (#1808) +- Discount out-of-order packets in downstream score. (#1831) +- Adaptive stream to select highest layer of equal dimensions (#1841) +- Return 404 with DeleteRoom/RemoveParticipant when deleting non-existent resources (#1860) + +## [1.4.3] - 2023-06-03 + +### Added + +- Send quality stats to prometheus. (#1708) +- Support for disabling publishing codec on specific devices (#1728) +- Add support for bypass_transcoding field in ingress (#1741) +- Include await_start_signal for Web Egress (#1759) + +### Fixed + +- Handle time stamp increment across mute for A/V sync (#1705) +- Additional A/V sync improvements (#1712 #1724 #1737 #1738 #1764) +- Check egress status on UpdateStream failure (#1716) +- Start signal relay sessions with the correct node (#1721) +- Fix unwrap for out-of-order packet (#1729) +- Fix dynacast for svc codec (#1742 #1743) +- Ignore receiver reports that have a sequence number before first packet (#1745) +- Fix node stats updates on Windows (#1748) +- Avoid reconnect loop for unsupported downtrack (#1754) +- Perform unsubscribe in parallel to avoid blocking (#1760) + +### Changed + +- Make signal close async. (#1711 #1722) +- Don't add nack if it is already present in track codec (#1714) +- Tweaked connection quality algorithm to be less sensitive to jitter (#1719) +- Adjust sender report time stamp for slow publishers (#1740) +- Split probe controller from StreamAllocator (#1751) + +## [1.4.2] - 2023-04-27 + +### Added + +- VP9 codec with SVC support (#1586) +- Support for source-specific permissions and client-initiated metadata updates (#1590) +- Batch support for signal relay (#1593 #1596) +- Support for simulating subscriber bandwidth (#1609) +- Support for subscription limits (#1629) +- Send Room updates when participant counts change (#1647) + +### Fixed + +- Fixed process return code to 0 (#1589) +- Fixed VP9 stutter when not using dependency descriptors (#1595) +- Fixed stutter when using dependency descriptors (#1600) +- Fixed Redis cluster support when using Egress or Ingress (#1606) +- Fixed simulcast parsing error for slower clients (camera and screenshare) (#1621) +- Don't close RTCP reader if Downtrack will be resumed (#1632) +- Restore VP8 munger state properly. (#1634) +- Fixed incorrect node routing when using signal relay (#1645) +- Do not send hidden participants to others after resume (#1689) +- Fix for potential webhook delivery delays (#1690) + +### Changed + +- Refactored video layer selector (#1588 #1591 #1592) +- Improved transport fallback when client is resuming (#1597) +- Improved webhook reliability with delivery retries (#1607 #1615) +- Congestion controller improvements (#1614 #1616 #1617 #1623 #1628 #1631 #1652) +- Reduced memory usage by releasing ParticipantInfo after JoinResponse is transmitted (#1619) +- Ensure safe access in sequencer (#1625) +- Run quality scorer when there are no streams. (#1633) +- Participant version is only incremented after updates (#1646) +- Connection quality attribution improvements (#1653 #1664) +- Remove disallowed subscriptions on close. (#1668) +- A/V sync improvements (#1681 #1684 #1687 #1693 #1695 #1696 #1698 #1704) +- RTCP sender reports every three seconds. (#1692) + +### Removed + +- Remove deprecated (non-psrpc) egress client (#1701) + +## [1.4.1] - 2023-04-05 + +### Added + +- Added prometheus metrics for internal signaling API #1571 + +### Fixed + +- Fix regressions in RTC when using redis with psrpc signaling #1584 #1582 #1580 #1567 +- Fix required bitrate assessment under channel congestion #1577 + +### Changed + +- Improve DTLS reliability in regions with internet filters #1568 +- Reduce memory usage from logging #1576 + +## [1.4.0] - 2023-03-27 + +### Added + +- Added config to disable active RED encoding. Use NACK instead #1476 #1477 +- Added option to skip TCP fallback if TCP RTT is high #1484 +- psrpc based signaling between signal and RTC #1485 +- Connection quality algorithm revamp #1490 #1491 #1493 #1496 #1497 #1500 #1505 #1507 #1509 #1516 #1520 #1521 #1527 #1528 #1536 +- Support for topics in data channel messages #1489 +- Added active filter to ListEgress #1517 +- Handling for React Native and Rust SDK ClientInfo #1544 + +### Fixed + +- Fixed unsubscribed speakers stuck as speaking to clients #1475 +- Do not include packet in RED if timestamp is too far back #1478 +- Prevent PLI layer lock getting stuck #1481 +- Fix a case of changing video quality not succeeding #1483 +- Resync on pub muted for audio to avoid jump in sequence numbers on unmute #1487 +- Fixed a case of data race #1492 +- Inform reconnecting participant about recently disconnected users #1495 +- Send room update that may be missed by reconnected participant #1499 +- Fixed regression for AV1 forwarding #1538 +- Ensure sequence number continuity #1539 +- Give proper grace period when recorder is still in the room #1547 +- Fix sequence number offset on packet drop #1556 +- Fix signal client message buffer size #1561 + +### Changed + +- Reduce lock scope getting RTCP sender reports #1473 +- Avoid duplicate queueReconcile in subscription manager #1474 +- Do not log TURN errors with prefix "error when handling datagram" #1494 +- Improvements to TCP fallback mode #1498 +- Unify forwarder between dependency descriptor and no DD case. #1543 +- Increase sequence number cache to handle high rate tracks #1560 + +## [1.3.5] - 2023-02-25 + +### Added + +- Allow for strict ACKs to be disabled or subscriber peer connections #1410 + +### Fixed + +- Don't error when get tc stats fails #1306 +- Fixed support for Redis cluster #1415 +- Fixed unpublished callback being skipped in certain cases #1418 +- Fixed panic when Egress request is sent with an empty output field #1420 +- Do not unsubscribe from track if it's been republished #1424 #1429 #1454 #1465 +- Fixed panic when closing room #1428 +- Use available layers in optimal allocation #1445 #1446 #1448 #1449 +- Fixed unable to notify webhook when egress ending with status EgressStatus_EGRESS_LIMIT_REACHED #1451 +- Reset subscription start timer on permission grant #1457 +- Avoid panic when server receives a token without a video grant #1463 + +### Changed + +- Updated various logging #1413 #1433 #1437 #1440 #1470 +- Do not force TCP when client left before DTLS handshake #1414 +- Improved performance of data packet forwarding by broadcasting in parallel #1425 +- Cleaning up `availableLayers` and `exemptedLayers` #1407 +- Send stream start on initial start #1456 +- Switch to TLS if ICE/TCP isn't working well #1458 + +### Removed + +- Removed signal de-duper as it has not proven to be reliable #1427 +- Remove deprecated ingress rpc #1439 (breaking change for Ingress, this will require Ingress v0.0.2+) + +## [1.3.4] - 2023-02-09 + +### Added + +- Memory used and total to node stats #1293 #1296 +- Reconnect response to update ICE servers after resume #1300 #1367 +- Additional prometheus stats #1291 +- Adopt psrpc for internal communication protocol #1295 +- Enable track-level audio nack config #1306 +- Telemetry events for ParticipantResumed, track requested actions #1308 +- Allow disabling mDNS, which degrades performance on certain networks #1311 #1393 +- Publish stream stats to prometheus #1313 #1347 +- Retry initial connection attempt if it fails #1335 #1409 +- Add reconnect reason and signal rtt calculation #1381 +- silent frame for muted audio downtrack #1389 + +### Fixed + +- Fixed TimedVersion handling of non-monotonic timestamps #1304 +- Persist participant before firing webhook #1340 +- Set IsPublisher to true for data-only publishers #1348 +- Ignore inactive media in SDP #1365 +- Ensure older participant session update does not go out after a newer #1372 +- Fix potentially nil access in buffer #1374 +- Ensure onPacket is not nil in RTCPReader callback #1390 +- Fix rare panic by CreateSenderReport before bind completed #1397 + +### Changed + +- A/V synchronization improvements #1297 #1315 #1318 #1321 #1351 +- IOInfo service to handle ingress/egress updates #1305 +- Subscription manager to improve subscription resilience #1317 #1358 #1369 #1379 #1382 +- Enable video at low res by default when adaptive stream is enabled #1341 +- Enable upstream nack for opus only audio track #1343 +- Allow /rtc/validate to return room not found message #1344 +- Improve connectivity check to detect DTLS failure #1366 +- Simplify forwarding logic #1349 #1376 #1398 +- Send stream state paused only when it is paused due to bandwidth limitation. #1391 +- Do not catch panics, exit instead to prevent lockup #1392 + +## [1.3.3] - 2023-01-06 + +### Added + +- Signal deduper: ignore duplicate signal messages #1243 #1247 #1257 +- FPS based stream tracker #1267 #1269 #1275 #1281 +- Support forwarding track encryption status #1265 +- Use publisher side sender report when forwarding - improves A/V sync #1286 + +### Fixed + +- When removing a participant, verify SID matches #1237 +- Fixed rare panic when GetSelectedICECandidatePair returns nil #1253 +- Prevent ParticipantUpdate to be sent before JoinResponse #1271 #1272 +- Fixed Firefox connectivity issues when using UDPMux #1270 #1277 +- Fixed subscribing muted track with Egress and Go SDK #1283 + +### Changed + +- ParticipantLeft webhook would not be sent unless connected successfully #1130 +- Updated to Go 1.18+ #1259 +- Updated Egress RPC framework - psrpc #1252 #1256 #1266 #1273 +- Track subscription operations per source track #1248 +- Egress participants do not count in max_participants #1279 + +## [1.3.2] - 2022-12-15 + +### Added + +- help-verbose subcommand to print out all flags #1171 #1180 +- Support for Redis cluster #1181 +- Allow loopback candidates to be used via config option #1185 +- Support for high bitrate audio #1188 +- Ability to detect publication errors and force reconnect #1214 +- API secrets are validated upon startup to ensure sufficient security #1217 + +### Fixed + +- Correctly suppress verbose pion logs #1163 +- Fixed memory leak on long running room/participants #1169 +- Force full reconnect when there is no previous answer #1168 +- Fixed potential SSRC collision between participants #1173 +- Prevent RTX buffer and forwarding path colliding #1174 +- Do not set forceRelay when unset #1184 +- Prevent subscription after participant close #1182 +- Fixed lost RTCP packets when incorrect buffer factory was used #1195 +- Fixed handling of high bitrate while adding Opus RED #1196 +- Fixes a rare timing issue leading to connection failure #1208 +- Fixed incorrect handling of | in participant identity #1220 #1223 +- Fixed regression causing Firefox to not connect over TURN #1226 + +### Changed + +- CreateRoom API to allocate the room on RTC node #1155 #1157 +- Check forwarder started when seeding #1191 +- Do not forward media until peer connection is connected #1194 +- Log sampler to reduce log spam #1222 + +## [1.3.1] - 2022-11-09 + +### Fixed + +- Fixed logging config causes server to fail to start #1154 + +## [1.3.0] - 2022-11-08 + +### Added + +- Ingress Service support #1125 +- Support for web egress #1126 +- Ability to set all configuration params via command line flags #1112 +- Server-side RED encoding for supported clients #1137 +- Opus RED active loss recovery #1139 +- Experimental: fallback to TCP when UDP is unstable #1119 +- Populate memory load in node stats #1121 + +### Fixed + +- Fixed dynacast pausing a layer due to clients (FF) not publishing layer 0 #1117 +- Room.activeRecording updated correctly after users rejoin #1132 +- Don't collect external candidate IP when it's filtered out #1135 +- Install script to use uname without assuming /usr/bin #1138 + +### Changed + +- Allocate packetMeta up front to reduce number of allocations #1108 +- Do not log duplicate packet error. #1116 +- Consolidate getMemoryStats #1122 +- Seed snapshots to avoid saving/restoring in downtrack #1128 +- Remove Dependency Descriptor extension when AV1 is not preferred #1129 +- Always send participant updates prior to negotiation #1147 +- Set track level codec settings for all pending tracks #1148 +- Use Redis universal client to support clustered redis #1149 + +## [1.2.5] - 2022-10-19 + +### Added + +- Ability to filter IP addresses from being used #1052 +- Allow TCP fallback on multiple connection failures #1077 +- Added support for track level stereo and RED setting #1086 + +### Fixed + +- Fixed stream allocator with SVC codecs #1053 +- Fixed UDPMux connectivity issues when machine has multiple interfaces #1081 +- Ensure sender reports are in sync after transceiver is re-used #1080 +- Fixed simulcast codec blocking track closure #1082 +- Prevents multiple transport fallback in the same session #1090 + +### Changed + +- Config validation has been enabled. Server will not start if there are invalid config values #1051 +- Improves NACK stats to count as a miss only if i t's not EOF #1061 +- Store track MIME type during publishing #1065 +- Minor cleanup of media track & friends module #1067 +- Split out shared media transport code into livekit/mediatransportutil #1071 +- Cleaned up logging, improved consistency of debug vs info #1073 +- Reduced memory usage with sequencer #1100 +- Improved IP address mapping, handling of multiple IPs #1094 +- Service API requests are logged #1091 +- Default HTTP handler responds with 404 for unknown paths #1088 + +## [1.2.3] - 2022-09-13 + +### Added + +- Supervisor framework to improve edge case & error handling #1005 #1006 #1010 #1017 +- Support for stereo Opus tracks #1013 +- Allow CORS responses to be cached to allow faster initial connection #1027 + +### Fixed + +- Fixed SSRC mix-up for simulcasted tracks during session resume #1014 +- Fixed screen corruption for non-simulcasted tracks, caused by probing packets #1020 +- Fixed Handling of Simple NALU keyframes for H.264 #1016 +- Fixed TCPMux & UDPMux mixup when multiple host candidates are offered #1036 + +### Changed + +- Webhook requests are now using Content-Type application/webhook+json to avoid eager JSON parsing #1025 +- Don't automatically add STUN servers when explicit Node IP has been set #1023 +- Automatic TCP and TURN/TLS fallback is now enabled by default #1033 + +### Removed + +- Fully removed references to VP9. LiveKit is focused on AV1. #1004 + +## [1.2.1] - 2022-09-13 + +### Added + +- Accepts existing participant ID on reconnection attempts #988 + +### Fixed + +- Fixed ICE restart during candidate gathering #963 +- Ensure TrackInfoAvailable is fired after information is known to be ready #967 +- Fixed layer handling when publisher pauses layer 0 (FireFox is has a tendency to pause lowest layer) #984 +- Fixed inaccurate participant count due to storing stale data #992 + +### Changed + +- Protect against looking up dimensions for invalid spatial layer #977 +- Improvements around migration handling #979 #981 #982 #995 +- Consistent mapping between VideoQuality, rid, and video layers #986 +- Only enable TCP/TURN fallback for supported clients #997 + +## [1.2.0] - 2022-08-25 + +### Added + +- Support for NACK with audio tracks (#829) +- Allow binding HTTP server to specific address, binds to localhost in dev mode(#831) +- Packet stats from TC (#832) +- Automatic connectivity fallback to TCP & TURN (#872 #873 #874 #901 #950) +- Support for client-side ping/pong messages (#871) +- Support for setCodecPreferences for clients that don't implement it (#916) +- Opus/RED support: redundant audio transmission is enabled by default (#938 #940) + +### Fixed + +- Fixed timing issue in DownTrack.Bind/Close (#833) +- Fixed TCPMux potentially blocking operations (#840) +- Fixed ICE restart while still in ICE gathering (#895) +- Fixed Websocket connection hanging if node isn't available to accept connection (#923) +- Fixed ICE restart/resume in single node mode (#930) +- Fixed client disconnected in certain conditions after ICE restart (#932) + +### Changed + +- Move to synchronously handle subscriber dynacast status (#834) +- Retransmit DD extension in case packets were missed (#837) +- Clean up stats workers (#836) +- Use TimedVersion for subscription permission updates (#839) +- Cleaned up logging (#843 #865 #910 #921) +- track_published event now includes the participant's ID and identity (#846) +- Improve synchronization of track publishing/unpublish path (#857) +- Don't re-use transceiver when pending negotiation (#862) +- Dynacast and media loss proxy refactor (#894 #902) +- PCTransport refactor (#907 #944) +- Improve accuracy of connection quality score (#912 #913) +- Docker image now builds with Go v1.19 + +## [1.1.2] - 2022-07-11 + +### Added + +- Returns reason when server disconnects a client (#801 #806) +- Allow livekit-server to start without keys configuration (#788) +- Added recovery from negotiation failures (#807) + +### Fixed + +- Fixed synchronization issues with Dynacast (#779 #802) +- Fixed panic due to timing in Pion's ICE agent (#780) +- ICELite is disabled by default, improving connectivity behind NAT (#784) +- Fixed EgressService UpdateLayout (#782) +- Fixed synchronization bugs with selective subscriptions & permissions (#796 #797 #805 #813 #814 #816) +- Correctly recover from ICE Restart during an negotiation attempt (#798) + +### Changed + +- Improved Transceiver re-use to avoid renegotiation (#785) +- Close room if recorder is the only participant left (#787) +- Improved connection quality score stability & computation (#793 #795) +- Set layer state to stopped when paused (#818) + +### Removed + +- Removed deprecated RecordingService - Egress should be used instead (#811) + +## [1.1.0] - 2022-06-21 + +### Added + +- Add support for Redis Sentinel (#707) +- Track participant join total + rate in node stats (#741) +- Protocol 8 - fast connection support (#747) +- Simulate switch candidate for network connectivity with poor UDP performance (#754) +- Allow server to disable codec for certain devices (#755) +- Support for on-demand multi-codec publishing (#762) + +### Fixed + +- Fixed unclean DownTrack close when removed before bound. (#736) +- Do not munge VP8 header in place - fixes video corruption (#763) + +### Changed + +- Reintroduce audio-level quantization to dampen small changes (#732) +- Allow overshooting maximum when there are no bandwidth constraints. (#739) +- Improvements to upcoming multi-codec simulcast (#740) +- Send layer dimensions when max subscribed layers change (#746) +- Use stable TrackID after unpublishing & republishing (#751) +- Update egress RPC handler (#759) +- Improved connection quality metrics (#766 #767 #770 #771 #773 #774 #775) + +## [1.0.2] - 2022-05-27 + +### Changed + +- Fixed edge cases where streams were not allocated (#701) +- Fixed panic caused by concurrent modifications to stats worker map (#702 #704) +- Batched subscriber updates to reduce noise in large rooms (#703 #729) +- Fixed potential data race conditions (#706 #709 #711 #713 #715 #716 #717 #724 #727 +- /debug/pprof endpoint when running in development mode (#708) +- When audio tracks are muted, send blank frames to induce silence (#710) +- Fixed stream allocator not upgrading streams after downgrading (#719) +- Fixed repeated AddSubscriber potentially ignored (#723) +- Fixed ListEgress API sometimes returning not found (#722) + +## [1.0.1] - 2022-05-19 + +### Changed + +- Update Egress details when changed, fixed Egress APIs (#694) + +## [1.0.0] - 2022-05-17 + +### Added + +- Improved stats around NACKs (#664) +- Internal structures in preparation for AV1 SVC support (#669) + +### Changed + +- Supports participant identity in permissions API (#633) +- Fixed concurrent access of stats worker map (#666 #670) +- Do not count padding packets in stream tracker (#667) +- Fixed TWCC panic under heavy packet loss (#668) +- Change state to JOINED before sending JoinResponse (#674) +- Improved frequency of stats update (#673) +- Send active speaker update during initial subscription (#676) +- Updated DTLS library to incorporate security fixes (#678) +- Improved list-nodes command (#681) +- Improved screen-share handling in StreamTracker (#683) +- Inject slience opus packets when muted (#682) diff --git a/livekit/Dockerfile b/livekit/Dockerfile new file mode 100644 index 0000000..d2ad082 --- /dev/null +++ b/livekit/Dockerfile @@ -0,0 +1,44 @@ +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM golang:1.25-alpine AS builder + +ARG TARGETPLATFORM +ARG TARGETARCH +RUN echo building for "$TARGETPLATFORM" + +WORKDIR /workspace + +# Copy the Go Modules manifests +COPY go.mod go.mod +COPY go.sum go.sum +# cache deps before building and copying source so that we don't need to re-download as much +# and so that source changes don't invalidate our downloaded layer +RUN go mod download + +# Copy the go source +COPY cmd/ cmd/ +COPY pkg/ pkg/ +COPY test/ test/ +COPY tools/ tools/ +COPY version/ version/ + +RUN CGO_ENABLED=0 GOOS=linux GOARCH=$TARGETARCH GO111MODULE=on go build -a -o livekit-server ./cmd/server + +FROM alpine + +COPY --from=builder /workspace/livekit-server /livekit-server + +# Run the binary. +ENTRYPOINT ["/livekit-server"] diff --git a/livekit/LICENSE b/livekit/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/livekit/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/livekit/NOTICE b/livekit/NOTICE new file mode 100644 index 0000000..692adc9 --- /dev/null +++ b/livekit/NOTICE @@ -0,0 +1,13 @@ +Copyright 2023 LiveKit, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/livekit/README.md b/livekit/README.md new file mode 100644 index 0000000..b41b634 --- /dev/null +++ b/livekit/README.md @@ -0,0 +1,319 @@ + + + + + + The LiveKit icon, the name of the repository and some sample code in the background. + + + + +# LiveKit: Real-time video, audio and data for developers + +[LiveKit](https://livekit.io) is an open source project that provides scalable, multi-user conferencing based on WebRTC. +It's designed to provide everything you need to build real-time video audio data capabilities in your applications. + +LiveKit's server is written in Go, using the awesome [Pion WebRTC](https://github.com/pion/webrtc) implementation. + +[![GitHub stars](https://img.shields.io/github/stars/livekit/livekit?style=social&label=Star&maxAge=2592000)](https://github.com/livekit/livekit/stargazers/) +[![Slack community](https://img.shields.io/endpoint?url=https%3A%2F%2Flivekit.io%2Fbadges%2Fslack)](https://livekit.io/join-slack) +[![Twitter Follow](https://img.shields.io/twitter/follow/livekit)](https://twitter.com/livekit) +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/livekit/livekit) +[![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/livekit/livekit)](https://github.com/livekit/livekit/releases/latest) +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/livekit/livekit/buildtest.yaml?branch=master)](https://github.com/livekit/livekit/actions/workflows/buildtest.yaml) +[![License](https://img.shields.io/github/license/livekit/livekit)](https://github.com/livekit/livekit/blob/master/LICENSE) + +## Features + +- Scalable, distributed WebRTC SFU (Selective Forwarding Unit) +- Modern, full-featured client SDKs +- Built for production, supports JWT authentication +- Robust networking and connectivity, UDP/TCP/TURN +- Easy to deploy: single binary, Docker or Kubernetes +- Advanced features including: + - [speaker detection](https://docs.livekit.io/home/client/tracks/subscribe/#speaker-detection) + - [simulcast](https://docs.livekit.io/home/client/tracks/publish/#video-simulcast) + - [end-to-end optimizations](https://blog.livekit.io/livekit-one-dot-zero/) + - [selective subscription](https://docs.livekit.io/home/client/tracks/subscribe/#selective-subscription) + - [moderation APIs](https://docs.livekit.io/home/server/managing-participants/) + - end-to-end encryption + - SVC codecs (VP9, AV1) + - [webhooks](https://docs.livekit.io/home/server/webhooks/) + - [distributed and multi-region](https://docs.livekit.io/home/self-hosting/distributed/) + +## Documentation & Guides + +https://docs.livekit.io + +## Live Demos + +- [LiveKit Meet](https://meet.livekit.io) ([source](https://github.com/livekit-examples/meet)) +- [Spatial Audio](https://spatial-audio-demo.livekit.io/) ([source](https://github.com/livekit-examples/spatial-audio)) +- Livestreaming from OBS Studio ([source](https://github.com/livekit-examples/livestream)) +- [AI voice assistant using ChatGPT](https://livekit.io/kitt) ([source](https://github.com/livekit-examples/kitt)) + +## Ecosystem + +- [Agents](https://github.com/livekit/agents): build real-time multimodal AI applications with programmable backend participants +- [Egress](https://github.com/livekit/egress): record or multi-stream rooms and export individual tracks +- [Ingress](https://github.com/livekit/ingress): ingest streams from external sources like RTMP, WHIP, HLS, or OBS Studio + +## SDKs & Tools + +### Client SDKs + +Client SDKs enable your frontend to include interactive, multi-user experiences. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
LanguageRepo + Declarative UI + Links
JavaScript (TypeScript) + client-sdk-js + + React + + docs + | + JS example + | + React example +
Swift (iOS / MacOS) + client-sdk-swift + Swift UI + docs + | + example +
Kotlin (Android) + client-sdk-android + Compose + docs + | + example + | + Compose example +
Flutter (all platforms) + client-sdk-flutter + native + docs + | + example +
Unity WebGL + client-sdk-unity-web + + docs +
React Native (beta) + client-sdk-react-native + native
Rust + client-sdk-rust +
+ +### Server SDKs + +Server SDKs enable your backend to generate [access tokens](https://docs.livekit.io/home/get-started/authentication/), +call [server APIs](https://docs.livekit.io/reference/server/server-apis/), and +receive [webhooks](https://docs.livekit.io/home/server/webhooks/). In addition, the Go SDK includes client capabilities, +enabling you to build automations that behave like end-users. + +| Language | Repo | Docs | +| :---------------------- | :-------------------------------------------------------------------------------------- | :---------------------------------------------------------- | +| Go | [server-sdk-go](https://github.com/livekit/server-sdk-go) | [docs](https://pkg.go.dev/github.com/livekit/server-sdk-go) | +| JavaScript (TypeScript) | [server-sdk-js](https://github.com/livekit/server-sdk-js) | [docs](https://docs.livekit.io/server-sdk-js/) | +| Ruby | [server-sdk-ruby](https://github.com/livekit/server-sdk-ruby) | | +| Java (Kotlin) | [server-sdk-kotlin](https://github.com/livekit/server-sdk-kotlin) | | +| Python (community) | [python-sdks](https://github.com/livekit/python-sdks) | | +| PHP (community) | [agence104/livekit-server-sdk-php](https://github.com/agence104/livekit-server-sdk-php) | | + +### Tools + +- [CLI](https://github.com/livekit/livekit-cli) - command line interface & load tester +- [Docker image](https://hub.docker.com/r/livekit/livekit-server) +- [Helm charts](https://github.com/livekit/livekit-helm) + +## Install + +> [!TIP] +> We recommend installing [LiveKit CLI](https://github.com/livekit/livekit-cli) along with the server. It lets you access +> server APIs, create tokens, and generate test traffic. + +The following will install LiveKit's media server: + +### MacOS + +```shell +brew install livekit +``` + +### Linux + +```shell +curl -sSL https://get.livekit.io | bash +``` + +### Windows + +Download the [latest release here](https://github.com/livekit/livekit/releases/latest) + +## Getting Started + +### Starting LiveKit + +Start LiveKit in development mode by running `livekit-server --dev`. It'll use a placeholder API key/secret pair. + +``` +API Key: devkey +API Secret: secret +``` + +To customize your setup for production, refer to our [deployment docs](https://docs.livekit.io/deploy/) + +### Creating access token + +A user connecting to a LiveKit room requires an [access token](https://docs.livekit.io/home/get-started/authentication/#creating-a-token). Access +tokens (JWT) encode the user's identity and the room permissions they've been granted. You can generate a token with our +CLI: + +```shell +lk token create \ + --api-key devkey --api-secret secret \ + --join --room my-first-room --identity user1 \ + --valid-for 24h +``` + +### Test with example app + +Head over to our [example app](https://example.livekit.io) and enter a generated token to connect to your LiveKit +server. This app is built with our [React SDK](https://github.com/livekit/livekit-react). + +Once connected, your video and audio are now being published to your new LiveKit instance! + +### Simulating a test publisher + +```shell +lk room join \ + --url ws://localhost:7880 \ + --api-key devkey --api-secret secret \ + --identity bot-user1 \ + --publish-demo \ + my-first-room +``` + +This command publishes a looped demo video to a room. Due to how the video clip was encoded (keyframes every 3s), +there's a slight delay before the browser has sufficient data to begin rendering frames. This is an artifact of the +simulation. + +## Deployment + +### Use LiveKit Cloud + +LiveKit Cloud is the fastest and most reliable way to run LiveKit. Every project gets free monthly bandwidth and +transcoding credits. + +Sign up for [LiveKit Cloud](https://cloud.livekit.io/). + +### Self-host + +Read our [deployment docs](https://docs.livekit.io/deploy/) for more information. + +## Building from source + +Pre-requisites: + +- Go 1.23+ is installed +- GOPATH/bin is in your PATH + +Then run + +```shell +git clone https://github.com/livekit/livekit +cd livekit +./bootstrap.sh +mage +``` + +## Contributing + +We welcome your contributions toward improving LiveKit! Please join us +[on Slack](http://livekit.io/join-slack) to discuss your ideas and/or PRs. + +## License + +LiveKit server is licensed under Apache License v2.0. + + +
+ + + + + + + + + +
LiveKit Ecosystem
LiveKit SDKsBrowser · iOS/macOS/visionOS · Android · Flutter · React Native · Rust · Node.js · Python · Unity · Unity (WebGL) · ESP32
Server APIsNode.js · Golang · Ruby · Java/Kotlin · Python · Rust · PHP (community) · .NET (community)
UI ComponentsReact · Android Compose · SwiftUI · Flutter
Agents FrameworksPython · Node.js · Playground
ServicesLiveKit server · Egress · Ingress · SIP
ResourcesDocs · Example apps · Cloud · Self-hosting · CLI
+ diff --git a/livekit/bootstrap.sh b/livekit/bootstrap.sh new file mode 100755 index 0000000..7511dba --- /dev/null +++ b/livekit/bootstrap.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +if ! command -v mage &> /dev/null +then + pushd /tmp + git clone https://github.com/magefile/mage + cd mage + go run bootstrap.go + rm -rf /tmp/mage + popd +fi + +if ! command -v mage &> /dev/null +then + echo "Ensure `go env GOPATH`/bin is in your \$PATH" + exit 1 +fi + +go mod download diff --git a/livekit/cmd/server/commands.go b/livekit/cmd/server/commands.go new file mode 100644 index 0000000..dfbd0da --- /dev/null +++ b/livekit/cmd/server/commands.go @@ -0,0 +1,258 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/dustin/go-humanize" + "github.com/olekukonko/tablewriter" + "github.com/urfave/cli/v3" + "gopkg.in/yaml.v3" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" +) + +func generateKeys(_ context.Context, _ *cli.Command) error { + apiKey := guid.New(utils.APIKeyPrefix) + secret := utils.RandomSecret() + fmt.Println("API Key: ", apiKey) + fmt.Println("API Secret: ", secret) + return nil +} + +func printPorts(_ context.Context, c *cli.Command) error { + conf, err := getConfig(c) + if err != nil { + return err + } + + udpPorts := make([]string, 0) + tcpPorts := make([]string, 0) + + tcpPorts = append(tcpPorts, fmt.Sprintf("%d - HTTP service", conf.Port)) + if conf.RTC.TCPPort != 0 { + tcpPorts = append(tcpPorts, fmt.Sprintf("%d - ICE/TCP", conf.RTC.TCPPort)) + } + if conf.RTC.UDPPort.Valid() { + portStr, _ := conf.RTC.UDPPort.MarshalYAML() + udpPorts = append(udpPorts, fmt.Sprintf("%s - ICE/UDP", portStr)) + } else { + udpPorts = append(udpPorts, fmt.Sprintf("%d-%d - ICE/UDP range", conf.RTC.ICEPortRangeStart, conf.RTC.ICEPortRangeEnd)) + } + + if conf.TURN.Enabled { + if conf.TURN.TLSPort > 0 { + tcpPorts = append(tcpPorts, fmt.Sprintf("%d - TURN/TLS", conf.TURN.TLSPort)) + } + if conf.TURN.UDPPort > 0 { + udpPorts = append(udpPorts, fmt.Sprintf("%d - TURN/UDP", conf.TURN.UDPPort)) + } + } + + fmt.Println("TCP Ports") + for _, p := range tcpPorts { + fmt.Println(p) + } + + fmt.Println("UDP Ports") + for _, p := range udpPorts { + fmt.Println(p) + } + return nil +} + +func helpVerbose(_ context.Context, c *cli.Command) error { + generatedFlags, err := config.GenerateCLIFlags(baseFlags, false) + if err != nil { + return err + } + + c.Flags = append(baseFlags, generatedFlags...) + return cli.ShowAppHelp(c) +} + +func createToken(_ context.Context, c *cli.Command) error { + room := c.String("room") + identity := c.String("identity") + + conf, err := getConfig(c) + if err != nil { + return err + } + + // use the first API key from config + if len(conf.Keys) == 0 { + // try to load from file + if _, err := os.Stat(conf.KeyFile); err != nil { + return err + } + f, err := os.Open(conf.KeyFile) + if err != nil { + return err + } + defer func() { + _ = f.Close() + }() + decoder := yaml.NewDecoder(f) + if err = decoder.Decode(conf.Keys); err != nil { + return err + } + + if len(conf.Keys) == 0 { + return fmt.Errorf("keys are not configured") + } + } + + var apiKey string + var apiSecret string + for k, v := range conf.Keys { + apiKey = k + apiSecret = v + break + } + + grant := &auth.VideoGrant{ + RoomJoin: true, + Room: room, + } + if c.Bool("recorder") { + grant.Hidden = true + grant.Recorder = true + grant.SetCanPublish(false) + grant.SetCanPublishData(false) + } + + at := auth.NewAccessToken(apiKey, apiSecret). + AddGrant(grant). + SetIdentity(identity). + SetValidFor(30 * 24 * time.Hour) + + token, err := at.ToJWT() + if err != nil { + return err + } + + fmt.Println("Token:", token) + + return nil +} + +func listNodes(_ context.Context, c *cli.Command) error { + conf, err := getConfig(c) + if err != nil { + return err + } + + currentNode, err := routing.NewLocalNode(conf) + if err != nil { + return err + } + + router, err := service.InitializeRouter(conf, currentNode) + if err != nil { + return err + } + + nodes, err := router.ListNodes() + if err != nil { + return err + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetRowLine(true) + table.SetAutoWrapText(false) + table.SetHeader([]string{ + "ID", "IP Address", "Region", + "CPUs", "CPU Usage\nLoad Avg", + "Memory Used/Total", + "Rooms", "Clients\nTracks In/Out", + "Bytes/s In/Out\nBytes Total", "Packets/s In/Out\nPackets Total", "System Dropped Pkts/s\nPkts/s Out/Dropped", + "Nack/s\nNack Total", "Retrans/s\nRetrans Total", + "Started At\nUpdated At", + }) + table.SetColumnAlignment([]int{ + tablewriter.ALIGN_CENTER, tablewriter.ALIGN_CENTER, tablewriter.ALIGN_CENTER, + tablewriter.ALIGN_RIGHT, tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_RIGHT, tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_RIGHT, tablewriter.ALIGN_RIGHT, tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_RIGHT, tablewriter.ALIGN_RIGHT, + tablewriter.ALIGN_CENTER, + }) + + for _, node := range nodes { + stats := node.Stats + rate := &livekit.NodeStatsRate{} + if len(stats.Rates) > 0 { + rate = stats.Rates[0] + } + + // Id and state + idAndState := fmt.Sprintf("%s\n(%s)", node.Id, node.State.Enum().String()) + + // System stats + cpus := strconv.Itoa(int(stats.NumCpus)) + cpuUsageAndLoadAvg := fmt.Sprintf("%.2f %%\n%.2f %.2f %.2f", stats.CpuLoad*100, + stats.LoadAvgLast1Min, stats.LoadAvgLast5Min, stats.LoadAvgLast15Min) + memUsage := fmt.Sprintf("%s / %s", humanize.Bytes(stats.MemoryUsed), humanize.Bytes(stats.MemoryTotal)) + + // Room stats + rooms := strconv.Itoa(int(stats.NumRooms)) + clientsAndTracks := fmt.Sprintf("%d\n%d / %d", stats.NumClients, stats.NumTracksIn, stats.NumTracksOut) + + // Packet stats + bytes := fmt.Sprintf("%sps / %sps\n%s / %s", humanize.Bytes(uint64(rate.BytesIn)), humanize.Bytes(uint64(rate.BytesOut)), + humanize.Bytes(stats.BytesIn), humanize.Bytes(stats.BytesOut)) + packets := fmt.Sprintf("%s / %s\n%s / %s", humanize.Comma(int64(rate.PacketsIn)), humanize.Comma(int64(rate.PacketsOut)), + strings.TrimSpace(humanize.SIWithDigits(float64(stats.PacketsIn), 2, "")), strings.TrimSpace(humanize.SIWithDigits(float64(stats.PacketsOut), 2, ""))) + sysPacketsDroppedPct := float32(0) + if rate.SysPacketsOut+rate.SysPacketsDropped > 0 { + sysPacketsDroppedPct = float32(rate.SysPacketsDropped) / float32(rate.SysPacketsDropped+rate.SysPacketsOut) + } + sysPackets := fmt.Sprintf("%.2f %%\n%v / %v", sysPacketsDroppedPct*100, float64(rate.SysPacketsOut), float64(rate.SysPacketsDropped)) + nacks := fmt.Sprintf("%.2f\n%s", rate.NackTotal, strings.TrimSpace(humanize.SIWithDigits(float64(stats.NackTotal), 2, ""))) + retransmit := fmt.Sprintf("%.2f\n%s", rate.RetransmitPacketsOut, strings.TrimSpace(humanize.SIWithDigits(float64(stats.RetransmitPacketsOut), 2, ""))) + + // Date + startedAndUpdated := fmt.Sprintf("%s\n%s", time.Unix(stats.StartedAt, 0).UTC().UTC().Format("2006-01-02 15:04:05"), + time.Unix(stats.UpdatedAt, 0).UTC().Format("2006-01-02 15:04:05")) + + table.Append([]string{ + idAndState, node.Ip, node.Region, + cpus, cpuUsageAndLoadAvg, + memUsage, + rooms, clientsAndTracks, + bytes, packets, sysPackets, + nacks, retransmit, + startedAndUpdated, + }) + } + table.Render() + + return nil +} diff --git a/livekit/cmd/server/main.go b/livekit/cmd/server/main.go new file mode 100644 index 0000000..b5e9741 --- /dev/null +++ b/livekit/cmd/server/main.go @@ -0,0 +1,333 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "math/rand" + "net" + "os" + "os/signal" + "runtime" + "runtime/pprof" + "syscall" + "time" + + "github.com/urfave/cli/v3" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/tracer/jaeger" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/version" +) + +var baseFlags = []cli.Flag{ + &cli.StringSliceFlag{ + Name: "bind", + Usage: "IP address to listen on, use flag multiple times to specify multiple addresses", + }, + &cli.StringFlag{ + Name: "config", + Usage: "path to LiveKit config file", + }, + &cli.StringFlag{ + Name: "config-body", + Usage: "LiveKit config in YAML, typically passed in as an environment var in a container", + Sources: cli.EnvVars("LIVEKIT_CONFIG"), + }, + &cli.StringFlag{ + Name: "key-file", + Usage: "path to file that contains API keys/secrets", + }, + &cli.StringFlag{ + Name: "keys", + Usage: "api keys (key: secret\\n)", + Sources: cli.EnvVars("LIVEKIT_KEYS"), + }, + &cli.StringFlag{ + Name: "region", + Usage: "region of the current node. Used by regionaware node selector", + Sources: cli.EnvVars("LIVEKIT_REGION"), + }, + &cli.StringFlag{ + Name: "node-ip", + Usage: "IP address of the current node, used to advertise to clients. Automatically determined by default", + Sources: cli.EnvVars("NODE_IP"), + }, + &cli.StringFlag{ + Name: "udp-port", + Usage: "UDP port(s) to use for WebRTC traffic", + Sources: cli.EnvVars("UDP_PORT"), + }, + &cli.StringFlag{ + Name: "redis-host", + Usage: "host (incl. port) to redis server", + Sources: cli.EnvVars("REDIS_HOST"), + }, + &cli.StringFlag{ + Name: "redis-password", + Usage: "password to redis", + Sources: cli.EnvVars("REDIS_PASSWORD"), + }, + &cli.StringFlag{ + Name: "turn-cert", + Usage: "tls cert file for TURN server", + Sources: cli.EnvVars("LIVEKIT_TURN_CERT"), + }, + &cli.StringFlag{ + Name: "turn-key", + Usage: "tls key file for TURN server", + Sources: cli.EnvVars("LIVEKIT_TURN_KEY"), + }, + &cli.StringFlag{ + Name: "cpuprofile", + Usage: "write CPU profile to `file`", + }, + &cli.StringFlag{ + Name: "memprofile", + Usage: "write memory profile to `file`", + }, + &cli.BoolFlag{ + Name: "dev", + Usage: "sets log-level to debug, console formatter, and /debug/pprof. insecure for production", + }, + &cli.BoolFlag{ + Name: "disable-strict-config", + Usage: "disables strict config parsing", + Hidden: true, + }, +} + +func init() { + rand.Seed(time.Now().Unix()) +} + +func main() { + defer func() { + if rtc.Recover(logger.GetLogger()) != nil { + os.Exit(1) + } + }() + + generatedFlags, err := config.GenerateCLIFlags(baseFlags, true) + if err != nil { + fmt.Println(err) + } + + cmd := &cli.Command{ + Name: "livekit-server", + Usage: "High performance WebRTC server", + Description: "run without subcommands to start the server", + Flags: append(baseFlags, generatedFlags...), + Action: startServer, + Commands: []*cli.Command{ + { + Name: "generate-keys", + Usage: "generates an API key and secret pair", + Action: generateKeys, + }, + { + Name: "ports", + Usage: "print ports that server is configured to use", + Action: printPorts, + }, + { + // this subcommand is deprecated, token generation is provided by CLI + Name: "create-join-token", + Hidden: true, + Usage: "create a room join token for development use", + Action: createToken, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "room", + Usage: "name of room to join", + Required: true, + }, + &cli.StringFlag{ + Name: "identity", + Usage: "identity of participant that holds the token", + Required: true, + }, + &cli.BoolFlag{ + Name: "recorder", + Usage: "creates a hidden participant that can only subscribe", + Required: false, + }, + }, + }, + { + Name: "list-nodes", + Usage: "list all nodes", + Action: listNodes, + }, + { + Name: "help-verbose", + Usage: "prints app help, including all generated configuration flags", + Action: helpVerbose, + }, + }, + Version: version.Version, + } + + if err := cmd.Run(context.Background(), os.Args); err != nil { + fmt.Println(err) + } +} + +func getConfig(c *cli.Command) (*config.Config, error) { + confString, err := getConfigString(c.String("config"), c.String("config-body")) + if err != nil { + return nil, err + } + + strictMode := true + if c.Bool("disable-strict-config") { + strictMode = false + } + + conf, err := config.NewConfig(confString, strictMode, c, baseFlags) + if err != nil { + return nil, err + } + config.InitLoggerFromConfig(&conf.Logging) + + if conf.Development { + logger.Infow("starting in development mode") + + if len(conf.Keys) == 0 { + logger.Infow("no keys provided, using placeholder keys", + "API Key", "devkey", + "API Secret", "secret", + ) + conf.Keys = map[string]string{ + "devkey": "secret", + } + shouldMatchRTCIP := false + // when dev mode and using shared keys, we'll bind to localhost by default + if conf.BindAddresses == nil { + conf.BindAddresses = []string{ + "127.0.0.1", + "::1", + } + } else { + // if non-loopback addresses are provided, then we'll match RTC IP to bind address + // our IP discovery ignores loopback addresses + for _, addr := range conf.BindAddresses { + ip := net.ParseIP(addr) + if ip != nil && !ip.IsLoopback() && !ip.IsUnspecified() { + shouldMatchRTCIP = true + } + } + } + if shouldMatchRTCIP { + for _, bindAddr := range conf.BindAddresses { + conf.RTC.IPs.Includes = append(conf.RTC.IPs.Includes, bindAddr+"/24") + } + } + } + } + return conf, nil +} + +func startServer(ctx context.Context, c *cli.Command) error { + conf, err := getConfig(c) + if err != nil { + return err + } + if url := conf.Trace.JaegerURL; url != "" { + jaeger.Configure(ctx, url, "livekit") + } + + // validate API key length + err = conf.ValidateKeys() + if err != nil { + return err + } + + if cpuProfile := c.String("cpuprofile"); cpuProfile != "" { + if f, err := os.Create(cpuProfile); err != nil { + return err + } else { + if err := pprof.StartCPUProfile(f); err != nil { + f.Close() + return err + } + defer func() { + pprof.StopCPUProfile() + f.Close() + }() + } + } + + if memProfile := c.String("memprofile"); memProfile != "" { + if f, err := os.Create(memProfile); err != nil { + return err + } else { + defer func() { + // run memory profile at termination + runtime.GC() + _ = pprof.WriteHeapProfile(f) + _ = f.Close() + }() + } + } + + currentNode, err := routing.NewLocalNode(conf) + if err != nil { + return err + } + + if err := prometheus.Init(string(currentNode.NodeID()), currentNode.NodeType()); err != nil { + return err + } + + server, err := service.InitializeServer(conf, currentNode) + if err != nil { + return err + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + + go func() { + for i := range 2 { + sig := <-sigChan + force := i > 0 + logger.Infow("exit requested, shutting down", "signal", sig, "force", force) + go server.Stop(force) + } + }() + + return server.Start() +} + +func getConfigString(configFile string, inConfigBody string) (string, error) { + if inConfigBody != "" || configFile == "" { + return inConfigBody, nil + } + + outConfigBody, err := os.ReadFile(configFile) + if err != nil { + return "", err + } + + return string(outConfigBody), nil +} diff --git a/livekit/cmd/server/main_test.go b/livekit/cmd/server/main_test.go new file mode 100644 index 0000000..dc82bf0 --- /dev/null +++ b/livekit/cmd/server/main_test.go @@ -0,0 +1,63 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +type testStruct struct { + configFileName string + configBody string + + expectedError error + expectedConfigBody string +} + +func TestGetConfigString(t *testing.T) { + tests := []testStruct{ + {"", "", nil, ""}, + {"", "configBody", nil, "configBody"}, + {"file", "configBody", nil, "configBody"}, + {"file", "", nil, "fileContent"}, + } + for _, test := range tests { + func() { + writeConfigFile(test, t) + defer os.Remove(test.configFileName) + + configBody, err := getConfigString(test.configFileName, test.configBody) + require.Equal(t, test.expectedError, err) + require.Equal(t, test.expectedConfigBody, configBody) + }() + } +} + +func TestShouldReturnErrorIfConfigFileDoesNotExist(t *testing.T) { + configBody, err := getConfigString("notExistingFile", "") + require.Error(t, err) + require.Empty(t, configBody) +} + +func writeConfigFile(test testStruct, t *testing.T) { + if test.configFileName != "" { + d1 := []byte(test.expectedConfigBody) + err := os.WriteFile(test.configFileName, d1, 0o644) + require.NoError(t, err) + } +} diff --git a/livekit/config-sample.yaml b/livekit/config-sample.yaml new file mode 100644 index 0000000..189a8f2 --- /dev/null +++ b/livekit/config-sample.yaml @@ -0,0 +1,330 @@ +# Copyright 2024 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# main TCP port for RoomService and RTC endpoint +# for production setups, this port should be placed behind a load balancer with TLS +port: 7880 + +# when redis is set, LiveKit will automatically operate in a fully distributed fashion +# clients could connect to any node and be routed to the same room +redis: + address: redis.host:6379 + # db: 0 + # username: myuser + # password: mypassword + # To use sentinel remove the address key above and add the following + # sentinel_master_name: livekit + # sentinel_addresses: + # - livekit-redis-node-0.livekit-redis-headless:26379 + # - livekit-redis-node-1.livekit-redis-headless:26379 + # If you use a different set of credentials for sentinel add + # sentinel_username: user + # sentinel_password: pass + # + # To use TLS with redis + # tls: + # enabled: true + # # when set to true, LiveKit will not verify the server's certificate, defaults to true + # insecure: false + # server_name: myserver.com + # # file containing trusted root certificates for verification + # ca_cert_file: /path/to/ca.crt + # client_cert_file: /path/to/client.crt + # client_key_file: /path/to/client.key + # + # To use cluster remove the address key above and add the following + # cluster_addresses: + # - livekit-redis-node-0.livekit-redis-headless:6379 + # - livekit-redis-node-1.livekit-redis-headless:6380 + # And it will use the password key above as cluster password + # And the db key will not be used due to cluster mode not support it. + +# WebRTC configuration +rtc: + # UDP ports to use for client traffic. + # this port range should be open for inbound traffic on the firewall + port_range_start: 50000 + port_range_end: 60000 + # when set, LiveKit enable WebRTC ICE over TCP when UDP isn't available + # this port *cannot* be behind load balancer or TLS, and must be exposed on the node + # WebRTC transports are encrypted and do not require additional encryption + # only 80/443 on public IP are allowed if less than 1024 + tcp_port: 7881 + # when set to true, attempts to discover the host's public IP via STUN + # this is useful for cloud environments such as AWS & Google where hosts have an internal IP + # that maps to an external one + use_external_ip: true + # # there are cases where the public IP determined via STUN is not the correct one + # # in such cases, use this setting to set the public IP of the node + # # use_external_ip takes precedence, for this to take effect, set use_external_ip to false + # node_ip: + # # when set, LiveKit will attempt to use a UDP mux so all UDP traffic goes through + # # listed port(s). To maximize system performance, we recommend using a range of ports + # # greater or equal to the number of vCPUs on the machine. + # # port_range_start & end must not be set for this config to take effect + # udp_port: 7882-7892 + # # when set to true, server will use a lite ice agent, that will speed up ice connection, but + # # might cause connect issue if server running behind NAT. + # use_ice_lite: true + # # optional STUN servers for LiveKit clients to use. Clients will be configured to use these STUN servers automatically. + # # by default LiveKit clients use Google's public STUN servers + # stun_servers: + # - server1 + # # optional TURN servers for clients. This isn't necessary if using embedded TURN server (see below). + # turn_servers: + # - host: myhost.com + # port: 443 + # # tls, tcp, or udp + # protocol: tls + # # Shared secret for TURN server authentication + # secret: "" + # ttl: 14400 # seconds + # # Insecure username/password authentication + # username: "" + # credential: "" + # # allows LiveKit to monitor congestion when sending streams and automatically + # # manage bandwidth utilization to avoid congestion/loss. Enabled by default + # congestion_control: + # enabled: true + # # in the unlikely event of highly congested networks, SFU may choose to pause some tracks + # # in order to allow others to stream smoothly. You can disable this behavior here + # allow_pause: true + # # allows automatic connection fallback to TCP and TURN/TLS (if configured) when UDP has been unstable, default true + # allow_tcp_fallback: true + # # number of packets to buffer in the SFU for video, defaults to 500 + # packet_buffer_size_video: 500 + # # number of packets to buffer in the SFU for audio, defaults to 200 + # packet_buffer_size_audio: 200 + # # minimum amount of time between pli/fir rtcp packets being sent to an individual + # # producer. Increasing these times can lead to longer black screens when new participants join, + # # while reducing them can lead to higher stream bitrate. + # pli_throttle: + # low_quality: 500ms + # mid_quality: 1s + # high_quality: 1s + # # when set, Livekit will collect loopback candidates, it is useful for some VM have public address mapped to its loopback interface. + # enable_loopback_candidate: true + # # network interface filter. If the machine has more than one network interface and you'd like it to use or skip specific interfaces + # # both inclusion and exclusion filters can be used together. If neither is defined (default), all interfaces on the machine will be used. + # # If both of them are set, then only include takes effect. + # interfaces: + # includes: + # - en0 + # excludes: + # - docker0 + # # ip address filter. If the machine has more than one ip address and you'd like it to use or skip specific ips, + # # both inclusion and exclusion CIDR filters can be used together. If neither is defined (default), all ip on the machine will be used. + # # If both of them are set, then only include takes effect. + # ips: + # includes: + # - 10.0.0.0/16 + # excludes: + # - 192.168.1.0/24 + # # Set to true to enable mDNS name candidate. This should be left disabled for most users. + # # when enabled, it will impact performance since each PeerConnection will process the same mDNS message independently + # use_mdns: true + # # Set to false to disable strict ACKs for peer connections where LiveKit is the dialing side, + # # ie. subscriber peer connections. Disabling strict ACKs will prevent clients that do not ACK + # # peer connections from getting kicked out of rooms by the monitor. Note that if strict ACKs + # # are disabled and clients don't ACK opened peer connections, only reliable, ordered delivery + # # will be available. + # strict_acks: true + # # enable batch write to merge network write system calls to reduce cpu usage. Outgoing packets + # # will be queued until length of queue equal to `batch_size` or time elapsed since last write exceeds `max_flush_interval`. + # batch_io: + # batch_size: 128 + # max_flush_interval: 2ms + # # max number of bytes to buffer for data channel. 0 means unlimited. + # # when this limit is breached, data messages will be dropped till the buffered amount drops below this limit. + # data_channel_max_buffered_amount: 0 + +# when enabled, LiveKit will expose prometheus metrics on :6789/metrics +# prometheus_port: 6789 + +# API key / secret pairs. +# Keys are used for JWT authentication, server APIs would require a keypair in order to generate access tokens +# and make calls to the server +keys: + key1: secret1 + key2: secret2 +# Logging config +# logging: +# # log level, valid values: debug, info, warn, error +# level: info +# # log level for pion, default error +# pion_level: error +# # when set to true, emit json fields +# json: false +# # for production setups, enables sampling algorithm +# # https://github.com/uber-go/zap/blob/master/FAQ.md#why-sample-application-logs +# sample: false + +# Default room config +# Each room created will inherit these settings. If rooms are created explicitly with CreateRoom, they will take +# precedence over defaults +# room: +# # allow rooms to be automatically created when participants join, defaults to true +# # auto_create: false +# # number of seconds to keep the room open if no one joins +# empty_timeout: 300 +# # number of seconds to keep the room open after everyone leaves +# departure_timeout: 20 +# # limit number of participants that can be in a room, 0 for no limit +# max_participants: 0 +# # only accept specific codecs for clients publishing to this room +# # this is useful to standardize codecs across clients +# # other supported codecs are video/h264, video/vp9, video/av1, audio/red +# enabled_codecs: +# - mime: audio/opus +# - mime: video/vp8 +# # allow tracks to be unmuted remotely, defaults to false +# # tracks can always be muted from the Room Service APIs +# enable_remote_unmute: true +# # control playout delay in ms of video track (and associated audio track) +# playout_delay: +# enabled: true +# min: 100 +# max: 2000 +# # improves A/V sync when playout_delay set to a value larger than 200ms. It will disables transceiver re-use +# # so not recommended for rooms with frequent subscription changes +# sync_streams: true + +# Webhooks +# when configured, LiveKit notifies your URL handler with room events +# webhook: +# # the API key to use in order to sign the message +# # this must match one of the keys LiveKit is configured with +# api_key: +# # list of URLs to be notified of room events +# urls: +# - https://your-host.com/handler + +# Signal Relay +# since v1.4.0, a more reliable, psrpc based signal relay is available +# this gives us the ability to reliably proxy messages between a signal server and RTC node +# signal_relay: +# # amount of time a message delivery is tried before giving up +# retry_timeout: 30s +# # minimum amount of time to wait for RTC node to ack, +# # retries use exponentially increasing wait on every subsequent try +# # with an upper bound of max_retry_interval +# min_retry_interval: 500ms +# # maximum amount of time to wait for RTC node to ack +# max_retry_interval: 5s +# # number of messages to buffer before dropping +# stream_buffer_size: 1000 + +# PSRPC +# since v1.5.1, a more reliable, psrpc based internal rpc +# psrpc: +# # maximum number of rpc attempts +# max_attempts: 3 +# # initial time to wait for calls to complete +# timeout: 500ms +# # amount of time added to the timeout after each failure +# backoff: 500ms +# # number of messages to buffer before dropping +# buffer_size: 1000 + +# customize audio level sensitivity +# audio: +# # minimum level to be considered active, 0-127, where 0 is loudest +# # defaults to 30 +# active_level: 30 +# # percentile to measure, a participant is considered active if it has exceeded the +# # ActiveLevel more than MinPercentile% of the time +# # defaults to 40 +# min_percentile: 40 +# # frequency in ms to notify changes to clients, defaults to 500 +# update_interval: 500 +# # to prevent speaker updates from too jumpy, smooth out values over N samples +# smooth_intervals: 4 +# # enable red encoding downtrack for opus only audio up track +# active_red_encoding: true + +# turn server +# turn: +# # Uses TLS. Requires cert and key pem files by either: +# # - using turn.secretName if deploying with our helm chart, or +# # - setting LIVEKIT_TURN_CERT and LIVEKIT_TURN_KEY env vars with file locations, or +# # - using cert_file and key_file below +# # defaults to false +# enabled: false +# # defaults to 3478 - recommended to 443 if not running HTTP3/QUIC server +# # only 53/80/443 are allowed if less than 1024 +# udp_port: 3478 +# # defaults to 5349 - if not using a load balancer, this must be set to 443 +# tls_port: 5349 +# # set UDP port range for TURN relay to connect to LiveKit SFU, by default it uses a any available port +# relay_range_start: 1024 +# relay_range_end: 30000 +# # set external_tls to true if using a L4 load balancer to terminate TLS. when enabled, +# # LiveKit expects unencrypted traffic on tls_port, and still advertise tls_port as a TURN/TLS candidate. +# external_tls: true +# # needs to match tls cert domain +# domain: turn.myhost.com +# # optional (set only if not using external TLS termination) +# # cert_file: /path/to/cert.pem +# # key_file: /path/to/key.pem + +# ingress server +# ingress: +# # Prefix used to generate RTMP URLs for RTMP ingress. +# rtmp_base_url: "rtmp://my.domain.com/live" +# # Prefix used to generate WHIP URLs for WHIP ingress. +# whip_base_url: "http://my.domain.com/whip" + +# Region of the current node. Required if using regionaware node selector +# region: us-west-2 + +# # node selector +# node_selector: +# # default: any. valid values: any, sysload, cpuload, regionaware +# kind: sysload +# # priority used for selection of node when multiple are available +# # default: random. valid values: random, sysload, cpuload, rooms, clients, tracks, bytespersec +# sort_by: sysload +# # algorithm used to govern selecting from sorted nodes +# # default: lowest. valid values: lowest, twochoice +# algorithm: lowest +# # used in sysload and regionaware +# # do not assign room to node if load per CPU exceeds sysload_limit +# sysload_limit: 0.7 +# # used in regionaware +# # list of regions and their lat/lon coordinates +# regions: +# - name: us-west-2 +# lat: 44.19434095976287 +# lon: -123.0674908379146 + +# # node limits +# # set to -1 to disable a limit +# limit: +# # defaults to 400 tracks in & out per CPU, up to 8000 +# num_tracks: -1 +# # defaults to 1 GB/s, or just under 10 Gbps +# bytes_per_sec: 1_000_000_000 +# # how many tracks (audio / video) that a single participant can subscribe at same time. +# # if the limit is exceeded, subscriptions will be pending until any subscribed track has been unsubscribed. +# # value less or equal than 0 means no limit. +# subscription_limit_video: 0 +# subscription_limit_audio: 0 +# # limit size of room and participant's metadata, 0 for no limit +# max_metadata_size: 0 +# # limit size of participant attributes, 0 for no limit +# max_attributes_size: 0 +# # limit length of room names +# max_room_name_length: 0 +# # limit length of participant identity +# max_participant_identity_length: 0 diff --git a/livekit/deploy/README.md b/livekit/deploy/README.md new file mode 100644 index 0000000..c72854c --- /dev/null +++ b/livekit/deploy/README.md @@ -0,0 +1,8 @@ +# LiveKit Server Deployment + +Deployment Guides: + +- [Deploy to a VM](https://docs.livekit.io/deploy/vm) +- [Deploy to Kubernetes](https://docs.livekit.io/deploy/kubernetes) + +Also included are Grafana charts for metrics gathered in Prometheus. diff --git a/livekit/deploy/grafana/livekit-server-overview.json b/livekit/deploy/grafana/livekit-server-overview.json new file mode 100644 index 0000000..24f3757 --- /dev/null +++ b/livekit/deploy/grafana/livekit-server-overview.json @@ -0,0 +1,531 @@ +{ + "__inputs": [], + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "8.2.2" + }, + { + "type": "panel", + "id": "timeseries", + "name": "Time series", + "version": "" + } + ], + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "gnetId": null, + "graphTooltip": 0, + "id": null, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": null, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "id": 4, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "exemplar": true, + "expr": "sum(livekit_room_total)", + "interval": "", + "legendFormat": "Rooms", + "refId": "A" + } + ], + "thresholds": [], + "title": "Rooms", + "type": "timeseries" + }, + { + "datasource": null, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "id": 6, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "exemplar": true, + "expr": "sum(livekit_participant_total)", + "interval": "", + "legendFormat": "Participants", + "refId": "A" + } + ], + "title": "Participants", + "type": "timeseries" + }, + { + "datasource": null, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 9, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "exemplar": true, + "expr": "sum(irate(livekit_node_messages{type=\"signal\"}[5m]))", + "interval": "", + "legendFormat": "Signal", + "refId": "Signal" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_node_messages{type=\"rtc\"}[5m]))", + "hide": false, + "interval": "", + "legendFormat": "RTC", + "refId": "RTC" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_node_messages{status=\"failure\"}[5m]))", + "hide": false, + "interval": "", + "legendFormat": "Failure", + "refId": "Failure" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_messagebus_messages{type=\"out\"}[5m]))", + "hide": false, + "interval": "", + "legendFormat": "Out", + "refId": "Out" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_messagebus_messages{type=\"out\", status=\"failure\"}[5m]))", + "hide": false, + "interval": "", + "legendFormat": "Out Failure", + "refId": "Out Failure" + } + ], + "thresholds": [ + { + "colorMode": "critical", + "op": "gt", + "value": 0, + "visible": true + } + ], + "title": "Message Rate", + "type": "timeseries" + }, + { + "datasource": null, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "id": 11, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "exemplar": true, + "expr": "sum(livekit_track_published_total)", + "interval": "", + "legendFormat": "Tracks", + "refId": "A" + } + ], + "title": "Tracks Published", + "type": "timeseries" + }, + { + "datasource": null, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "id": 13, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom" + }, + "tooltip": { + "mode": "single" + } + }, + "targets": [ + { + "exemplar": true, + "expr": "sum(irate(livekit_packet_total[5m]))", + "interval": "", + "legendFormat": "Packets", + "refId": "Packets" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_nack_total[5m]))", + "hide": false, + "interval": "", + "legendFormat": "NACK", + "refId": "NACK" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_pli_total[5m]))", + "hide": false, + "interval": "", + "legendFormat": "PLI", + "refId": "PLI" + }, + { + "exemplar": true, + "expr": "sum(irate(livekit_fir_total[5m]))", + "hide": false, + "interval": "", + "legendFormat": "FIR", + "refId": "FIR" + } + ], + "title": "Network Rate", + "type": "timeseries" + } + ], + "refresh": "5m", + "schemaVersion": 31, + "style": "dark", + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "now-1h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "LiveKit Server Overview", + "uid": "z_GO3t5nz", + "version": 2 +} \ No newline at end of file diff --git a/livekit/go.mod b/livekit/go.mod new file mode 100644 index 0000000..eb09152 --- /dev/null +++ b/livekit/go.mod @@ -0,0 +1,159 @@ +module github.com/livekit/livekit-server + +go 1.24.0 + +toolchain go1.24.6 + +require ( + github.com/bep/debounce v1.2.1 + github.com/d5/tengo/v2 v2.17.0 + github.com/dennwc/iters v1.2.2 + github.com/dustin/go-humanize v1.0.1 + github.com/elliotchance/orderedmap/v2 v2.7.0 + github.com/florianl/go-tc v0.4.5 + github.com/frostbyte73/core v0.1.1 + github.com/gammazero/deque v1.2.0 + github.com/gammazero/workerpool v1.1.3 + github.com/google/uuid v1.6.0 + github.com/google/wire v0.7.0 + github.com/gorilla/websocket v1.5.3 + github.com/hashicorp/go-version v1.7.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/jxskiss/base62 v1.1.0 + github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 + github.com/livekit/mediatransportutil v0.0.0-20260113174415-2e8ba344fca3 + github.com/livekit/protocol v1.43.5-0.20260114074149-a8bb8204ce69 + github.com/livekit/psrpc v0.7.1 + github.com/mackerelio/go-osstat v0.2.6 + github.com/magefile/mage v1.15.0 + github.com/maxbrunsfeld/counterfeiter/v6 v6.12.0 + github.com/mitchellh/go-homedir v1.1.0 + github.com/olekukonko/tablewriter v0.0.5 + github.com/ory/dockertest/v3 v3.12.0 + github.com/pion/datachannel v1.6.0 + github.com/pion/dtls/v3 v3.0.10 + github.com/pion/ice/v4 v4.2.0 + github.com/pion/interceptor v0.1.43 + github.com/pion/rtcp v1.2.16 + github.com/pion/rtp v1.10.0 + github.com/pion/sctp v1.9.2 + github.com/pion/sdp/v3 v3.0.17 + github.com/pion/transport/v4 v4.0.1 + github.com/pion/turn/v4 v4.1.4 + github.com/pion/webrtc/v4 v4.2.3 + github.com/pkg/errors v0.9.1 + github.com/prometheus/client_golang v1.23.0 + github.com/redis/go-redis/v9 v9.17.2 + github.com/rs/cors v1.11.1 + github.com/stretchr/testify v1.11.1 + github.com/thoas/go-funk v0.9.3 + github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 + github.com/twitchtv/twirp v8.1.3+incompatible + github.com/ua-parser/uap-go v0.0.0-20250326155420-f7f5a2f9f5bc + github.com/urfave/negroni/v3 v3.1.1 + go.uber.org/atomic v1.11.0 + go.uber.org/multierr v1.11.0 + go.uber.org/zap v1.27.1 + golang.org/x/exp v0.0.0-20260112195511-716be5621a96 + golang.org/x/mod v0.32.0 + golang.org/x/sync v0.19.0 + google.golang.org/protobuf v1.36.11 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-viper/mapstructure/v2 v2.1.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect + github.com/moby/sys/user v0.3.0 // indirect + github.com/nyaruka/phonenumbers v1.6.5 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel v1.39.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 // indirect + go.opentelemetry.io/otel/metric v1.39.0 // indirect + go.opentelemetry.io/otel/sdk v1.39.0 // indirect + go.opentelemetry.io/otel/trace v1.39.0 // indirect + go.opentelemetry.io/proto/otlp v1.9.0 // indirect + golang.org/x/time v0.14.0 // indirect +) + +require ( + buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20251209175733-2a1774d88802.1 // indirect + buf.build/go/protovalidate v1.1.0 // indirect + buf.build/go/protoyaml v0.6.0 // indirect + cel.dev/expr v0.25.1 // indirect + dario.cat/mergo v1.0.0 // indirect + github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect + github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect + github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/benbjohnson/clock v1.3.5 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/containerd/continuity v0.4.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/docker/cli v27.4.1+incompatible // indirect + github.com/docker/docker v27.1.1+incompatible // indirect + github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-units v0.5.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-jose/go-jose/v3 v3.0.4 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/cel-go v0.26.1 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/subcommands v1.2.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.7 // indirect + github.com/hashicorp/golang-lru v1.0.2 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/lithammer/shortuuid/v4 v4.2.0 // indirect + github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/mdlayher/netlink v1.7.1 // indirect + github.com/mdlayher/socket v0.4.0 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/term v0.5.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/nats-io/nats.go v1.48.0 // indirect + github.com/nats-io/nkeys v0.4.12 // indirect + github.com/nats-io/nuid v1.0.1 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect + github.com/opencontainers/runc v1.2.3 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/srtp/v3 v3.0.10 // indirect + github.com/pion/stun/v3 v3.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.65.0 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/stoewer/go-strcase v1.3.1 // indirect + github.com/urfave/cli/v3 v3.3.9 + github.com/wlynxg/anet v0.0.5 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect + go.uber.org/zap/exp v0.3.0 // indirect + golang.org/x/crypto v0.47.0 // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect + golang.org/x/tools v0.41.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260112192933-99fd39fd28a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260112192933-99fd39fd28a9 // indirect + google.golang.org/grpc v1.78.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/livekit/go.sum b/livekit/go.sum new file mode 100644 index 0000000..70bd6af --- /dev/null +++ b/livekit/go.sum @@ -0,0 +1,507 @@ +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20251209175733-2a1774d88802.1 h1:j9yeqTWEFrtimt8Nng2MIeRrpoCvQzM9/g25XTvqUGg= +buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20251209175733-2a1774d88802.1/go.mod h1:tvtbpgaVXZX4g6Pn+AnzFycuRK3MOz5HJfEGeEllXYM= +buf.build/go/protovalidate v1.1.0 h1:pQqEQRpOo4SqS60qkvmhLTTQU9JwzEvdyiqAtXa5SeY= +buf.build/go/protovalidate v1.1.0/go.mod h1:bGZcPiAQDC3ErCHK3t74jSoJDFOs2JH3d7LWuTEIdss= +buf.build/go/protoyaml v0.6.0 h1:Nzz1lvcXF8YgNZXk+voPPwdU8FjDPTUV4ndNTXN0n2w= +buf.build/go/protoyaml v0.6.0/go.mod h1:RgUOsBu/GYKLDSIRgQXniXbNgFlGEZnQpRAUdLAFV2Q= +cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= +dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= +github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= +github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o= +github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY= +github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0= +github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4= +github.com/brianvoe/gofakeit/v6 v6.28.0/go.mod h1:Xj58BMSnFqcn/fAQeSK+/PLtC5kSb7FJIq4JyGa8vEs= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cilium/ebpf v0.5.0/go.mod h1:4tRaxcgiL706VnOzHOdBlY8IEAIdxINsQBcU4xJJXRs= +github.com/cilium/ebpf v0.7.0/go.mod h1:/oI2+1shJiTGAMgl6/RgJr36Eo1jzrRcAWbcXO2usCA= +github.com/cilium/ebpf v0.8.1/go.mod h1:f5zLIM0FSNuAkSyLAN7X+Hy6yznlF1mNiWUMfxMtrgk= +github.com/cilium/ebpf v0.16.0 h1:+BiEnHL6Z7lXnlGUsXQPPAE7+kenAd4ES8MQ5min0Ok= +github.com/cilium/ebpf v0.16.0/go.mod h1:L7u2Blt2jMM/vLAVgjxluxtBKlz3/GWjB0dMOEngfwE= +github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4= +github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= +github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/d5/tengo/v2 v2.17.0 h1:BWUN9NoJzw48jZKiYDXDIF3QrIVZRm1uV1gTzeZ2lqM= +github.com/d5/tengo/v2 v2.17.0/go.mod h1:XRGjEs5I9jYIKTxly6HCF8oiiilk5E/RYXOZ5b0DZC8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dennwc/iters v1.2.2 h1:XH2/Etihiy9ZvPOVCR+icQXeYlhbvS7k0qro4x/2qQo= +github.com/dennwc/iters v1.2.2/go.mod h1:M9KuuMBeyEXYTmB7EnI9SCyALFCmPWOIxn5W1L0CjGg= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/docker/cli v27.4.1+incompatible h1:VzPiUlRJ/xh+otB75gva3r05isHMo5wXDfPRi5/b4hI= +github.com/docker/cli v27.4.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/docker v27.1.1+incompatible h1:hO/M4MtV36kzKldqnA37IWhebRA+LnqqcqDja6kVaKY= +github.com/docker/docker v27.1.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elliotchance/orderedmap/v2 v2.7.0 h1:WHuf0DRo63uLnldCPp9ojm3gskYwEdIIfAUVG5KhoOc= +github.com/elliotchance/orderedmap/v2 v2.7.0/go.mod h1:85lZyVbpGaGvHvnKa7Qhx7zncAdBIBq6u56Hb1PRU5Q= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/florianl/go-tc v0.4.5 h1:8lvecARs3c/vGee46j0ro8kco98ga9XjwWvXGwlzrXA= +github.com/florianl/go-tc v0.4.5/go.mod h1:uvp6pIlOw7Z8hhfnT5M4+V1hHVgZWRZwwMS8Z0JsRxc= +github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= +github.com/frankban/quicktest v1.14.0/go.mod h1:NeW+ay9A/U67EYXNFA1nPE8e/tnQv/09mUdL/ijj8og= +github.com/frostbyte73/core v0.1.1 h1:ChhJOR7bAKOCPbA+lqDLE2cGKlCG5JXsDvvQr4YaJIA= +github.com/frostbyte73/core v0.1.1/go.mod h1:mhfOtR+xWAvwXiwor7jnqPMnu4fxbv1F2MwZ0BEpzZo= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gammazero/deque v1.2.0 h1:scEFO8Uidhw6KDU5qg1HA5fYwM0+us2qdeJqm43bitU= +github.com/gammazero/deque v1.2.0/go.mod h1:JVrR+Bj1NMQbPnYclvDlvSX0nVGReLrQZ0aUMuWLctg= +github.com/gammazero/workerpool v1.1.3 h1:WixN4xzukFoN0XSeXF6puqEqFTl2mECI9S6W44HWy9Q= +github.com/gammazero/workerpool v1.1.3/go.mod h1:wPjyBLDbyKnUn2XwwyD3EEwo9dHutia9/fwNmSHWACc= +github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= +github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-viper/mapstructure/v2 v2.1.0 h1:gHnMa2Y/pIxElCH2GlZZ1lZSsn6XMtufpGyP1XxdC/w= +github.com/go-viper/mapstructure/v2 v2.1.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/cel-go v0.26.1 h1:iPbVVEdkhTX++hpe3lzSk7D3G3QSYqLGoHOcEio+UXQ= +github.com/google/cel-go v0.26.1/go.mod h1:A9O8OU9rdvrK5MQyrqfIxo1a0u4g3sF8KB6PUIaryMM= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= +github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= +github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= +github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= +github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= +github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= +github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= +github.com/josharian/native v0.0.0-20200817173448-b6b71def0850/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw= +github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ= +github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok= +github.com/jsimonetti/rtnetlink v0.0.0-20201216134343-bde56ed16391/go.mod h1:cR77jAZG3Y3bsb8hF6fHJbFoyFukLFOkQ98S0pQz3xw= +github.com/jsimonetti/rtnetlink v0.0.0-20201220180245-69540ac93943/go.mod h1:z4c53zj6Eex712ROyh8WI0ihysb5j2ROyV42iNogmAs= +github.com/jsimonetti/rtnetlink v0.0.0-20210122163228-8d122574c736/go.mod h1:ZXpIyOK59ZnN7J0BV99cZUPmsqDRZ3eq5X+st7u/oSA= +github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U= +github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo= +github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786 h1:N527AHMa793TP5z5GNAn/VLPzlc0ewzWdeP/25gDfgQ= +github.com/jsimonetti/rtnetlink v0.0.0-20211022192332-93da33804786/go.mod h1:v4hqbTdfQngbVSZJVWUhGE/lbTFf9jb+ygmNUDQMuOs= +github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw= +github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= +github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y= +github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 h1:9x+U2HGLrSw5ATTo469PQPkqzdoU7be46ryiCDO3boc= +github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731/go.mod h1:Rs3MhFwutWhGwmY1VQsygw28z5bWcnEYmS1OG9OxjOQ= +github.com/livekit/mediatransportutil v0.0.0-20260113174415-2e8ba344fca3 h1:v1Xc/q/547TjLX7Nw5y2vXNnmV0XYFAbhTJrtErQeDA= +github.com/livekit/mediatransportutil v0.0.0-20260113174415-2e8ba344fca3/go.mod h1:QBx/KHV6Vv00ggibg/WrOlqrkTciEA2Hc9DGWYr3Q9U= +github.com/livekit/protocol v1.43.5-0.20260114074149-a8bb8204ce69 h1:cD82r488SxGYL5MX1lLuLLjmdnNoC+u5TIepxQmSB40= +github.com/livekit/protocol v1.43.5-0.20260114074149-a8bb8204ce69/go.mod h1:BLJHYHErQTu3+fnmfGrzN6CbHxNYiooFIIYGYxXxotw= +github.com/livekit/psrpc v0.7.1 h1:ms37az0QTD3UXIWuUC5D/SkmKOlRMVRsI261eBWu/Vw= +github.com/livekit/psrpc v0.7.1/go.mod h1:bZ4iHFQptTkbPnB0LasvRNu/OBYXEu1NA6O5BMFo9kk= +github.com/mackerelio/go-osstat v0.2.6 h1:gs4U8BZeS1tjrL08tt5VUliVvSWP26Ai2Ob8Lr7f2i0= +github.com/mackerelio/go-osstat v0.2.6/go.mod h1:lRy8V9ZuHpuRVZh+vyTkODeDPl3/d5MgXHtLSaqG8bA= +github.com/magefile/mage v1.15.0 h1:BvGheCMAsG3bWUDbZ8AyXXpCNwU9u5CB6sM+HNb9HYg= +github.com/magefile/mage v1.15.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/maxbrunsfeld/counterfeiter/v6 v6.12.0 h1:aOeI7xAOVdK+R6xbVsZuU9HmCZYmQVmZgPf9xJUd2Sg= +github.com/maxbrunsfeld/counterfeiter/v6 v6.12.0/go.mod h1:0hZWbtfeCYUQeAQdPLUzETiBhUSns7O6LDj9vH88xKA= +github.com/mdlayher/ethtool v0.0.0-20210210192532-2b88debcdd43/go.mod h1:+t7E0lkKfbBsebllff1xdTmyJt8lH37niI6kwFk9OTo= +github.com/mdlayher/genetlink v1.0.0/go.mod h1:0rJ0h4itni50A86M2kHcgS85ttZazNt7a8H2a2cw0Gc= +github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= +github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= +github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY= +github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o= +github.com/mdlayher/netlink v1.2.0/go.mod h1:kwVW1io0AZy9A1E2YYgaD4Cj+C+GPkU6klXCMzIJ9p8= +github.com/mdlayher/netlink v1.2.1/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU= +github.com/mdlayher/netlink v1.2.2-0.20210123213345-5cc92139ae3e/go.mod h1:bacnNlfhqHqqLo4WsYeXSqfyXkInQ9JneWI68v1KwSU= +github.com/mdlayher/netlink v1.3.0/go.mod h1:xK/BssKuwcRXHrtN04UBkwQ6dY9VviGGuriDdoPSWys= +github.com/mdlayher/netlink v1.4.0/go.mod h1:dRJi5IABcZpBD2A3D0Mv/AiX8I9uDEu5oGkAVrekmf8= +github.com/mdlayher/netlink v1.4.1/go.mod h1:e4/KuJ+s8UhfUpO9z00/fDZZmhSrs+oxyqAS9cNgn6Q= +github.com/mdlayher/netlink v1.6.0/go.mod h1:0o3PlBmGst1xve7wQ7j/hwpNaFaH4qCRyWCdcZk8/vA= +github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTmg= +github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ= +github.com/mdlayher/socket v0.0.0-20210307095302-262dc9984e00/go.mod h1:GAFlyu4/XV68LkQKYzKhIo/WW7j3Zi0YRAz/BOoanUc= +github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs= +github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw= +github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/sys/user v0.3.0 h1:9ni5DlcW5an3SvRSx4MouotOygvzaXbaSrc/wGDFWPo= +github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U= +github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= +github.com/nats-io/nkeys v0.4.12 h1:nssm7JKOG9/x4J8II47VWCL1Ds29avyiQDRn0ckMvDc= +github.com/nats-io/nkeys v0.4.12/go.mod h1:MT59A1HYcjIcyQDJStTfaOY6vhy9XTUjOFo+SVsvpBg= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/nyaruka/phonenumbers v1.6.5 h1:aBCaUhfpRA7hU6fsXk+p7KF1aNx4nQlq9hGeo2qdFg8= +github.com/nyaruka/phonenumbers v1.6.5/go.mod h1:7gjs+Lchqm49adhAKB5cdcng5ZXgt6x7Jgvi0ZorUtU= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A= +github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/opencontainers/runc v1.2.3 h1:fxE7amCzfZflJO2lHXf4y/y8M1BoAqp+FVmG19oYB80= +github.com/opencontainers/runc v1.2.3/go.mod h1:nSxcWUydXrsBZVYNSkTjoQ/N6rcyTtn+1SD5D4+kRIM= +github.com/ory/dockertest/v3 v3.12.0 h1:3oV9d0sDzlSQfHtIaB5k6ghUCVMVLpAY8hwrqoCyRCw= +github.com/ory/dockertest/v3 v3.12.0/go.mod h1:aKNDTva3cp8dwOWwb9cWuX84aH5akkxXRvO7KCwWVjE= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v3 v3.0.10 h1:k9ekkq1kaZoxnNEbyLKI8DI37j/Nbk1HWmMuywpQJgg= +github.com/pion/dtls/v3 v3.0.10/go.mod h1:YEmmBYIoBsY3jmG56dsziTv/Lca9y4Om83370CXfqJ8= +github.com/pion/ice/v4 v4.2.0 h1:jJC8S+CvXCCvIQUgx+oNZnoUpt6zwc34FhjWwCU4nlw= +github.com/pion/ice/v4 v4.2.0/go.mod h1:EgjBGxDgmd8xB0OkYEVFlzQuEI7kWSCFu+mULqaisy4= +github.com/pion/interceptor v0.1.43 h1:6hmRfnmjogSs300xfkR0JxYFZ9k5blTEvCD7wxEDuNQ= +github.com/pion/interceptor v0.1.43/go.mod h1:BSiC1qKIJt1XVr3l3xQ2GEmCFStk9tx8fwtCZxxgR7M= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= +github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= +github.com/pion/rtp v1.10.0 h1:XN/xca4ho6ZEcijpdF2VGFbwuHUfiIMf3ew8eAAE43w= +github.com/pion/rtp v1.10.0/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM= +github.com/pion/sctp v1.9.2 h1:HxsOzEV9pWoeggv7T5kewVkstFNcGvhMPx0GvUOUQXo= +github.com/pion/sctp v1.9.2/go.mod h1:OTOlsQ5EDQ6mQ0z4MUGXt2CgQmKyafBEXhUVqLRB6G8= +github.com/pion/sdp/v3 v3.0.17 h1:9SfLAW/fF1XC8yRqQ3iWGzxkySxup4k4V7yN8Fs8nuo= +github.com/pion/sdp/v3 v3.0.17/go.mod h1:9tyKzznud3qiweZcD86kS0ff1pGYB3VX+Bcsmkx6IXo= +github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ= +github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M= +github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw= +github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ= +github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ= +github.com/pion/webrtc/v4 v4.2.3 h1:RtdWDnkenNQGxUrZqWa5gSkTm5ncsLg5d+zu0M4cXt4= +github.com/pion/webrtc/v4 v4.2.3/go.mod h1:7vsyFzRzaKP5IELUnj8zLcglPyIT6wWwqTppBZ1k6Kc= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= +github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= +github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= +github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= +github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= +github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= +github.com/rodaine/protogofakeit v0.1.1 h1:ZKouljuRM3A+TArppfBqnH8tGZHOwM/pjvtXe9DaXH8= +github.com/rodaine/protogofakeit v0.1.1/go.mod h1:pXn/AstBYMaSfc1/RqH3N82pBuxtWgejz1AlYpY1mI0= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/sclevine/spec v1.4.0 h1:z/Q9idDcay5m5irkZ28M7PtQM4aOISzOpj4bUPkDee8= +github.com/sclevine/spec v1.4.0/go.mod h1:LvpgJaFyvQzRvc1kaDs0bulYwzC70PbiYjC4QnFHkOM= +github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk= +github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stoewer/go-strcase v1.3.1 h1:iS0MdW+kVTxgMoE1LAZyMiYJFKlOzLooE4MxjirtkAs= +github.com/stoewer/go-strcase v1.3.1/go.mod h1:fAH5hQ5pehh+j3nZfvwdk2RgEgQjAoM8wodgtPmh1xo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw= +github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= +github.com/twitchtv/twirp v8.1.3+incompatible h1:+F4TdErPgSUbMZMwp13Q/KgDVuI7HJXP61mNV3/7iuU= +github.com/twitchtv/twirp v8.1.3+incompatible/go.mod h1:RRJoFSAmTEh2weEqWtpPE3vFK5YBhA6bqp2l1kfCC5A= +github.com/ua-parser/uap-go v0.0.0-20250326155420-f7f5a2f9f5bc h1:reH9QQKGFOq39MYOvU9+SYrB8uzXtWNo51fWK3g0gGc= +github.com/ua-parser/uap-go v0.0.0-20250326155420-f7f5a2f9f5bc/go.mod h1:gwANdYmo9R8LLwGnyDFWK2PMsaXXX2HhAvCnb/UhZsM= +github.com/urfave/cli/v3 v3.3.9 h1:54roEDJcTWuucl6MSQ3B+pQqt1ePh/xOQokhEYl5Gfs= +github.com/urfave/cli/v3 v3.3.9/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= +github.com/urfave/negroni/v3 v3.1.1 h1:6MS4nG9Jk/UuCACaUlNXCbiKa0ywF9LXz5dGu09v8hw= +github.com/urfave/negroni/v3 v3.1.1/go.mod h1:jWvnX03kcSjDBl/ShB0iHvx5uOs7mAzZXW+JvJ5XYAs= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= +github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 h1:Ckwye2FpXkYgiHX7fyVrN1uA/UYd9ounqqTuSNAv0k4= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0/go.mod h1:teIFJh5pW2y+AN7riv6IBPX2DuesS3HgP39mwOspKwU= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= +go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U= +go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201218084310-7d0127a74742/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20260112192933-99fd39fd28a9 h1:4DKBrmaqeptdEzp21EfrOEh8LE7PJ5ywH6wydSbOfGY= +google.golang.org/genproto/googleapis/api v0.0.0-20260112192933-99fd39fd28a9/go.mod h1:dd646eSK+Dk9kxVBl1nChEOhJPtMXriCcVb4x3o6J+E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260112192933-99fd39fd28a9 h1:IY6/YYRrFUk0JPp0xOVctvFIVuRnjccihY5kxf5g0TE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260112192933-99fd39fd28a9/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= +google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/livekit/install-livekit.sh b/livekit/install-livekit.sh new file mode 100755 index 0000000..e0b243a --- /dev/null +++ b/livekit/install-livekit.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash +# Copyright 2023 LiveKit, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# LiveKit install script for Linux + +set -u +set -o errtrace +set -o errexit +set -o pipefail + +REPO="livekit" +INSTALL_PATH="/usr/local/bin" + +log() { printf "%b\n" "$*"; } +abort() { + printf "%s\n" "$@" >&2 + exit 1 +} + +# returns the latest version according to GH +# i.e. 1.0.0 +get_latest_version() +{ + latest_version=$(curl -s https://api.github.com/repos/livekit/$REPO/releases/latest | grep -oP '"tarball_url": ".*/tarball/v\K([^/]*)(?=")') + printf "%s" "$latest_version" +} + +# Ensure bash is used +if [ -z "${BASH_VERSION:-}" ] +then + abort "This script requires bash" +fi + +# Check if $INSTALL_PATH exists +if [ ! -d ${INSTALL_PATH} ] +then + abort "Could not install, ${INSTALL_PATH} doesn't exist" +fi + +# Needs SUDO if no permissions to write +SUDO_PREFIX="" +if [ ! -w ${INSTALL_PATH} ] +then + SUDO_PREFIX="sudo" + log "sudo is required to install to ${INSTALL_PATH}" +fi + +# Check cURL is installed +if ! command -v curl >/dev/null +then + abort "cURL is required and is not found" +fi + +# OS check +OS="$(uname)" +if [[ "${OS}" == "Darwin" ]] +then + abort "Installer not supported on MacOS, please install using Homebrew." +elif [[ "${OS}" != "Linux" ]] +then + abort "Installer is only supported on Linux." +fi + +ARCH="$(uname -m)" + +# fix arch on linux +if [[ "${ARCH}" == "aarch64" ]] +then + ARCH="arm64" +elif [[ "${ARCH}" == "x86_64" ]] +then + ARCH="amd64" +fi + +VERSION=$(get_latest_version) +ARCHIVE_URL="https://github.com/livekit/$REPO/releases/download/v${VERSION}/${REPO}_${VERSION}_linux_${ARCH}.tar.gz" + +# Ensure version follows SemVer +if ! [[ "${VERSION}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]] +then + abort "Invalid version: ${VERSION}" +fi + +log "Installing ${REPO} ${VERSION}" +log "Downloading from ${ARCHIVE_URL}..." + +curl -s -L "${ARCHIVE_URL}" | ${SUDO_PREFIX} tar xzf - -C "${INSTALL_PATH}" --wildcards --no-anchored "$REPO*" + +log "\nlivekit-server is installed to $INSTALL_PATH\n" diff --git a/livekit/magefile.go b/livekit/magefile.go new file mode 100644 index 0000000..4b924b4 --- /dev/null +++ b/livekit/magefile.go @@ -0,0 +1,233 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build mage +// +build mage + +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/magefile/mage/mg" + + "github.com/livekit/livekit-server/version" + "github.com/livekit/mageutil" + _ "github.com/livekit/psrpc" +) + +const ( + goChecksumFile = ".checksumgo" + imageName = "livekit/livekit-server" +) + +// Default target to run when none is specified +// If not set, running mage will list available targets +var ( + Default = Build + checksummer = mageutil.NewChecksummer(".", goChecksumFile, ".go", ".mod") +) + +func init() { + checksummer.IgnoredPaths = []string{ + "pkg/service/wire_gen.go", + "pkg/rtc/types/typesfakes", + } +} + +// explicitly reinstall all deps +func Deps() error { + return installTools(true) +} + +// builds LiveKit server +func Build() error { + mg.Deps(generateWire) + if !checksummer.IsChanged() { + fmt.Println("up to date") + return nil + } + + fmt.Println("building...") + if err := os.MkdirAll("bin", 0755); err != nil { + return err + } + if err := mageutil.RunDir(context.Background(), "cmd/server", "go build -o ../../bin/livekit-server"); err != nil { + return err + } + + checksummer.WriteChecksum() + return nil +} + +// builds binary that runs on linux +func BuildLinux() error { + mg.Deps(generateWire) + if !checksummer.IsChanged() { + fmt.Println("up to date") + return nil + } + + fmt.Println("building...") + if err := os.MkdirAll("bin", 0755); err != nil { + return err + } + buildArch := os.Getenv("GOARCH") + if len(buildArch) == 0 { + buildArch = "amd64" + } + cmd := mageutil.CommandDir(context.Background(), "cmd/server", "go build -buildvcs=false -o ../../bin/livekit-server-" + buildArch) + cmd.Env = []string{ + "GOOS=linux", + "GOARCH=" + buildArch, + "HOME=" + os.Getenv("HOME"), + "GOPATH=" + os.Getenv("GOPATH"), + } + if err := cmd.Run(); err != nil { + return err + } + + checksummer.WriteChecksum() + return nil +} + +func Deadlock() error { + ctx := context.Background() + if err := mageutil.InstallTool("golang.org/x/tools/cmd/goimports", "latest", false); err != nil { + return err + } + if err := mageutil.Run(ctx, "go get github.com/sasha-s/go-deadlock"); err != nil { + return err + } + if err := mageutil.Pipe("grep -rl sync.Mutex ./pkg", "xargs sed -i -e s/sync.Mutex/deadlock.Mutex/g"); err != nil { + return err + } + if err := mageutil.Pipe("grep -rl sync.RWMutex ./pkg", "xargs sed -i -e s/sync.RWMutex/deadlock.RWMutex/g"); err != nil { + return err + } + if err := mageutil.Pipe("grep -rl deadlock.Mutex\\|deadlock.RWMutex ./pkg", "xargs goimports -w"); err != nil { + return err + } + if err := mageutil.Run(ctx, "go mod tidy"); err != nil { + return err + } + return nil +} + +func Sync() error { + if err := mageutil.Pipe("grep -rl deadlock.Mutex ./pkg", "xargs sed -i -e s/deadlock.Mutex/sync.Mutex/g"); err != nil { + return err + } + if err := mageutil.Pipe("grep -rl deadlock.RWMutex ./pkg", "xargs sed -i -e s/deadlock.RWMutex/sync.RWMutex/g"); err != nil { + return err + } + if err := mageutil.Pipe("grep -rl sync.Mutex\\|sync.RWMutex ./pkg", "xargs goimports -w"); err != nil { + return err + } + if err := mageutil.Run(context.Background(), "go mod tidy"); err != nil { + return err + } + return nil +} + +// builds and publish snapshot docker image +func PublishDocker() error { + // don't publish snapshot versions as latest or minor version + if !strings.Contains(version.Version, "SNAPSHOT") { + return errors.New("Cannot publish non-snapshot versions") + } + + versionImg := fmt.Sprintf("%s:v%s", imageName, version.Version) + cmd := exec.Command("docker", "buildx", "build", + "--push", "--platform", "linux/amd64,linux/arm64", + "--tag", versionImg, + ".") + mageutil.ConnectStd(cmd) + if err := cmd.Run(); err != nil { + return err + } + return nil +} + +// run unit tests, skipping integration +func Test() error { + mg.Deps(generateWire, setULimit) + return mageutil.Run(context.Background(), "go test -short ./... -count=1") +} + +// run all tests including integration +func TestAll() error { + mg.Deps(generateWire, setULimit) + return mageutil.Run(context.Background(), "go test ./... -count=1 -timeout=4m -v") +} + +// cleans up builds +func Clean() { + fmt.Println("cleaning...") + os.RemoveAll("bin") + os.Remove(goChecksumFile) +} + +// regenerate code +func Generate() error { + mg.Deps(installDeps, generateWire) + + fmt.Println("generating...") + return mageutil.Run(context.Background(), "go generate ./...") +} + +// code generation for wiring +func generateWire() error { + mg.Deps(installDeps) + if !checksummer.IsChanged() { + return nil + } + + fmt.Println("wiring...") + + wire, err := mageutil.GetToolPath("wire") + if err != nil { + return err + } + cmd := exec.Command(wire) + cmd.Dir = "pkg/service" + mageutil.ConnectStd(cmd) + if err := cmd.Run(); err != nil { + return err + } + + return nil +} + +// implicitly install deps +func installDeps() error { + return installTools(false) +} + +func installTools(force bool) error { + tools := map[string]string{ + "github.com/google/wire/cmd/wire": "latest", + } + for t, v := range tools { + if err := mageutil.InstallTool(t, v, force); err != nil { + return err + } + } + return nil +} diff --git a/livekit/magefile_unix.go b/livekit/magefile_unix.go new file mode 100644 index 0000000..186f6f4 --- /dev/null +++ b/livekit/magefile_unix.go @@ -0,0 +1,34 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build mage && !windows +// +build mage,!windows + +package main + +import ( + "syscall" +) + +func setULimit() error { + // raise ulimit on unix + var rLimit syscall.Rlimit + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + if err != nil { + return err + } + rLimit.Max = 10000 + rLimit.Cur = 10000 + return syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) +} diff --git a/livekit/magefile_windows.go b/livekit/magefile_windows.go new file mode 100644 index 0000000..9e25fe7 --- /dev/null +++ b/livekit/magefile_windows.go @@ -0,0 +1,22 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build mage +// +build mage + +package main + +func setULimit() error { + return nil +} diff --git a/livekit/pkg/agent/agent_test.go b/livekit/pkg/agent/agent_test.go new file mode 100644 index 0000000..40ae7a7 --- /dev/null +++ b/livekit/pkg/agent/agent_test.go @@ -0,0 +1,190 @@ +package agent_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/agent/testutils" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/utils/must" + "github.com/livekit/psrpc" +) + +func TestAgent(t *testing.T) { + testAgentName := "test_agent" + t.Run("dispatched jobs are assigned to a worker", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + server := testutils.NewTestServer(bus) + t.Cleanup(server.Close) + + worker := server.SimulateAgentWorker() + worker.Register(testAgentName, livekit.JobType_JT_ROOM) + jobAssignments := worker.JobAssignments.Observe() + + job := &livekit.Job{ + Id: guid.New(guid.AgentJobPrefix), + DispatchId: guid.New(guid.AgentDispatchPrefix), + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{}, + AgentName: testAgentName, + } + _, err := client.JobRequest(context.Background(), testAgentName, agent.RoomAgentTopic, job) + require.NoError(t, err) + + select { + case a := <-jobAssignments.Events(): + require.EqualValues(t, job.Id, a.Job.Id) + v, err := auth.ParseAPIToken(a.Token) + require.NoError(t, err) + _, claims, err := v.Verify(server.TestAPISecret) + require.NoError(t, err) + require.Equal(t, testAgentName, claims.Attributes[agent.AgentNameAttributeKey]) + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + }) +} + +func testBatchJobRequest(t require.TestingT, batchSize int, totalJobs int, client rpc.AgentInternalClient, workers []*testutils.AgentWorker) <-chan struct{} { + var assigned atomic.Uint32 + done := make(chan struct{}) + for _, w := range workers { + assignments := w.JobAssignments.Observe() + go func() { + defer assignments.Stop() + for { + select { + case <-done: + case <-assignments.Events(): + if assigned.Inc() == uint32(totalJobs) { + close(done) + } + } + } + }() + } + + // wait for agent registration + time.Sleep(100 * time.Millisecond) + + var wg sync.WaitGroup + for i := 0; i < totalJobs; i += batchSize { + wg.Add(1) + go func(start int) { + defer wg.Done() + for j := start; j < start+batchSize && j < totalJobs; j++ { + job := &livekit.Job{ + Id: guid.New(guid.AgentJobPrefix), + DispatchId: guid.New(guid.AgentDispatchPrefix), + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{}, + AgentName: "test", + } + _, err := client.JobRequest(context.Background(), "test", agent.RoomAgentTopic, job) + require.NoError(t, err) + } + }(i) + } + wg.Wait() + + return done +} + +func TestAgentLoadBalancing(t *testing.T) { + t.Run("jobs are distributed normally with baseline worker load", func(t *testing.T) { + totalWorkers := 5 + totalJobs := 100 + + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + t.Cleanup(client.Close) + server := testutils.NewTestServer(bus) + t.Cleanup(server.Close) + + agents := make([]*testutils.AgentWorker, totalWorkers) + for i := range totalWorkers { + agents[i] = server.SimulateAgentWorker( + testutils.WithLabel(fmt.Sprintf("agent-%d", i)), + testutils.WithJobLoad(testutils.NewStableJobLoad(0.01)), + ) + agents[i].Register("test", livekit.JobType_JT_ROOM) + } + + select { + case <-testBatchJobRequest(t, 10, totalJobs, client, agents): + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + + jobCount := make(map[string]int) + for _, w := range agents { + jobCount[w.Label] = len(w.Jobs()) + } + + // check that jobs are distributed normally + for i := range totalWorkers { + label := fmt.Sprintf("agent-%d", i) + require.GreaterOrEqual(t, jobCount[label], 0) + require.Less(t, jobCount[label], 35) // three std deviations from the mean is 32 + } + }) + + t.Run("jobs are distributed with variable and overloaded worker load", func(t *testing.T) { + totalWorkers := 4 + totalJobs := 15 + + bus := psrpc.NewLocalMessageBus() + + client := must.Get(rpc.NewAgentInternalClient(bus)) + t.Cleanup(client.Close) + server := testutils.NewTestServer(bus) + t.Cleanup(server.Close) + + agents := make([]*testutils.AgentWorker, totalWorkers) + for i := range totalWorkers { + label := fmt.Sprintf("agent-%d", i) + if i%2 == 0 { + // make sure we have some workers that can accept jobs + agents[i] = server.SimulateAgentWorker(testutils.WithLabel(label)) + } else { + agents[i] = server.SimulateAgentWorker(testutils.WithLabel(label), testutils.WithDefaultWorkerLoad(0.9)) + } + agents[i].Register("test", livekit.JobType_JT_ROOM) + } + + select { + case <-testBatchJobRequest(t, 1, totalJobs, client, agents): + case <-time.After(time.Second): + require.Fail(t, "job assignment timeout") + } + + jobCount := make(map[string]int) + for _, w := range agents { + jobCount[w.Label] = len(w.Jobs()) + } + + for i := range totalWorkers { + label := fmt.Sprintf("agent-%d", i) + + if i%2 == 0 { + require.GreaterOrEqual(t, jobCount[label], 2) + } else { + require.Equal(t, 0, jobCount[label]) + } + require.GreaterOrEqual(t, jobCount[label], 0) + } + }) +} diff --git a/livekit/pkg/agent/client.go b/livekit/pkg/agent/client.go new file mode 100644 index 0000000..4a46525 --- /dev/null +++ b/livekit/pkg/agent/client.go @@ -0,0 +1,339 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agent + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gammazero/workerpool" + "google.golang.org/protobuf/types/known/emptypb" + + serverutils "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" +) + +const ( + EnabledCacheTTL = 1 * time.Minute + RoomAgentTopic = "room" + PublisherAgentTopic = "publisher" + ParticipantAgentTopic = "participant" + DefaultHandlerNamespace = "" + + CheckEnabledTimeout = 5 * time.Second +) + +var jobTypeTopics = map[livekit.JobType]string{ + livekit.JobType_JT_ROOM: RoomAgentTopic, + livekit.JobType_JT_PUBLISHER: PublisherAgentTopic, + livekit.JobType_JT_PARTICIPANT: ParticipantAgentTopic, +} + +type Client interface { + // LaunchJob starts a room or participant job on an agent. + // it will launch a job once for each worker in each namespace + LaunchJob(ctx context.Context, desc *JobRequest) *serverutils.IncrementalDispatcher[*livekit.Job] + TerminateJob(ctx context.Context, jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) + Stop() error +} + +type JobRequest struct { + DispatchId string + JobType livekit.JobType + Room *livekit.Room + // only set for participant jobs + Participant *livekit.ParticipantInfo + Metadata string + AgentName string +} + +type agentClient struct { + client rpc.AgentInternalClient + config Config + + mu sync.RWMutex + + // cache response to avoid constantly checking with controllers + // cache is invalidated with AgentRegistered updates + roomNamespaces *serverutils.IncrementalDispatcher[string] // deprecated + publisherNamespaces *serverutils.IncrementalDispatcher[string] // deprecated + participantNamespaces *serverutils.IncrementalDispatcher[string] // deprecated + roomAgentNames *serverutils.IncrementalDispatcher[string] + publisherAgentNames *serverutils.IncrementalDispatcher[string] + participantAgentNames *serverutils.IncrementalDispatcher[string] + + enabledExpiresAt time.Time + + workers *workerpool.WorkerPool + + invalidateSub psrpc.Subscription[*emptypb.Empty] + subDone chan struct{} +} + +func NewAgentClient(bus psrpc.MessageBus, config Config) (Client, error) { + client, err := rpc.NewAgentInternalClient(bus) + if err != nil { + return nil, err + } + + c := &agentClient{ + client: client, + config: config, + workers: workerpool.New(50), + subDone: make(chan struct{}), + } + + sub, err := c.client.SubscribeWorkerRegistered(context.Background(), DefaultHandlerNamespace) + if err != nil { + return nil, err + } + + c.invalidateSub = sub + + go func() { + // invalidate cache + for range sub.Channel() { + c.mu.Lock() + c.roomNamespaces = nil + c.publisherNamespaces = nil + c.participantNamespaces = nil + c.roomAgentNames = nil + c.publisherAgentNames = nil + c.participantAgentNames = nil + c.mu.Unlock() + } + + c.subDone <- struct{}{} + }() + + return c, nil +} + +func (c *agentClient) LaunchJob(ctx context.Context, desc *JobRequest) *serverutils.IncrementalDispatcher[*livekit.Job] { + var wg sync.WaitGroup + ret := serverutils.NewIncrementalDispatcher[*livekit.Job]() + defer func() { + c.workers.Submit(func() { + wg.Wait() + ret.Done() + }) + }() + + jobTypeTopic, ok := jobTypeTopics[desc.JobType] + if !ok { + return ret + } + + dispatcher := c.getDispatcher(desc.AgentName, desc.JobType) + + if dispatcher == nil { + logger.Infow("not dispatching agent job since no worker is available", + "agentName", desc.AgentName, + "jobType", desc.JobType, + "room", desc.Room.Name, + "roomID", desc.Room.Sid) + return ret + } + + dispatcher.ForEach(func(curNs string) { + topic := GetAgentTopic(desc.AgentName, curNs) + + wg.Add(1) + c.workers.Submit(func() { + defer wg.Done() + // The cached agent parameters do not provide the exact combination of available job type/agent name/namespace, so some of the JobRequest RPC may not trigger any worker + job := &livekit.Job{ + Id: utils.NewGuid(utils.AgentJobPrefix), + DispatchId: desc.DispatchId, + Type: desc.JobType, + Room: desc.Room, + Participant: desc.Participant, + Namespace: curNs, + AgentName: desc.AgentName, + Metadata: desc.Metadata, + EnableRecording: c.config.EnableUserDataRecording, + } + resp, err := c.client.JobRequest(context.Background(), topic, jobTypeTopic, job) + if err != nil { + logger.Infow("failed to send job request", "error", err, "namespace", curNs, "jobType", desc.JobType, "agentName", desc.AgentName) + return + } + job.State = resp.State + ret.Add(job) + }) + }) + + return ret +} + +func (c *agentClient) TerminateJob(ctx context.Context, jobID string, reason rpc.JobTerminateReason) (*livekit.JobState, error) { + resp, err := c.client.JobTerminate(context.Background(), jobID, &rpc.JobTerminateRequest{ + JobId: jobID, + Reason: reason, + }) + if err != nil { + logger.Infow("failed to send job request", "error", err, "jobID", jobID) + return nil, err + } + + return resp.State, nil +} + +func (c *agentClient) getDispatcher(agName string, jobType livekit.JobType) *serverutils.IncrementalDispatcher[string] { + c.mu.Lock() + + if time.Since(c.enabledExpiresAt) > EnabledCacheTTL || c.roomNamespaces == nil || + c.publisherNamespaces == nil || c.participantNamespaces == nil || c.roomAgentNames == nil || c.publisherAgentNames == nil || c.participantAgentNames == nil { + c.enabledExpiresAt = time.Now() + c.roomNamespaces = serverutils.NewIncrementalDispatcher[string]() + c.publisherNamespaces = serverutils.NewIncrementalDispatcher[string]() + c.participantNamespaces = serverutils.NewIncrementalDispatcher[string]() + c.roomAgentNames = serverutils.NewIncrementalDispatcher[string]() + c.publisherAgentNames = serverutils.NewIncrementalDispatcher[string]() + c.participantAgentNames = serverutils.NewIncrementalDispatcher[string]() + + go c.checkEnabled(c.roomNamespaces, c.publisherNamespaces, c.participantNamespaces, c.roomAgentNames, c.publisherAgentNames, c.participantAgentNames) + } + + var target *serverutils.IncrementalDispatcher[string] + var agentNames *serverutils.IncrementalDispatcher[string] + switch jobType { + case livekit.JobType_JT_ROOM: + target = c.roomNamespaces + agentNames = c.roomAgentNames + case livekit.JobType_JT_PUBLISHER: + target = c.publisherNamespaces + agentNames = c.publisherAgentNames + case livekit.JobType_JT_PARTICIPANT: + target = c.participantNamespaces + agentNames = c.participantAgentNames + } + c.mu.Unlock() + + if agName == "" { + // if no agent name is given, we would need to dispatch backwards compatible mode + // which means dispatching to each of the namespaces + return target + } + + done := make(chan *serverutils.IncrementalDispatcher[string], 1) + c.workers.Submit(func() { + agentNames.ForEach(func(ag string) { + if ag == agName { + select { + case done <- target: + default: + } + } + }) + select { + case done <- nil: + default: + } + }) + + return <-done +} + +func (c *agentClient) checkEnabled(roomNamespaces, publisherNamespaces, participantNamespaces, roomAgentNames, publisherAgentNames, participantAgentNames *serverutils.IncrementalDispatcher[string]) { + defer roomNamespaces.Done() + defer publisherNamespaces.Done() + defer participantNamespaces.Done() + defer roomAgentNames.Done() + defer publisherAgentNames.Done() + defer participantAgentNames.Done() + + resChan, err := c.client.CheckEnabled(context.Background(), &rpc.CheckEnabledRequest{}, psrpc.WithRequestTimeout(CheckEnabledTimeout)) + if err != nil { + logger.Errorw("failed to check enabled", err) + return + } + + roomNSMap := make(map[string]bool) + publisherNSMap := make(map[string]bool) + participantNSMap := make(map[string]bool) + roomAgMap := make(map[string]bool) + publisherAgMap := make(map[string]bool) + participantAgMap := make(map[string]bool) + + for r := range resChan { + if r.Result.GetRoomEnabled() { + for _, ns := range r.Result.GetNamespaces() { + if _, ok := roomNSMap[ns]; !ok { + roomNamespaces.Add(ns) + roomNSMap[ns] = true + } + } + for _, ag := range r.Result.GetAgentNames() { + if _, ok := roomAgMap[ag]; !ok { + roomAgentNames.Add(ag) + roomAgMap[ag] = true + } + } + } + if r.Result.GetPublisherEnabled() { + for _, ns := range r.Result.GetNamespaces() { + if _, ok := publisherNSMap[ns]; !ok { + publisherNamespaces.Add(ns) + publisherNSMap[ns] = true + } + } + for _, ag := range r.Result.GetAgentNames() { + if _, ok := publisherAgMap[ag]; !ok { + publisherAgentNames.Add(ag) + publisherAgMap[ag] = true + } + } + } + if r.Result.GetParticipantEnabled() { + for _, ns := range r.Result.GetNamespaces() { + if _, ok := participantNSMap[ns]; !ok { + participantNamespaces.Add(ns) + participantNSMap[ns] = true + } + } + for _, ag := range r.Result.GetAgentNames() { + if _, ok := participantAgMap[ag]; !ok { + participantAgentNames.Add(ag) + participantAgMap[ag] = true + } + } + } + } +} + +func (c *agentClient) Stop() error { + _ = c.invalidateSub.Close() + <-c.subDone + return nil +} + +func GetAgentTopic(agentName, namespace string) string { + if agentName == "" { + // Backward compatibility + return namespace + } else if namespace == "" { + // Forward compatibility once the namespace field is removed from the worker SDK + return agentName + } else { + return fmt.Sprintf("%s_%s", agentName, namespace) + } +} diff --git a/livekit/pkg/agent/config.go b/livekit/pkg/agent/config.go new file mode 100644 index 0000000..c4d0fa8 --- /dev/null +++ b/livekit/pkg/agent/config.go @@ -0,0 +1,5 @@ +package agent + +type Config struct { + EnableUserDataRecording bool `yaml:"enable_user_data_recording"` +} diff --git a/livekit/pkg/agent/testutils/server.go b/livekit/pkg/agent/testutils/server.go new file mode 100644 index 0000000..3eb2b0e --- /dev/null +++ b/livekit/pkg/agent/testutils/server.go @@ -0,0 +1,485 @@ +package testutils + +import ( + "context" + "errors" + "io" + "math" + "math/rand/v2" + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/gammazero/deque" + "golang.org/x/exp/maps" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/events" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/utils/must" + "github.com/livekit/protocol/utils/options" + "github.com/livekit/psrpc" +) + +type AgentService interface { + HandleConnection(context.Context, agent.SignalConn, agent.WorkerRegistration) + DrainConnections(time.Duration) +} + +type TestServer struct { + AgentService + TestAPIKey string + TestAPISecret string +} + +func NewTestServer(bus psrpc.MessageBus) *TestServer { + localNode, _ := routing.NewLocalNode(nil) + return NewTestServerWithService(must.Get(service.NewAgentService( + &config.Config{Region: "test"}, + localNode, + bus, + auth.NewSimpleKeyProvider("test", "verysecretsecret"), + ))) +} + +func NewTestServerWithService(s AgentService) *TestServer { + return &TestServer{s, "test", "verysecretsecret"} +} + +type SimulatedWorkerOptions struct { + Context context.Context + Label string + SupportResume bool + DefaultJobLoad float32 + JobLoadThreshold float32 + DefaultWorkerLoad float32 + HandleAvailability func(AgentJobRequest) + HandleAssignment func(*livekit.Job) JobLoad +} + +type SimulatedWorkerOption func(*SimulatedWorkerOptions) + +func WithContext(ctx context.Context) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.Context = ctx + } +} + +func WithLabel(label string) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.Label = label + } +} + +func WithJobAvailabilityHandler(h func(AgentJobRequest)) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.HandleAvailability = h + } +} + +func WithJobAssignmentHandler(h func(*livekit.Job) JobLoad) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.HandleAssignment = h + } +} + +func WithJobLoad(l JobLoad) SimulatedWorkerOption { + return WithJobAssignmentHandler(func(j *livekit.Job) JobLoad { return l }) +} + +func WithDefaultWorkerLoad(load float32) SimulatedWorkerOption { + return func(o *SimulatedWorkerOptions) { + o.DefaultWorkerLoad = load + } +} + +func (h *TestServer) SimulateAgentWorker(opts ...SimulatedWorkerOption) *AgentWorker { + o := &SimulatedWorkerOptions{ + Context: context.Background(), + Label: guid.New("TEST_AGENT_"), + DefaultJobLoad: 0.1, + JobLoadThreshold: 0.8, + DefaultWorkerLoad: 0.0, + HandleAvailability: func(r AgentJobRequest) { r.Accept() }, + HandleAssignment: func(j *livekit.Job) JobLoad { return nil }, + } + options.Apply(o, opts) + + w := &AgentWorker{ + workerMessages: make(chan *livekit.WorkerMessage, 1), + jobs: map[string]*AgentJob{}, + SimulatedWorkerOptions: o, + + RegisterWorkerResponses: events.NewObserverList[*livekit.RegisterWorkerResponse](), + AvailabilityRequests: events.NewObserverList[*livekit.AvailabilityRequest](), + JobAssignments: events.NewObserverList[*livekit.JobAssignment](), + JobTerminations: events.NewObserverList[*livekit.JobTermination](), + WorkerPongs: events.NewObserverList[*livekit.WorkerPong](), + } + w.ctx, w.cancel = context.WithCancel(context.Background()) + + if o.DefaultWorkerLoad > 0.0 { + w.sendStatus() + } + + ctx := service.WithAPIKey(o.Context, &auth.ClaimGrants{}, "test") + go h.HandleConnection(ctx, w, agent.MakeWorkerRegistration()) + + return w +} + +func (h *TestServer) Close() { + h.DrainConnections(1) +} + +var _ agent.SignalConn = (*AgentWorker)(nil) + +type JobLoad interface { + Load() float32 +} + +type AgentJob struct { + *livekit.Job + JobLoad +} + +type AgentJobRequest struct { + w *AgentWorker + *livekit.AvailabilityRequest +} + +func (r AgentJobRequest) Accept() { + identity := guid.New("PI_") + r.w.SendAvailability(&livekit.AvailabilityResponse{ + JobId: r.Job.Id, + Available: true, + SupportsResume: r.w.SupportResume, + ParticipantName: identity, + ParticipantIdentity: identity, + }) +} + +func (r AgentJobRequest) Reject() { + r.w.SendAvailability(&livekit.AvailabilityResponse{ + JobId: r.Job.Id, + Available: false, + }) +} + +type AgentWorker struct { + *SimulatedWorkerOptions + + fuse core.Fuse + mu sync.Mutex + ctx context.Context + cancel context.CancelFunc + workerMessages chan *livekit.WorkerMessage + serverMessages deque.Deque[*livekit.ServerMessage] + jobs map[string]*AgentJob + + RegisterWorkerResponses *events.ObserverList[*livekit.RegisterWorkerResponse] + AvailabilityRequests *events.ObserverList[*livekit.AvailabilityRequest] + JobAssignments *events.ObserverList[*livekit.JobAssignment] + JobTerminations *events.ObserverList[*livekit.JobTermination] + WorkerPongs *events.ObserverList[*livekit.WorkerPong] +} + +func (w *AgentWorker) statusWorker() { + t := time.NewTicker(2 * time.Second) + defer t.Stop() + + for !w.fuse.IsBroken() { + w.sendStatus() + <-t.C + } +} + +func (w *AgentWorker) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + w.fuse.Break() + return nil +} + +func (w *AgentWorker) SetReadDeadline(t time.Time) error { + w.mu.Lock() + defer w.mu.Unlock() + if !w.fuse.IsBroken() { + cancel := w.cancel + if t.IsZero() { + w.ctx, w.cancel = context.WithCancel(context.Background()) + } else { + w.ctx, w.cancel = context.WithDeadline(context.Background(), t) + } + cancel() + } + return nil +} + +func (w *AgentWorker) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) { + for { + w.mu.Lock() + ctx := w.ctx + w.mu.Unlock() + + select { + case <-w.fuse.Watch(): + return nil, 0, io.EOF + case <-ctx.Done(): + if err := ctx.Err(); errors.Is(err, context.DeadlineExceeded) { + return nil, 0, err + } + case m := <-w.workerMessages: + return m, 0, nil + } + } +} + +func (w *AgentWorker) WriteServerMessage(m *livekit.ServerMessage) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + w.serverMessages.PushBack(m) + if w.serverMessages.Len() == 1 { + go w.handleServerMessages() + } + return 0, nil +} + +func (w *AgentWorker) handleServerMessages() { + w.mu.Lock() + for w.serverMessages.Len() != 0 { + m := w.serverMessages.Front() + w.mu.Unlock() + + switch m := m.Message.(type) { + case *livekit.ServerMessage_Register: + w.handleRegister(m.Register) + case *livekit.ServerMessage_Availability: + w.handleAvailability(m.Availability) + case *livekit.ServerMessage_Assignment: + w.handleAssignment(m.Assignment) + case *livekit.ServerMessage_Termination: + w.handleTermination(m.Termination) + case *livekit.ServerMessage_Pong: + w.handlePong(m.Pong) + } + + w.mu.Lock() + w.serverMessages.PopFront() + } + w.mu.Unlock() +} + +func (w *AgentWorker) handleRegister(m *livekit.RegisterWorkerResponse) { + w.RegisterWorkerResponses.Emit(m) +} + +func (w *AgentWorker) handleAvailability(m *livekit.AvailabilityRequest) { + w.AvailabilityRequests.Emit(m) + if w.HandleAvailability != nil { + w.HandleAvailability(AgentJobRequest{w, m}) + } else { + AgentJobRequest{w, m}.Accept() + } +} + +func (w *AgentWorker) handleAssignment(m *livekit.JobAssignment) { + w.JobAssignments.Emit(m) + + var load JobLoad + if w.HandleAssignment != nil { + load = w.HandleAssignment(m.Job) + } + + if load == nil { + load = NewStableJobLoad(w.DefaultJobLoad) + } + + w.mu.Lock() + defer w.mu.Unlock() + w.jobs[m.Job.Id] = &AgentJob{m.Job, load} +} + +func (w *AgentWorker) handleTermination(m *livekit.JobTermination) { + w.JobTerminations.Emit(m) + + w.mu.Lock() + defer w.mu.Unlock() + delete(w.jobs, m.JobId) +} + +func (w *AgentWorker) handlePong(m *livekit.WorkerPong) { + w.WorkerPongs.Emit(m) +} + +func (w *AgentWorker) sendMessage(m *livekit.WorkerMessage) { + select { + case <-w.fuse.Watch(): + case w.workerMessages <- m: + } +} + +func (w *AgentWorker) SendRegister(m *livekit.RegisterWorkerRequest) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_Register{ + Register: m, + }}) +} + +func (w *AgentWorker) SendAvailability(m *livekit.AvailabilityResponse) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_Availability{ + Availability: m, + }}) +} + +func (w *AgentWorker) SendUpdateWorker(m *livekit.UpdateWorkerStatus) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_UpdateWorker{ + UpdateWorker: m, + }}) +} + +func (w *AgentWorker) SendUpdateJob(m *livekit.UpdateJobStatus) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_UpdateJob{ + UpdateJob: m, + }}) +} + +func (w *AgentWorker) SendPing(m *livekit.WorkerPing) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_Ping{ + Ping: m, + }}) +} + +func (w *AgentWorker) SendSimulateJob(m *livekit.SimulateJobRequest) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_SimulateJob{ + SimulateJob: m, + }}) +} + +func (w *AgentWorker) SendMigrateJob(m *livekit.MigrateJobRequest) { + w.sendMessage(&livekit.WorkerMessage{Message: &livekit.WorkerMessage_MigrateJob{ + MigrateJob: m, + }}) +} + +func (w *AgentWorker) sendStatus() { + w.mu.Lock() + var load float32 + jobCount := len(w.jobs) + + if len(w.jobs) == 0 { + load = w.DefaultWorkerLoad + } else { + for _, j := range w.jobs { + load += j.Load() + } + } + w.mu.Unlock() + + status := livekit.WorkerStatus_WS_AVAILABLE + if load > w.JobLoadThreshold { + status = livekit.WorkerStatus_WS_FULL + } + + w.SendUpdateWorker(&livekit.UpdateWorkerStatus{ + Status: &status, + Load: load, + JobCount: uint32(jobCount), + }) +} + +func (w *AgentWorker) Register(agentName string, jobType livekit.JobType) { + w.SendRegister(&livekit.RegisterWorkerRequest{ + Type: jobType, + AgentName: agentName, + }) + go w.statusWorker() +} + +func (w *AgentWorker) SimulateRoomJob(roomName string) { + w.SendSimulateJob(&livekit.SimulateJobRequest{ + Type: livekit.JobType_JT_ROOM, + Room: &livekit.Room{ + Sid: guid.New(guid.RoomPrefix), + Name: roomName, + }, + }) +} + +func (w *AgentWorker) Jobs() []*AgentJob { + w.mu.Lock() + defer w.mu.Unlock() + return maps.Values(w.jobs) +} + +type stableJobLoad struct { + load float32 +} + +func NewStableJobLoad(load float32) JobLoad { + return stableJobLoad{load} +} + +func (s stableJobLoad) Load() float32 { + return s.load +} + +type periodicJobLoad struct { + amplitude float64 + period time.Duration + epoch time.Time +} + +func NewPeriodicJobLoad(max float32, period time.Duration) JobLoad { + return periodicJobLoad{ + amplitude: float64(max / 2), + period: period, + epoch: time.Now().Add(-time.Duration(rand.Int64N(int64(period)))), + } +} + +func (s periodicJobLoad) Load() float32 { + a := math.Sin(time.Since(s.epoch).Seconds() / s.period.Seconds() * math.Pi * 2) + return float32(s.amplitude + a*s.amplitude) +} + +type uniformRandomJobLoad struct { + min, max float32 + rng func() float64 +} + +func NewUniformRandomJobLoad(min, max float32) JobLoad { + return uniformRandomJobLoad{min, max, rand.Float64} +} + +func NewUniformRandomJobLoadWithRNG(min, max float32, rng *rand.Rand) JobLoad { + return uniformRandomJobLoad{min, max, rng.Float64} +} + +func (s uniformRandomJobLoad) Load() float32 { + return rand.Float32()*(s.max-s.min) + s.min +} + +type normalRandomJobLoad struct { + mean, stddev float64 + rng func() float64 +} + +func NewNormalRandomJobLoad(mean, stddev float64) JobLoad { + return normalRandomJobLoad{mean, stddev, rand.Float64} +} + +func NewNormalRandomJobLoadWithRNG(mean, stddev float64, rng *rand.Rand) JobLoad { + return normalRandomJobLoad{mean, stddev, rng.Float64} +} + +func (s normalRandomJobLoad) Load() float32 { + u := 1 - s.rng() + v := s.rng() + z := math.Sqrt(-2*math.Log(u)) * math.Cos(2*math.Pi*v) + return float32(max(0, z*s.stddev+s.mean)) +} diff --git a/livekit/pkg/agent/worker.go b/livekit/pkg/agent/worker.go new file mode 100644 index 0000000..f7aed2a --- /dev/null +++ b/livekit/pkg/agent/worker.go @@ -0,0 +1,578 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agent + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "google.golang.org/protobuf/proto" + + pagent "github.com/livekit/protocol/agent" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" +) + +var ( + ErrUnimplementedWrorkerSignal = errors.New("unimplemented worker signal") + ErrUnknownWorkerSignal = errors.New("unknown worker signal") + ErrUnknownJobType = errors.New("unknown job type") + ErrJobNotFound = psrpc.NewErrorf(psrpc.NotFound, "no running job for given jobID") + ErrWorkerClosed = errors.New("worker closed") + ErrWorkerNotAvailable = errors.New("worker not available") + ErrAvailabilityTimeout = errors.New("agent worker availability timeout") + ErrDuplicateJobAssignment = errors.New("duplicate job assignment") +) + +const AgentNameAttributeKey = "lk.agent_name" + +type WorkerProtocolVersion int + +const CurrentProtocol = 1 + +const ( + RegisterTimeout = 10 * time.Second + AssignJobTimeout = 10 * time.Second +) + +type SignalConn interface { + WriteServerMessage(msg *livekit.ServerMessage) (int, error) + ReadWorkerMessage() (*livekit.WorkerMessage, int, error) + SetReadDeadline(time.Time) error + Close() error +} + +func JobStatusIsEnded(s livekit.JobStatus) bool { + return s == livekit.JobStatus_JS_SUCCESS || s == livekit.JobStatus_JS_FAILED +} + +type WorkerSignalHandler interface { + HandleRegister(*livekit.RegisterWorkerRequest) error + HandleAvailability(*livekit.AvailabilityResponse) error + HandleUpdateJob(*livekit.UpdateJobStatus) error + HandleSimulateJob(*livekit.SimulateJobRequest) error + HandlePing(*livekit.WorkerPing) error + HandleUpdateWorker(*livekit.UpdateWorkerStatus) error + HandleMigrateJob(*livekit.MigrateJobRequest) error +} + +func DispatchWorkerSignal(req *livekit.WorkerMessage, h WorkerSignalHandler) error { + switch m := req.Message.(type) { + case *livekit.WorkerMessage_Register: + return h.HandleRegister(m.Register) + case *livekit.WorkerMessage_Availability: + return h.HandleAvailability(m.Availability) + case *livekit.WorkerMessage_UpdateJob: + return h.HandleUpdateJob(m.UpdateJob) + case *livekit.WorkerMessage_SimulateJob: + return h.HandleSimulateJob(m.SimulateJob) + case *livekit.WorkerMessage_Ping: + return h.HandlePing(m.Ping) + case *livekit.WorkerMessage_UpdateWorker: + return h.HandleUpdateWorker(m.UpdateWorker) + case *livekit.WorkerMessage_MigrateJob: + return h.HandleMigrateJob(m.MigrateJob) + default: + return ErrUnknownWorkerSignal + } +} + +var _ WorkerSignalHandler = (*UnimplementedWorkerSignalHandler)(nil) + +type UnimplementedWorkerSignalHandler struct{} + +func (UnimplementedWorkerSignalHandler) HandleRegister(*livekit.RegisterWorkerRequest) error { + return fmt.Errorf("%w: Register", ErrUnimplementedWrorkerSignal) +} +func (UnimplementedWorkerSignalHandler) HandleAvailability(*livekit.AvailabilityResponse) error { + return fmt.Errorf("%w: Availability", ErrUnimplementedWrorkerSignal) +} +func (UnimplementedWorkerSignalHandler) HandleUpdateJob(*livekit.UpdateJobStatus) error { + return fmt.Errorf("%w: UpdateJob", ErrUnimplementedWrorkerSignal) +} +func (UnimplementedWorkerSignalHandler) HandleSimulateJob(*livekit.SimulateJobRequest) error { + return fmt.Errorf("%w: SimulateJob", ErrUnimplementedWrorkerSignal) +} +func (UnimplementedWorkerSignalHandler) HandlePing(*livekit.WorkerPing) error { + return fmt.Errorf("%w: Ping", ErrUnimplementedWrorkerSignal) +} +func (UnimplementedWorkerSignalHandler) HandleUpdateWorker(*livekit.UpdateWorkerStatus) error { + return fmt.Errorf("%w: UpdateWorker", ErrUnimplementedWrorkerSignal) +} +func (UnimplementedWorkerSignalHandler) HandleMigrateJob(*livekit.MigrateJobRequest) error { + return fmt.Errorf("%w: MigrateJob", ErrUnimplementedWrorkerSignal) +} + +type WorkerPingHandler struct { + UnimplementedWorkerSignalHandler + conn SignalConn +} + +func (h WorkerPingHandler) HandlePing(ping *livekit.WorkerPing) error { + _, err := h.conn.WriteServerMessage(&livekit.ServerMessage{ + Message: &livekit.ServerMessage_Pong{ + Pong: &livekit.WorkerPong{ + LastTimestamp: ping.Timestamp, + Timestamp: time.Now().UnixMilli(), + }, + }, + }) + return err +} + +type WorkerRegistration struct { + Protocol WorkerProtocolVersion + ID string + Version string + AgentID string + AgentName string + Namespace string + JobType livekit.JobType + Permissions *livekit.ParticipantPermission + ClientIP string +} + +func MakeWorkerRegistration() WorkerRegistration { + return WorkerRegistration{ + ID: guid.New(guid.AgentWorkerPrefix), + Protocol: CurrentProtocol, + } +} + +var _ WorkerSignalHandler = (*WorkerRegisterer)(nil) + +type WorkerRegisterer struct { + WorkerPingHandler + serverInfo *livekit.ServerInfo + deadline time.Time + + registration WorkerRegistration + registered bool +} + +func NewWorkerRegisterer(conn SignalConn, serverInfo *livekit.ServerInfo, base WorkerRegistration) *WorkerRegisterer { + return &WorkerRegisterer{ + WorkerPingHandler: WorkerPingHandler{conn: conn}, + serverInfo: serverInfo, + registration: base, + deadline: time.Now().Add(RegisterTimeout), + } +} + +func (h *WorkerRegisterer) Deadline() time.Time { + return h.deadline +} + +func (h *WorkerRegisterer) Registration() WorkerRegistration { + return h.registration +} + +func (h *WorkerRegisterer) Registered() bool { + return h.registered +} + +func (h *WorkerRegisterer) HandleRegister(req *livekit.RegisterWorkerRequest) error { + if !livekit.IsJobType(req.GetType()) { + return ErrUnknownJobType + } + + permissions := req.AllowedPermissions + if permissions == nil { + permissions = &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + CanPublishData: true, + CanUpdateMetadata: true, + } + } + + h.registration.Version = req.Version + h.registration.AgentName = req.AgentName + h.registration.Namespace = req.GetNamespace() + h.registration.JobType = req.GetType() + h.registration.Permissions = permissions + h.registered = true + + _, err := h.conn.WriteServerMessage(&livekit.ServerMessage{ + Message: &livekit.ServerMessage_Register{ + Register: &livekit.RegisterWorkerResponse{ + WorkerId: h.registration.ID, + ServerInfo: h.serverInfo, + }, + }, + }) + return err +} + +var _ WorkerSignalHandler = (*Worker)(nil) + +type Worker struct { + WorkerPingHandler + WorkerRegistration + + apiKey string + apiSecret string + logger logger.Logger + + ctx context.Context + cancel context.CancelFunc + closed chan struct{} + + mu sync.Mutex + load float32 + status livekit.WorkerStatus + + runningJobs map[livekit.JobID]*livekit.Job + availability map[livekit.JobID]chan *livekit.AvailabilityResponse +} + +func NewWorker( + registration WorkerRegistration, + apiKey string, + apiSecret string, + conn SignalConn, + logger logger.Logger, +) *Worker { + ctx, cancel := context.WithCancel(context.Background()) + + return &Worker{ + WorkerPingHandler: WorkerPingHandler{conn: conn}, + WorkerRegistration: registration, + apiKey: apiKey, + apiSecret: apiSecret, + logger: logger.WithValues( + "workerID", registration.ID, + "agentName", registration.AgentName, + "jobType", registration.JobType.String(), + ), + + ctx: ctx, + cancel: cancel, + closed: make(chan struct{}), + + runningJobs: make(map[livekit.JobID]*livekit.Job), + availability: make(map[livekit.JobID]chan *livekit.AvailabilityResponse), + } +} + +func (w *Worker) sendRequest(req *livekit.ServerMessage) { + if _, err := w.conn.WriteServerMessage(req); err != nil { + w.logger.Warnw("error writing to websocket", err) + } +} + +func (w *Worker) Status() livekit.WorkerStatus { + w.mu.Lock() + defer w.mu.Unlock() + return w.status +} + +func (w *Worker) Load() float32 { + w.mu.Lock() + defer w.mu.Unlock() + return w.load +} + +func (w *Worker) Logger() logger.Logger { + return w.logger +} + +func (w *Worker) RunningJobs() map[livekit.JobID]*livekit.Job { + w.mu.Lock() + defer w.mu.Unlock() + jobs := make(map[livekit.JobID]*livekit.Job, len(w.runningJobs)) + for k, v := range w.runningJobs { + jobs[k] = v + } + return jobs +} + +func (w *Worker) RunningJobCount() int { + w.mu.Lock() + defer w.mu.Unlock() + return len(w.runningJobs) +} + +func (w *Worker) GetJobState(jobID livekit.JobID) (*livekit.JobState, error) { + w.mu.Lock() + defer w.mu.Unlock() + j, ok := w.runningJobs[jobID] + if !ok { + return nil, ErrJobNotFound + } + return utils.CloneProto(j.State), nil +} + +func (w *Worker) AssignJob(ctx context.Context, job *livekit.Job) (*livekit.JobState, error) { + availCh := make(chan *livekit.AvailabilityResponse, 1) + job = utils.CloneProto(job) + jobID := livekit.JobID(job.Id) + + w.mu.Lock() + if _, ok := w.availability[jobID]; ok { + w.mu.Unlock() + return nil, ErrDuplicateJobAssignment + } + + w.availability[jobID] = availCh + w.mu.Unlock() + + defer func() { + w.mu.Lock() + delete(w.availability, jobID) + w.mu.Unlock() + }() + + if job.State == nil { + job.State = &livekit.JobState{} + } + now := time.Now() + job.State.WorkerId = w.ID + job.State.AgentId = w.AgentID + job.State.UpdatedAt = now.UnixNano() + job.State.StartedAt = now.UnixNano() + job.State.Status = livekit.JobStatus_JS_RUNNING + + w.sendRequest(&livekit.ServerMessage{Message: &livekit.ServerMessage_Availability{ + Availability: &livekit.AvailabilityRequest{Job: job}, + }}) + + timeout := time.NewTimer(AssignJobTimeout) + defer timeout.Stop() + + // See handleAvailability for the response + select { + case res := <-availCh: + if res.Terminate { + job.State.EndedAt = now.UnixNano() + job.State.Status = livekit.JobStatus_JS_SUCCESS + return job.State, nil + } + + if !res.Available { + return nil, ErrWorkerNotAvailable + } + + job.State.ParticipantIdentity = res.ParticipantIdentity + attributes := res.ParticipantAttributes + if attributes == nil { + attributes = make(map[string]string) + } + attributes[AgentNameAttributeKey] = w.AgentName + + token, err := pagent.BuildAgentToken( + w.apiKey, + w.apiSecret, + job.Room.Name, + res.ParticipantIdentity, + res.ParticipantName, + res.ParticipantMetadata, + attributes, + w.Permissions, + ) + if err != nil { + w.logger.Errorw("failed to build agent token", err) + return nil, err + } + + // In OSS, Url is nil, and the used API Key is the same as the one used to connect the worker + w.sendRequest(&livekit.ServerMessage{Message: &livekit.ServerMessage_Assignment{ + Assignment: &livekit.JobAssignment{Job: job, Url: nil, Token: token}, + }}) + + state := utils.CloneProto(job.State) + + w.mu.Lock() + w.runningJobs[jobID] = job + w.mu.Unlock() + + // TODO sweep jobs that are never started. We can't do this until all SDKs actually update the the JOB state + + return state, nil + case <-timeout.C: + return nil, ErrAvailabilityTimeout + case <-w.ctx.Done(): + return nil, ErrWorkerClosed + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (w *Worker) TerminateJob(jobID livekit.JobID, reason rpc.JobTerminateReason) (*livekit.JobState, error) { + w.mu.Lock() + _, ok := w.runningJobs[jobID] + w.mu.Unlock() + + if !ok { + return nil, ErrJobNotFound + } + + w.sendRequest(&livekit.ServerMessage{Message: &livekit.ServerMessage_Termination{ + Termination: &livekit.JobTermination{ + JobId: string(jobID), + }, + }}) + + status := livekit.JobStatus_JS_SUCCESS + errorStr := "" + if reason == rpc.JobTerminateReason_AGENT_LEFT_ROOM { + status = livekit.JobStatus_JS_FAILED + errorStr = "agent worker left the room" + } + + return w.UpdateJobStatus(&livekit.UpdateJobStatus{ + JobId: string(jobID), + Status: status, + Error: errorStr, + }) +} + +func (w *Worker) UpdateMetadata(metadata string) { + w.logger.Debugw("worker metadata updated", nil, "metadata", metadata) +} + +func (w *Worker) IsClosed() bool { + select { + case <-w.closed: + return true + default: + return false + } +} + +func (w *Worker) Close() { + w.mu.Lock() + if w.IsClosed() { + w.mu.Unlock() + return + } + + w.logger.Infow("closing worker", "workerID", w.ID, "jobType", w.JobType, "agentName", w.AgentName) + + close(w.closed) + w.cancel() + _ = w.conn.Close() + w.mu.Unlock() +} + +func (w *Worker) HandleAvailability(res *livekit.AvailabilityResponse) error { + w.mu.Lock() + defer w.mu.Unlock() + + jobID := livekit.JobID(res.JobId) + availCh, ok := w.availability[jobID] + if !ok { + w.logger.Warnw("received availability response for unknown job", nil, "jobID", jobID) + return nil + } + + availCh <- res + delete(w.availability, jobID) + + return nil +} + +func (w *Worker) HandleUpdateJob(update *livekit.UpdateJobStatus) error { + _, err := w.UpdateJobStatus(update) + if err != nil { + // treating this as a debug message only + // this can happen if the Room closes first, which would delete the agent dispatch + // that would mark the job as successful. subsequent updates from the same worker + // would not be able to find the same jobID. + w.logger.Debugw("received job update for unknown job", "jobID", update.JobId) + } + return nil +} + +func (w *Worker) UpdateJobStatus(update *livekit.UpdateJobStatus) (*livekit.JobState, error) { + w.mu.Lock() + defer w.mu.Unlock() + + jobID := livekit.JobID(update.JobId) + job, ok := w.runningJobs[jobID] + if !ok { + return nil, psrpc.NewErrorf(psrpc.NotFound, "received job update for unknown job") + } + + now := time.Now() + job.State.UpdatedAt = now.UnixNano() + + if job.State.Status == livekit.JobStatus_JS_PENDING && update.Status != livekit.JobStatus_JS_PENDING { + job.State.StartedAt = now.UnixNano() + } + + job.State.Status = update.Status + job.State.Error = update.Error + + if JobStatusIsEnded(update.Status) { + job.State.EndedAt = now.UnixNano() + delete(w.runningJobs, jobID) + + w.logger.Infow("job ended", "jobID", update.JobId, "status", update.Status, "error", update.Error) + } + + return proto.Clone(job.State).(*livekit.JobState), nil +} + +func (w *Worker) HandleSimulateJob(simulate *livekit.SimulateJobRequest) error { + jobType := livekit.JobType_JT_ROOM + if simulate.Participant != nil { + jobType = livekit.JobType_JT_PUBLISHER + } + + job := &livekit.Job{ + Id: guid.New(guid.AgentJobPrefix), + Type: jobType, + Room: simulate.Room, + Participant: simulate.Participant, + Namespace: w.Namespace, + AgentName: w.AgentName, + } + + go func() { + _, err := w.AssignJob(w.ctx, job) + if err != nil { + w.logger.Errorw("unable to simulate job", err, "jobID", job.Id) + } + }() + + return nil +} + +func (w *Worker) HandleUpdateWorker(update *livekit.UpdateWorkerStatus) error { + w.mu.Lock() + defer w.mu.Unlock() + + if status := update.Status; status != nil && w.status != *status { + w.status = *status + w.Logger().Debugw("worker status changed", "status", w.status) + } + w.load = update.GetLoad() + + return nil +} + +func (w *Worker) HandleMigrateJob(req *livekit.MigrateJobRequest) error { + // TODO(theomonnom): On OSS this is not implemented + // We could maybe just move a specific job to another worker + return nil +} diff --git a/livekit/pkg/clientconfiguration/conf.go b/livekit/pkg/clientconfiguration/conf.go new file mode 100644 index 0000000..a8f8ff6 --- /dev/null +++ b/livekit/pkg/clientconfiguration/conf.go @@ -0,0 +1,64 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientconfiguration + +import ( + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/must" +) + +// StaticConfigurations list specific device-side limitations that should be disabled at a global level +var StaticConfigurations = []ConfigurationItem{ + // { + // Match: must.Get(NewScriptMatch(`c.protocol <= 5 || c.browser == "firefox"`)), + // Configuration: &livekit.ClientConfiguration{ResumeConnection: livekit.ClientConfigSetting_DISABLED}, + // Merge: false, + // }, + { + Match: must.Get(NewScriptMatch(`c.browser == "safari"`)), + Configuration: &livekit.ClientConfiguration{ + DisabledCodecs: &livekit.DisabledCodecs{ + Codecs: []*livekit.Codec{ + {Mime: mime.MimeTypeAV1.String()}, + }, + }, + }, + Merge: true, + }, + { + Match: must.Get(NewScriptMatch(`c.browser == "safari" && c.browser_version > "18.3"`)), + Configuration: &livekit.ClientConfiguration{ + DisabledCodecs: &livekit.DisabledCodecs{ + Publish: []*livekit.Codec{ + {Mime: mime.MimeTypeVP9.String()}, + }, + }, + }, + Merge: true, + }, + { + Match: must.Get(NewScriptMatch(`(c.device_model == "xiaomi 2201117ti" && c.os == "android") || + ((c.browser == "firefox" || c.browser == "firefox mobile") && (c.os == "linux" || c.os == "android"))`)), + Configuration: &livekit.ClientConfiguration{ + DisabledCodecs: &livekit.DisabledCodecs{ + Publish: []*livekit.Codec{ + {Mime: mime.MimeTypeH264.String()}, + }, + }, + }, + Merge: false, + }, +} diff --git a/livekit/pkg/clientconfiguration/conf_test.go b/livekit/pkg/clientconfiguration/conf_test.go new file mode 100644 index 0000000..55d3812 --- /dev/null +++ b/livekit/pkg/clientconfiguration/conf_test.go @@ -0,0 +1,124 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientconfiguration + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/must" +) + +func TestScriptMatchConfiguration(t *testing.T) { + t.Run("no merge", func(t *testing.T) { + confs := []ConfigurationItem{ + { + Match: must.Get(NewScriptMatch(`c.protocol > 5 && c.browser != "firefox"`)), + Configuration: &livekit.ClientConfiguration{ + ResumeConnection: livekit.ClientConfigSetting_ENABLED, + }, + }, + } + + cm := NewStaticClientConfigurationManager(confs) + + conf := cm.GetConfiguration(&livekit.ClientInfo{Protocol: 4}) + require.Nil(t, conf) + + conf = cm.GetConfiguration(&livekit.ClientInfo{Protocol: 6, Browser: "firefox"}) + require.Nil(t, conf) + + conf = cm.GetConfiguration(&livekit.ClientInfo{Protocol: 6, Browser: "chrome"}) + require.Equal(t, conf.ResumeConnection, livekit.ClientConfigSetting_ENABLED) + }) + + t.Run("merge", func(t *testing.T) { + confs := []ConfigurationItem{ + { + Match: must.Get(NewScriptMatch(`c.protocol > 5 && c.browser != "firefox"`)), + Configuration: &livekit.ClientConfiguration{ + ResumeConnection: livekit.ClientConfigSetting_ENABLED, + }, + Merge: true, + }, + { + Match: must.Get(NewScriptMatch(`c.sdk == "android"`)), + Configuration: &livekit.ClientConfiguration{ + Video: &livekit.VideoConfiguration{ + HardwareEncoder: livekit.ClientConfigSetting_DISABLED, + }, + }, + Merge: true, + }, + } + + cm := NewStaticClientConfigurationManager(confs) + + conf := cm.GetConfiguration(&livekit.ClientInfo{Protocol: 4}) + require.Nil(t, conf) + + conf = cm.GetConfiguration(&livekit.ClientInfo{Protocol: 6, Browser: "firefox"}) + require.Nil(t, conf) + + conf = cm.GetConfiguration(&livekit.ClientInfo{Protocol: 6, Browser: "chrome", Sdk: 3}) + require.Equal(t, conf.ResumeConnection, livekit.ClientConfigSetting_ENABLED) + require.Equal(t, conf.Video.HardwareEncoder, livekit.ClientConfigSetting_DISABLED) + }) +} + +func TestScriptMatch(t *testing.T) { + client := &livekit.ClientInfo{ + Protocol: 6, + Browser: "chrome", + Sdk: 3, // android + DeviceModel: "12345", + } + + type testcase struct { + name string + expr string + result bool + err bool + } + + cases := []testcase{ + {name: "simple match", expr: `c.protocol > 5`, result: true}, + {name: "invalid expr", expr: `cc.protocol > 5`, err: true}, + {name: "unexist field", expr: `c.protocols > 5`, err: true}, + {name: "combined condition", expr: `c.protocol > 5 && (c.sdk=="android" || c.sdk=="ios")`, result: true}, + {name: "combined condition2", expr: `(c.device_model == "xiaomi 2201117ti" && c.os == "android") || ((c.browser == "firefox" || c.browser == "firefox mobile") && (c.os == "linux" || c.os == "android"))`, result: false}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + match, err := NewScriptMatch(c.expr) + if err != nil { + if !c.err { + require.NoError(t, err) + } + return + } + m, err := match.Match(client) + if c.err { + require.Error(t, err) + } else { + require.Equal(t, c.result, m) + } + }) + + } +} diff --git a/livekit/pkg/clientconfiguration/match.go b/livekit/pkg/clientconfiguration/match.go new file mode 100644 index 0000000..b79559f --- /dev/null +++ b/livekit/pkg/clientconfiguration/match.go @@ -0,0 +1,173 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientconfiguration + +import ( + "errors" + "fmt" + "strings" + + "github.com/d5/tengo/v2" + "github.com/d5/tengo/v2/token" + "golang.org/x/mod/semver" + + "github.com/livekit/protocol/livekit" +) + +type Match interface { + Match(clientInfo *livekit.ClientInfo) (bool, error) +} + +type ScriptMatch struct { + compiled *tengo.Compiled +} + +func NewScriptMatch(expr string) (*ScriptMatch, error) { + script := tengo.NewScript(fmt.Appendf(nil, "__res__ := (%s)", expr)) + if err := script.Add("c", &clientObject{}); err != nil { + return nil, err + } + compiled, err := script.Compile() + if err != nil { + return nil, err + } + return &ScriptMatch{compiled}, nil +} + +// use result of eval script expression for match. +// expression examples: +// protocol bigger than 5 : c.protocol > 5 +// browser if firefox: c.browser == "firefox" +// combined rule : c.protocol > 5 && c.browser == "firefox" +func (m *ScriptMatch) Match(clientInfo *livekit.ClientInfo) (bool, error) { + clone := m.compiled.Clone() + if err := clone.Set("c", &clientObject{info: clientInfo}); err != nil { + return false, err + } + if err := clone.Run(); err != nil { + return false, err + } + + res := clone.Get("__res__").Value() + if val, ok := res.(bool); ok { + return val, nil + } + return false, errors.New("invalid match expression result") +} + +// ------------------------------------------------ + +type clientObject struct { + tengo.ObjectImpl + info *livekit.ClientInfo +} + +func (c *clientObject) TypeName() string { + return "clientObject" +} + +func (c *clientObject) String() string { + return c.info.String() +} + +func (c *clientObject) IndexGet(index tengo.Object) (res tengo.Object, err error) { + field, ok := index.(*tengo.String) + if !ok { + return nil, tengo.ErrInvalidIndexType + } + + switch field.Value { + case "sdk": + return &tengo.String{Value: strings.ToLower(c.info.Sdk.String())}, nil + case "version": + return &ruleSdkVersion{sdkVersion: c.info.Version}, nil + case "protocol": + return &tengo.Int{Value: int64(c.info.Protocol)}, nil + case "os": + return &tengo.String{Value: strings.ToLower(c.info.Os)}, nil + case "os_version": + return &tengo.String{Value: c.info.OsVersion}, nil + case "device_model": + return &tengo.String{Value: strings.ToLower(c.info.DeviceModel)}, nil + case "browser": + return &tengo.String{Value: strings.ToLower(c.info.Browser)}, nil + case "browser_version": + return &ruleSdkVersion{sdkVersion: c.info.BrowserVersion}, nil + case "address": + return &tengo.String{Value: c.info.Address}, nil + } + return &tengo.Undefined{}, nil +} + +// ------------------------------------------ + +type ruleSdkVersion struct { + tengo.ObjectImpl + sdkVersion string +} + +func (r *ruleSdkVersion) TypeName() string { + return "sdkVersion" +} + +func (r *ruleSdkVersion) String() string { + return r.sdkVersion +} + +func (r *ruleSdkVersion) BinaryOp(op token.Token, rhs tengo.Object) (tengo.Object, error) { + if rhs, ok := rhs.(*tengo.String); ok { + cmp := r.compare(rhs.Value) + + isMatch := false + switch op { + case token.Greater: + isMatch = cmp > 0 + case token.GreaterEq: + isMatch = cmp >= 0 + default: + return nil, tengo.ErrInvalidOperator + } + + if isMatch { + return tengo.TrueValue, nil + } + return tengo.FalseValue, nil + } + + return nil, tengo.ErrInvalidOperator +} + +func (r *ruleSdkVersion) Equals(rhs tengo.Object) bool { + if rhs, ok := rhs.(*tengo.String); ok { + return r.compare(rhs.Value) == 0 + } + + return false +} + +func (r *ruleSdkVersion) compare(rhsSdkVersion string) int { + if !semver.IsValid("v"+r.sdkVersion) || !semver.IsValid("v"+rhsSdkVersion) { + // if not valid semver, do string compare + switch { + case r.sdkVersion < rhsSdkVersion: + return -1 + case r.sdkVersion > rhsSdkVersion: + return 1 + } + } else { + return semver.Compare("v"+r.sdkVersion, "v"+rhsSdkVersion) + } + return 0 +} diff --git a/livekit/pkg/clientconfiguration/staticconfiguration.go b/livekit/pkg/clientconfiguration/staticconfiguration.go new file mode 100644 index 0000000..d8145c2 --- /dev/null +++ b/livekit/pkg/clientconfiguration/staticconfiguration.go @@ -0,0 +1,71 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientconfiguration + +import ( + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + protoutils "github.com/livekit/protocol/utils" +) + +type ConfigurationItem struct { + Match + Configuration *livekit.ClientConfiguration + Merge bool +} + +type StaticClientConfigurationManager struct { + confs []ConfigurationItem +} + +func NewStaticClientConfigurationManager(confs []ConfigurationItem) *StaticClientConfigurationManager { + return &StaticClientConfigurationManager{confs: confs} +} + +func (s *StaticClientConfigurationManager) GetConfiguration(clientInfo *livekit.ClientInfo) *livekit.ClientConfiguration { + var matchedConf []*livekit.ClientConfiguration + for _, c := range s.confs { + matched, err := c.Match.Match(clientInfo) + if err != nil { + logger.Errorw("matchrule failed", err, + "clientInfo", logger.Proto(utils.ClientInfoWithoutAddress(clientInfo)), + ) + continue + } + if !matched { + continue + } + if !c.Merge { + return c.Configuration + } + matchedConf = append(matchedConf, c.Configuration) + } + + var conf *livekit.ClientConfiguration + for k, v := range matchedConf { + if k == 0 { + conf = protoutils.CloneProto(matchedConf[0]) + } else { + // TODO : there is a problem use protobuf merge, we don't have flag to indicate 'no value', + // don't override default behavior or other configuration's field. So a bool value = false or + // a int value = 0 will override same field in other configuration + proto.Merge(conf, v) + } + } + return conf +} diff --git a/livekit/pkg/clientconfiguration/types.go b/livekit/pkg/clientconfiguration/types.go new file mode 100644 index 0000000..b014518 --- /dev/null +++ b/livekit/pkg/clientconfiguration/types.go @@ -0,0 +1,23 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientconfiguration + +import ( + "github.com/livekit/protocol/livekit" +) + +type ClientConfigurationManager interface { + GetConfiguration(clientInfo *livekit.ClientInfo) *livekit.ClientConfiguration +} diff --git a/livekit/pkg/config/config.go b/livekit/pkg/config/config.go new file mode 100644 index 0000000..87aaafc --- /dev/null +++ b/livekit/pkg/config/config.go @@ -0,0 +1,796 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "reflect" + "strings" + "time" + + "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" + "github.com/urfave/cli/v3" + "gopkg.in/yaml.v3" + + "github.com/livekit/mediatransportutil/pkg/rtcconfig" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + redisLiveKit "github.com/livekit/protocol/redis" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/webhook" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/metric" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/bwe/remotebwe" + "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" +) + +const ( + generatedCLIFlagUsage = "generated" +) + +var ( + ErrKeyFileIncorrectPermission = errors.New("key file others permissions must be set to 0") + ErrKeysNotSet = errors.New("one of key-file or keys must be provided") +) + +type Config struct { + Port uint32 `yaml:"port,omitempty"` + BindAddresses []string `yaml:"bind_addresses,omitempty"` + // PrometheusPort is deprecated + PrometheusPort uint32 `yaml:"prometheus_port,omitempty"` + Prometheus PrometheusConfig `yaml:"prometheus,omitempty"` + RTC RTCConfig `yaml:"rtc,omitempty"` + Redis redisLiveKit.RedisConfig `yaml:"redis,omitempty"` + Audio sfu.AudioConfig `yaml:"audio,omitempty"` + Video VideoConfig `yaml:"video,omitempty"` + Room RoomConfig `yaml:"room,omitempty"` + TURN TURNConfig `yaml:"turn,omitempty"` + Ingress IngressConfig `yaml:"ingress,omitempty"` + SIP SIPConfig `yaml:"sip,omitempty"` + WebHook webhook.WebHookConfig `yaml:"webhook,omitempty"` + NodeSelector NodeSelectorConfig `yaml:"node_selector,omitempty"` + KeyFile string `yaml:"key_file,omitempty"` + Keys map[string]string `yaml:"keys,omitempty"` + Region string `yaml:"region,omitempty"` + SignalRelay SignalRelayConfig `yaml:"signal_relay,omitempty"` + PSRPC rpc.PSRPCConfig `yaml:"psrpc,omitempty"` + // Deprecated: LogLevel is deprecated + LogLevel string `yaml:"log_level,omitempty"` + Logging LoggingConfig `yaml:"logging,omitempty"` + Limit LimitConfig `yaml:"limit,omitempty"` + Agents agent.Config `yaml:"agents,omitempty"` + + Development bool `yaml:"development,omitempty"` + + Metric metric.MetricConfig `yaml:"metric,omitempty"` + Trace TracingConfig `yaml:"trace,omitempty"` + + NodeStats NodeStatsConfig `yaml:"node_stats,omitempty"` + + EnableDataTracks bool `yaml:"enable_data_tracks,omitempty"` +} + +type RTCConfig struct { + rtcconfig.RTCConfig `yaml:",inline"` + + TURNServers []TURNServer `yaml:"turn_servers,omitempty"` + + // Deprecated + StrictACKs bool `yaml:"strict_acks,omitempty"` + + // Deprecated: use PacketBufferSizeVideo and PacketBufferSizeAudio + PacketBufferSize int `yaml:"packet_buffer_size,omitempty"` + // Number of packets to buffer for NACK - video + PacketBufferSizeVideo int `yaml:"packet_buffer_size_video,omitempty"` + // Number of packets to buffer for NACK - audio + PacketBufferSizeAudio int `yaml:"packet_buffer_size_audio,omitempty"` + + // Throttle periods for pli/fir rtcp packets + PLIThrottle sfu.PLIThrottleConfig `yaml:"pli_throttle,omitempty"` + + CongestionControl CongestionControlConfig `yaml:"congestion_control,omitempty"` + + // allow TCP and TURN/TLS fallback + AllowTCPFallback *bool `yaml:"allow_tcp_fallback,omitempty"` + + // force a reconnect on a publication error + ReconnectOnPublicationError *bool `yaml:"reconnect_on_publication_error,omitempty"` + + // force a reconnect on a subscription error + ReconnectOnSubscriptionError *bool `yaml:"reconnect_on_subscription_error,omitempty"` + + // force a reconnect on a data channel error + ReconnectOnDataChannelError *bool `yaml:"reconnect_on_data_channel_error,omitempty"` + + // Deprecated + DataChannelMaxBufferedAmount uint64 `yaml:"data_channel_max_buffered_amount,omitempty"` + + // Threshold of data channel writing to be considered too slow, data packet could + // be dropped for a slow data channel to avoid blocking the room. + DatachannelSlowThreshold int `yaml:"datachannel_slow_threshold,omitempty"` + + // Target latency for lossy data channels, used to drop packets to reduce latency. + DatachannelLossyTargetLatency time.Duration `yaml:"datachannel_lossy_target_latency,omitempty"` + + ForwardStats ForwardStatsConfig `yaml:"forward_stats,omitempty"` + + // enable rtp stream restart detection for published tracks + EnableRTPStreamRestartDetection bool `yaml:"enable_rtp_stream_restart_detection,omitempty"` +} + +type TURNServer struct { + Host string `yaml:"host,omitempty"` + Port int `yaml:"port,omitempty"` + Protocol string `yaml:"protocol,omitempty"` + Username string `yaml:"username,omitempty"` + Credential string `yaml:"credential,omitempty"` + // Secret is used for TURN static auth secrets mechanism. When provided, + // dynamic credentials are generated using HMAC-SHA1 instead of static Username/Credential + Secret string `yaml:"secret,omitempty"` + // TTL is the time-to-live in seconds for generated credentials when using Secret. + // Defaults to 14400 seconds (4 hours) if not specified + TTL int `yaml:"ttl,omitempty"` +} + +type CongestionControlConfig struct { + Enabled bool `yaml:"enabled,omitempty"` + AllowPause bool `yaml:"allow_pause,omitempty"` + + StreamAllocator streamallocator.StreamAllocatorConfig `yaml:"stream_allocator,omitempty"` + + RemoteBWE remotebwe.RemoteBWEConfig `yaml:"remote_bwe,omitempty"` + + UseSendSideBWEInterceptor bool `yaml:"use_send_side_bwe_interceptor,omitempty"` + + UseSendSideBWE bool `yaml:"use_send_side_bwe,omitempty"` + SendSideBWEPacer string `yaml:"send_side_bwe_pacer,omitempty"` + SendSideBWE sendsidebwe.SendSideBWEConfig `yaml:"send_side_bwe,omitempty"` +} + +type PlayoutDelayConfig struct { + Enabled bool `yaml:"enabled,omitempty"` + Min int `yaml:"min,omitempty"` + Max int `yaml:"max,omitempty"` +} + +type VideoConfig struct { + DynacastPauseDelay time.Duration `yaml:"dynacast_pause_delay,omitempty"` + StreamTrackerManager sfu.StreamTrackerManagerConfig `yaml:"stream_tracker_manager,omitempty"` + + CodecRegressionThreshold int `yaml:"codec_regression_threshold,omitempty"` +} + +type RoomConfig struct { + // enable rooms to be automatically created + AutoCreate bool `yaml:"auto_create,omitempty"` + EnabledCodecs []CodecSpec `yaml:"enabled_codecs,omitempty"` + MaxParticipants uint32 `yaml:"max_participants,omitempty"` + EmptyTimeout uint32 `yaml:"empty_timeout,omitempty"` + DepartureTimeout uint32 `yaml:"departure_timeout,omitempty"` + EnableRemoteUnmute bool `yaml:"enable_remote_unmute,omitempty"` + PlayoutDelay PlayoutDelayConfig `yaml:"playout_delay,omitempty"` + SyncStreams bool `yaml:"sync_streams,omitempty"` + CreateRoomEnabled bool `yaml:"create_room_enabled,omitempty"` + CreateRoomTimeout time.Duration `yaml:"create_room_timeout,omitempty"` + CreateRoomAttempts int `yaml:"create_room_attempts,omitempty"` + // target room participant update batch chunk size in bytes + UpdateBatchTargetSize int `yaml:"update_batch_target_size,omitempty"` + // deprecated, moved to limits + MaxMetadataSize uint32 `yaml:"max_metadata_size,omitempty"` + // deprecated, moved to limits + MaxRoomNameLength int `yaml:"max_room_name_length,omitempty"` + // deprecated, moved to limits + MaxParticipantIdentityLength int `yaml:"max_participant_identity_length,omitempty"` + RoomConfigurations map[string]*livekit.RoomConfiguration `yaml:"room_configurations,omitempty"` +} + +type CodecSpec struct { + Mime string `yaml:"mime,omitempty"` + FmtpLine string `yaml:"fmtp_line,omitempty"` +} + +type LoggingConfig struct { + logger.Config `yaml:",inline"` + PionLevel string `yaml:"pion_level,omitempty"` +} + +type TURNConfig struct { + Enabled bool `yaml:"enabled,omitempty"` + Domain string `yaml:"domain,omitempty"` + CertFile string `yaml:"cert_file,omitempty"` + KeyFile string `yaml:"key_file,omitempty"` + TLSPort int `yaml:"tls_port,omitempty"` + UDPPort int `yaml:"udp_port,omitempty"` + RelayPortRangeStart uint16 `yaml:"relay_range_start,omitempty"` + RelayPortRangeEnd uint16 `yaml:"relay_range_end,omitempty"` + ExternalTLS bool `yaml:"external_tls,omitempty"` +} + +type NodeSelectorConfig struct { + Kind string `yaml:"kind,omitempty"` + SortBy string `yaml:"sort_by,omitempty"` + Algorithm string `yaml:"algorithm,omitempty"` + CPULoadLimit float32 `yaml:"cpu_load_limit,omitempty"` + SysloadLimit float32 `yaml:"sysload_limit,omitempty"` + Regions []RegionConfig `yaml:"regions,omitempty"` +} + +type SignalRelayConfig struct { + RetryTimeout time.Duration `yaml:"retry_timeout,omitempty"` + MinRetryInterval time.Duration `yaml:"min_retry_interval,omitempty"` + MaxRetryInterval time.Duration `yaml:"max_retry_interval,omitempty"` + StreamBufferSize int `yaml:"stream_buffer_size,omitempty"` + ConnectAttempts int `yaml:"connect_attempts,omitempty"` +} + +// RegionConfig lists available regions and their latitude/longitude, so the selector would prefer +// regions that are closer +type RegionConfig struct { + Name string `yaml:"name,omitempty"` + Lat float64 `yaml:"lat,omitempty"` + Lon float64 `yaml:"lon,omitempty"` +} + +type LimitConfig struct { + NumTracks int32 `yaml:"num_tracks,omitempty"` + BytesPerSec float32 `yaml:"bytes_per_sec,omitempty"` + SubscriptionLimitVideo int32 `yaml:"subscription_limit_video,omitempty"` + SubscriptionLimitAudio int32 `yaml:"subscription_limit_audio,omitempty"` + MaxMetadataSize uint32 `yaml:"max_metadata_size,omitempty"` + // total size of all attributes on a participant + MaxAttributesSize uint32 `yaml:"max_attributes_size,omitempty"` + MaxRoomNameLength int `yaml:"max_room_name_length,omitempty"` + MaxParticipantIdentityLength int `yaml:"max_participant_identity_length,omitempty"` + MaxParticipantNameLength int `yaml:"max_participant_name_length,omitempty"` +} + +func (l LimitConfig) CheckRoomNameLength(name string) bool { + return l.MaxRoomNameLength == 0 || len(name) <= l.MaxRoomNameLength +} + +func (l LimitConfig) CheckParticipantIdentityLength(identity string) bool { + return l.MaxParticipantIdentityLength == 0 || len(identity) <= l.MaxParticipantIdentityLength +} + +func (l LimitConfig) CheckParticipantNameLength(name string) bool { + return l.MaxParticipantNameLength == 0 || len(name) <= l.MaxParticipantNameLength +} + +func (l LimitConfig) CheckMetadataSize(metadata string) bool { + return l.MaxMetadataSize == 0 || uint32(len(metadata)) <= l.MaxMetadataSize +} + +func (l LimitConfig) CheckAttributesSize(attributes map[string]string) bool { + if l.MaxAttributesSize == 0 { + return true + } + + total := 0 + for k, v := range attributes { + total += len(k) + len(v) + } + return uint32(total) <= l.MaxAttributesSize +} + +type IngressConfig struct { + RTMPBaseURL string `yaml:"rtmp_base_url,omitempty"` + WHIPBaseURL string `yaml:"whip_base_url,omitempty"` +} + +type SIPConfig struct{} + +type APIConfig struct { + // amount of time to wait for API to execute, default 2s + ExecutionTimeout time.Duration `yaml:"execution_timeout,omitempty"` + + // min amount of time to wait before checking for operation complete + CheckInterval time.Duration `yaml:"check_interval,omitempty"` + + // max amount of time to wait before checking for operation complete + MaxCheckInterval time.Duration `yaml:"max_check_interval,omitempty"` +} + +type PrometheusConfig struct { + Port uint32 `yaml:"port,omitempty"` + Username string `yaml:"username,omitempty"` + Password string `yaml:"password,omitempty"` +} + +type ForwardStatsConfig struct { + SummaryInterval time.Duration `yaml:"summary_interval,omitempty"` + ReportInterval time.Duration `yaml:"report_interval,omitempty"` + ReportWindow time.Duration `yaml:"report_window,omitempty"` +} + +type TracingConfig struct { + // JaegerURL configures Jaeger as a global tracer. + // + // The following formats are supported: , :, http(s):/// + JaegerURL string `yaml:"jaeger_url,omitempty"` +} + +func DefaultAPIConfig() APIConfig { + return APIConfig{ + ExecutionTimeout: 2 * time.Second, + CheckInterval: 100 * time.Millisecond, + MaxCheckInterval: 300 * time.Second, + } +} + +type NodeStatsConfig struct { + StatsUpdateInterval time.Duration `yaml:"stats_update_interval,omitempty"` + StatsRateMeasurementIntervals []time.Duration `yaml:"stats_rate_measurement_intervals,omitempty"` + StatsMaxDelay time.Duration `yaml:"stats_max_delay,omitempty"` +} + +var DefaultNodeStatsConfig = NodeStatsConfig{ + StatsUpdateInterval: 2 * time.Second, + StatsRateMeasurementIntervals: []time.Duration{10 * time.Second}, + StatsMaxDelay: 30 * time.Second, +} + +var DefaultConfig = Config{ + Port: 7880, + RTC: RTCConfig{ + RTCConfig: rtcconfig.RTCConfig{ + UseExternalIP: false, + TCPPort: 7881, + ICEPortRangeStart: 0, + ICEPortRangeEnd: 0, + STUNServers: []string{}, + }, + PacketBufferSize: 500, + PacketBufferSizeVideo: 500, + PacketBufferSizeAudio: 200, + PLIThrottle: sfu.DefaultPLIThrottleConfig, + CongestionControl: CongestionControlConfig{ + Enabled: true, + AllowPause: false, + StreamAllocator: streamallocator.DefaultStreamAllocatorConfig, + RemoteBWE: remotebwe.DefaultRemoteBWEConfig, + UseSendSideBWEInterceptor: false, + UseSendSideBWE: false, + SendSideBWEPacer: string(pacer.PacerBehaviorNoQueue), + SendSideBWE: sendsidebwe.DefaultSendSideBWEConfig, + }, + }, + Audio: sfu.DefaultAudioConfig, + Video: VideoConfig{ + DynacastPauseDelay: 5 * time.Second, + StreamTrackerManager: sfu.DefaultStreamTrackerManagerConfig, + CodecRegressionThreshold: 5, + }, + Redis: redisLiveKit.RedisConfig{}, + Room: RoomConfig{ + AutoCreate: true, + EnabledCodecs: []CodecSpec{ + {Mime: mime.MimeTypePCMU.String()}, + {Mime: mime.MimeTypePCMA.String()}, + {Mime: mime.MimeTypeOpus.String()}, + {Mime: mime.MimeTypeRED.String()}, + {Mime: mime.MimeTypeVP8.String()}, + {Mime: mime.MimeTypeH264.String()}, + {Mime: mime.MimeTypeVP9.String()}, + {Mime: mime.MimeTypeAV1.String()}, + {Mime: mime.MimeTypeH265.String()}, + {Mime: mime.MimeTypeRTX.String()}, + }, + EmptyTimeout: 5 * 60, + DepartureTimeout: 20, + CreateRoomEnabled: true, + CreateRoomTimeout: 10 * time.Second, + CreateRoomAttempts: 3, + UpdateBatchTargetSize: 128 * 1024, + }, + Limit: LimitConfig{ + MaxMetadataSize: 64000, + MaxAttributesSize: 64000, + MaxRoomNameLength: 256, + MaxParticipantIdentityLength: 256, + MaxParticipantNameLength: 256, + }, + Logging: LoggingConfig{ + PionLevel: "error", + }, + TURN: TURNConfig{ + Enabled: false, + }, + NodeSelector: NodeSelectorConfig{ + Kind: "any", + SortBy: "random", + SysloadLimit: 0.9, + CPULoadLimit: 0.9, + Algorithm: "lowest", + }, + SignalRelay: SignalRelayConfig{ + RetryTimeout: 7500 * time.Millisecond, + MinRetryInterval: 500 * time.Millisecond, + MaxRetryInterval: 4 * time.Second, + StreamBufferSize: 1000, + ConnectAttempts: 3, + }, + PSRPC: rpc.DefaultPSRPCConfig, + Keys: map[string]string{}, + Metric: metric.DefaultMetricConfig, + WebHook: webhook.DefaultWebHookConfig, + NodeStats: DefaultNodeStatsConfig, +} + +func NewConfig(confString string, strictMode bool, c *cli.Command, baseFlags []cli.Flag) (*Config, error) { + // start with defaults + marshalled, err := yaml.Marshal(&DefaultConfig) + if err != nil { + return nil, err + } + + var conf Config + err = yaml.Unmarshal(marshalled, &conf) + if err != nil { + return nil, err + } + + if confString != "" { + decoder := yaml.NewDecoder(strings.NewReader(confString)) + decoder.KnownFields(strictMode) + if err := decoder.Decode(&conf); err != nil { + return nil, fmt.Errorf("could not parse config: %v", err) + } + } + + if c != nil { + if err := conf.updateFromCLI(c, baseFlags); err != nil { + return nil, err + } + } + + if err := conf.RTC.Validate(conf.Development); err != nil { + return nil, fmt.Errorf("could not validate RTC config: %v", err) + } + + // expand env vars in filenames + file, err := homedir.Expand(os.ExpandEnv(conf.KeyFile)) + if err != nil { + return nil, err + } + conf.KeyFile = file + + // set defaults for Turn relay if none are set + if conf.TURN.RelayPortRangeStart == 0 || conf.TURN.RelayPortRangeEnd == 0 { + // to make it easier to run in dev mode/docker, default to two ports + if conf.Development { + conf.TURN.RelayPortRangeStart = 30000 + conf.TURN.RelayPortRangeEnd = 30002 + } else { + conf.TURN.RelayPortRangeStart = 30000 + conf.TURN.RelayPortRangeEnd = 40000 + } + } + + if conf.LogLevel != "" { + conf.Logging.Level = conf.LogLevel + } + if conf.Logging.Level == "" && conf.Development { + conf.Logging.Level = "debug" + } + if conf.Logging.PionLevel != "" { + if conf.Logging.ComponentLevels == nil { + conf.Logging.ComponentLevels = map[string]string{} + } + conf.Logging.ComponentLevels["transport.pion"] = conf.Logging.PionLevel + conf.Logging.ComponentLevels["pion"] = conf.Logging.PionLevel + } + + // copy over legacy limits + if conf.Room.MaxMetadataSize != 0 { + conf.Limit.MaxMetadataSize = conf.Room.MaxMetadataSize + } + if conf.Room.MaxParticipantIdentityLength != 0 { + conf.Limit.MaxParticipantIdentityLength = conf.Room.MaxParticipantIdentityLength + } + if conf.Room.MaxRoomNameLength != 0 { + conf.Limit.MaxRoomNameLength = conf.Room.MaxRoomNameLength + } + + return &conf, nil +} + +func (conf *Config) IsTURNSEnabled() bool { + if conf.TURN.Enabled && conf.TURN.TLSPort != 0 { + return true + } + for _, s := range conf.RTC.TURNServers { + if s.Protocol == "tls" { + return true + } + } + return false +} + +type configNode struct { + TypeNode reflect.Value + TagPrefix string +} + +func (conf *Config) ToCLIFlagNames(existingFlags []cli.Flag) map[string]reflect.Value { + existingFlagNames := map[string]bool{} + for _, flag := range existingFlags { + for _, flagName := range flag.Names() { + existingFlagNames[flagName] = true + } + } + + flagNames := map[string]reflect.Value{} + var currNode configNode + nodes := []configNode{{reflect.ValueOf(conf).Elem(), ""}} + for len(nodes) > 0 { + currNode, nodes = nodes[0], nodes[1:] + for i := 0; i < currNode.TypeNode.NumField(); i++ { + // inspect yaml tag from struct field to get path + field := currNode.TypeNode.Type().Field(i) + yamlTagArray := strings.SplitN(field.Tag.Get("yaml"), ",", 2) + yamlTag := yamlTagArray[0] + isInline := false + if len(yamlTagArray) > 1 && yamlTagArray[1] == "inline" { + isInline = true + } + if (yamlTag == "" && (!isInline || currNode.TagPrefix == "")) || yamlTag == "-" { + continue + } + yamlPath := yamlTag + if currNode.TagPrefix != "" { + if isInline { + yamlPath = currNode.TagPrefix + } else { + yamlPath = fmt.Sprintf("%s.%s", currNode.TagPrefix, yamlTag) + } + } + if existingFlagNames[yamlPath] { + continue + } + + // map flag name to value + value := currNode.TypeNode.Field(i) + if value.Kind() == reflect.Struct { + nodes = append(nodes, configNode{value, yamlPath}) + } else { + flagNames[yamlPath] = value + } + } + } + + return flagNames +} + +func (conf *Config) ValidateKeys() error { + // prefer keyfile if set + if conf.KeyFile != "" { + var otherFilter os.FileMode = 0o007 + if st, err := os.Stat(conf.KeyFile); err != nil { + return err + } else if st.Mode().Perm()&otherFilter != 0o000 { + return ErrKeyFileIncorrectPermission + } + f, err := os.Open(conf.KeyFile) + if err != nil { + return err + } + defer func() { + _ = f.Close() + }() + decoder := yaml.NewDecoder(f) + conf.Keys = map[string]string{} + if err = decoder.Decode(conf.Keys); err != nil { + return err + } + } + + if len(conf.Keys) == 0 { + return ErrKeysNotSet + } + + if !conf.Development { + for key, secret := range conf.Keys { + if len(secret) < 32 { + logger.Errorw("secret is too short, should be at least 32 characters for security", nil, "apiKey", key) + } + } + } + return nil +} + +func GenerateCLIFlags(existingFlags []cli.Flag, hidden bool) ([]cli.Flag, error) { + blankConfig := &Config{} + flags := make([]cli.Flag, 0) + for name, value := range blankConfig.ToCLIFlagNames(existingFlags) { + kind := value.Kind() + if kind == reflect.Ptr { + kind = value.Type().Elem().Kind() + } + + var flag cli.Flag + envVar := fmt.Sprintf("LIVEKIT_%s", strings.ToUpper(strings.Replace(name, ".", "_", -1))) + + switch kind { + case reflect.Bool: + flag = &cli.BoolFlag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.String: + flag = &cli.StringFlag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Int, reflect.Int32: + flag = &cli.IntFlag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Int64: + flag = &cli.Int64Flag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + flag = &cli.UintFlag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Uint64: + flag = &cli.Uint64Flag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Float32: + flag = &cli.Float64Flag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Float64: + flag = &cli.Float64Flag{ + Name: name, + Sources: cli.EnvVars(envVar), + Usage: generatedCLIFlagUsage, + Hidden: hidden, + } + case reflect.Slice: + // TODO + continue + case reflect.Map: + // TODO + continue + case reflect.Struct: + // TODO + continue + default: + return flags, fmt.Errorf("cli flag generation unsupported for config type: %s is a %s", name, kind.String()) + } + + flags = append(flags, flag) + } + + return flags, nil +} + +func (conf *Config) updateFromCLI(c *cli.Command, baseFlags []cli.Flag) error { + generatedFlagNames := conf.ToCLIFlagNames(baseFlags) + for _, flag := range c.Flags { + flagName := flag.Names()[0] + + if !c.IsSet(flagName) { + continue + } + + configValue, ok := generatedFlagNames[flagName] + if !ok { + continue + } + + if configValue.Kind() == reflect.Ptr { + configValue.Set(reflect.New(configValue.Type().Elem())) + configValue = configValue.Elem() + } + + value := reflect.ValueOf(c.Value(flagName)) + if value.CanConvert(configValue.Type()) { + configValue.Set(value.Convert(configValue.Type())) + } else { + return fmt.Errorf("unsupported generated cli flag type for config: %s (expected %s, got %s)", flagName, configValue.Type(), value.Type()) + } + } + + if c.IsSet("dev") { + conf.Development = c.Bool("dev") + } + if c.IsSet("key-file") { + conf.KeyFile = c.String("key-file") + } + if c.IsSet("keys") { + if err := conf.unmarshalKeys(c.String("keys")); err != nil { + return errors.New("Could not parse keys, it needs to be exactly, \"key: secret\", including the space") + } + } + if c.IsSet("region") { + conf.Region = c.String("region") + } + if c.IsSet("redis-host") { + conf.Redis.Address = c.String("redis-host") + } + if c.IsSet("redis-password") { + conf.Redis.Password = c.String("redis-password") + } + if c.IsSet("turn-cert") { + conf.TURN.CertFile = c.String("turn-cert") + } + if c.IsSet("turn-key") { + conf.TURN.KeyFile = c.String("turn-key") + } + if c.IsSet("node-ip") { + conf.RTC.NodeIP = c.String("node-ip") + } + if c.IsSet("udp-port") { + conf.RTC.UDPPort.UnmarshalString(c.String("udp-port")) + } + if c.IsSet("bind") { + conf.BindAddresses = c.StringSlice("bind") + } + return nil +} + +func (conf *Config) unmarshalKeys(keys string) error { + temp := make(map[string]any) + if err := yaml.Unmarshal([]byte(keys), temp); err != nil { + return err + } + + conf.Keys = make(map[string]string, len(temp)) + + for key, val := range temp { + if secret, ok := val.(string); ok { + conf.Keys[key] = secret + } + } + return nil +} + +// Note: only pass in logr.Logger with default depth +func SetLogger(l logger.Logger) { + logger.SetLogger(l, "livekit") +} + +func InitLoggerFromConfig(config *LoggingConfig) { + logger.InitFromConfig(&config.Config, "livekit") +} diff --git a/livekit/pkg/config/config_test.go b/livekit/pkg/config/config_test.go new file mode 100644 index 0000000..da2b2e6 --- /dev/null +++ b/livekit/pkg/config/config_test.go @@ -0,0 +1,85 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v3" + + "github.com/livekit/livekit-server/pkg/config/configtest" +) + +func TestConfig_UnmarshalKeys(t *testing.T) { + conf, err := NewConfig("", true, nil, nil) + require.NoError(t, err) + + require.NoError(t, conf.unmarshalKeys("key1: secret1")) + require.Equal(t, "secret1", conf.Keys["key1"]) +} + +func TestConfig_DefaultsKept(t *testing.T) { + const content = `room: + empty_timeout: 10` + conf, err := NewConfig(content, true, nil, nil) + require.NoError(t, err) + require.Equal(t, true, conf.Room.AutoCreate) + require.Equal(t, uint32(10), conf.Room.EmptyTimeout) +} + +func TestConfig_UnknownKeys(t *testing.T) { + const content = `unknown: 10 +room: + empty_timeout: 10` + _, err := NewConfig(content, true, nil, nil) + require.Error(t, err) +} + +func TestGeneratedFlags(t *testing.T) { + generatedFlags, err := GenerateCLIFlags(nil, false) + require.NoError(t, err) + + c := &cli.Command{} + c.Name = "test" + c.Flags = append(c.Flags, generatedFlags...) + + c.Set("rtc.use_ice_lite", "true") + c.Set("redis.address", "localhost:6379") + c.Set("prometheus.port", "9999") + c.Set("rtc.allow_tcp_fallback", "true") + c.Set("rtc.reconnect_on_publication_error", "true") + c.Set("rtc.reconnect_on_subscription_error", "false") + + conf, err := NewConfig("", true, c, nil) + require.NoError(t, err) + + require.True(t, conf.RTC.UseICELite) + require.Equal(t, "localhost:6379", conf.Redis.Address) + require.Equal(t, uint32(9999), conf.Prometheus.Port) + + require.NotNil(t, conf.RTC.AllowTCPFallback) + require.True(t, *conf.RTC.AllowTCPFallback) + + require.NotNil(t, conf.RTC.ReconnectOnPublicationError) + require.True(t, *conf.RTC.ReconnectOnPublicationError) + + require.NotNil(t, conf.RTC.ReconnectOnSubscriptionError) + require.False(t, *conf.RTC.ReconnectOnSubscriptionError) +} + +func TestYAMLTag(t *testing.T) { + require.NoError(t, configtest.CheckYAMLTags(Config{})) +} diff --git a/livekit/pkg/config/configtest/checkyamltag.go b/livekit/pkg/config/configtest/checkyamltag.go new file mode 100644 index 0000000..2458a5b --- /dev/null +++ b/livekit/pkg/config/configtest/checkyamltag.go @@ -0,0 +1,69 @@ +package configtest + +import ( + "fmt" + "reflect" + "slices" + "strings" + + "go.uber.org/multierr" + "google.golang.org/protobuf/proto" +) + +var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem() + +func checkYAMLTags(t reflect.Type, seen map[reflect.Type]struct{}) error { + if _, ok := seen[t]; ok { + return nil + } + seen[t] = struct{}{} + + switch t.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.Pointer: + return checkYAMLTags(t.Elem(), seen) + case reflect.Struct: + if reflect.PointerTo(t).Implements(protoMessageType) { + // ignore protobuf messages + return nil + } + + var errs error + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if !field.IsExported() { + // ignore unexported fields + continue + } + + if field.Type.Kind() == reflect.Bool { + // ignore boolean fields + continue + } + + if field.Tag.Get("config") == "allowempty" { + // ignore configured exceptions + continue + } + + parts := strings.Split(field.Tag.Get("yaml"), ",") + if parts[0] == "-" { + // ignore unparsed fields + continue + } + + if !slices.Contains(parts, "omitempty") && !slices.Contains(parts, "inline") { + errs = multierr.Append(errs, fmt.Errorf("%s/%s.%s missing omitempty tag", t.PkgPath(), t.Name(), field.Name)) + } + + errs = multierr.Append(errs, checkYAMLTags(field.Type, seen)) + } + return errs + default: + return nil + } +} + +func CheckYAMLTags(config any) error { + return checkYAMLTags(reflect.TypeOf(config), map[reflect.Type]struct{}{}) +} diff --git a/livekit/pkg/metric/metric_config.go b/livekit/pkg/metric/metric_config.go new file mode 100644 index 0000000..48c9212 --- /dev/null +++ b/livekit/pkg/metric/metric_config.go @@ -0,0 +1,31 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +// ------------------------------------------------ + +type MetricConfig struct { + Timestamper MetricTimestamperConfig `yaml:"timestamper_config,omitempty"` + Collector MetricsCollectorConfig `yaml:"collector,omitempty"` + Reporter MetricsReporterConfig `yaml:"reporter,omitempty"` +} + +var ( + DefaultMetricConfig = MetricConfig{ + Timestamper: DefaultMetricTimestamperConfig, + Collector: DefaultMetricsCollectorConfig, + Reporter: DefaultMetricsReporterConfig, + } +) diff --git a/livekit/pkg/metric/metric_timestamper.go b/livekit/pkg/metric/metric_timestamper.go new file mode 100644 index 0000000..0ee0959 --- /dev/null +++ b/livekit/pkg/metric/metric_timestamper.go @@ -0,0 +1,119 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "sync" + "time" + + "github.com/livekit/mediatransportutil/pkg/latency" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// ------------------------------------------------ + +type MetricTimestamperConfig struct { + OneWayDelayEstimatorMinInterval time.Duration `yaml:"one_way_delay_estimator_min_interval,omitempty"` + OneWayDelayEstimatorMaxBatch int `yaml:"one_way_delay_estimator_max_batch,omitempty"` +} + +var ( + DefaultMetricTimestamperConfig = MetricTimestamperConfig{ + OneWayDelayEstimatorMinInterval: 5 * time.Second, + OneWayDelayEstimatorMaxBatch: 100, + } +) + +// ------------------------------------------------ + +type MetricTimestamperParams struct { + Config MetricTimestamperConfig + Logger logger.Logger +} + +type MetricTimestamper struct { + params MetricTimestamperParams + lock sync.Mutex + owdEstimator *latency.OWDEstimator + lastOWDEstimatorRunAt time.Time + batchesSinceLastOWDEstimatorRun int +} + +func NewMetricTimestamper(params MetricTimestamperParams) *MetricTimestamper { + return &MetricTimestamper{ + params: params, + owdEstimator: latency.NewOWDEstimator(latency.OWDEstimatorParamsDefault), + lastOWDEstimatorRunAt: time.Now().Add(-params.Config.OneWayDelayEstimatorMinInterval), + } +} + +func (m *MetricTimestamper) Process(batch *livekit.MetricsBatch) { + if m == nil { + return + } + + // run OWD estimation periodically + estimatedOWDNanos := m.maybeRunOWDEstimator(batch) + + // normalize all time stamps and add estimated OWD + // NOTE: all timestamps will be re-mapped. If the time series or event happened some time + // in the past and the OWD estimation has changed since, those samples will get the updated + // OWD estimation applied. So, they may have more uncertainty in addition to the uncertainty + // of OWD estimation process. + batch.NormalizedTimestamp = timestamppb.New(time.Unix(0, batch.TimestampMs*1e6+estimatedOWDNanos)) + + for _, ts := range batch.TimeSeries { + for _, sample := range ts.Samples { + sample.NormalizedTimestamp = timestamppb.New(time.Unix(0, sample.TimestampMs*1e6+estimatedOWDNanos)) + } + } + + for _, ev := range batch.Events { + ev.NormalizedStartTimestamp = timestamppb.New(time.Unix(0, ev.StartTimestampMs*1e6+estimatedOWDNanos)) + + endTimestampMs := ev.GetEndTimestampMs() + if endTimestampMs != 0 { + ev.NormalizedEndTimestamp = timestamppb.New(time.Unix(0, endTimestampMs*1e6+estimatedOWDNanos)) + } + } + + m.params.Logger.Debugw("timestamped metrics batch", "batch", logger.Proto(batch)) +} + +func (m *MetricTimestamper) maybeRunOWDEstimator(batch *livekit.MetricsBatch) int64 { + m.lock.Lock() + defer m.lock.Unlock() + + if time.Since(m.lastOWDEstimatorRunAt) < m.params.Config.OneWayDelayEstimatorMinInterval && + m.batchesSinceLastOWDEstimatorRun < m.params.Config.OneWayDelayEstimatorMaxBatch { + m.batchesSinceLastOWDEstimatorRun++ + return m.owdEstimator.EstimatedPropagationDelay() + } + + senderClockTime := batch.GetTimestampMs() + if senderClockTime == 0 { + m.batchesSinceLastOWDEstimatorRun++ + return m.owdEstimator.EstimatedPropagationDelay() + } + + m.lastOWDEstimatorRunAt = time.Now() + m.batchesSinceLastOWDEstimatorRun = 1 + + estimatedOWDNs, _ := m.owdEstimator.Update(senderClockTime*1e6, mono.UnixNano()) + return estimatedOWDNs +} diff --git a/livekit/pkg/metric/metrics_collector.go b/livekit/pkg/metric/metrics_collector.go new file mode 100644 index 0000000..bcd2809 --- /dev/null +++ b/livekit/pkg/metric/metrics_collector.go @@ -0,0 +1,225 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + + "github.com/livekit/protocol/utils" +) + +type MetricsCollectorProvider interface { + MetricsCollectorTimeToCollectMetrics() + MetricsCollectorBatchReady(mb *livekit.MetricsBatch) +} + +// -------------------------------------------------------- + +type MetricsCollectorConfig struct { + SamplingIntervalMs uint32 `yaml:"sampling_interval_ms,omitempty" json:"sampling_interval_ms,omitempty"` + BatchIntervalMs uint32 `yaml:"batch_interval_ms,omitempty" json:"batch_interval_ms,omitempty"` +} + +var ( + DefaultMetricsCollectorConfig = MetricsCollectorConfig{ + SamplingIntervalMs: 3 * 1000, + BatchIntervalMs: 10 * 1000, + } +) + +// -------------------------------------------------------- + +type MetricsCollectorParams struct { + ParticipantIdentity livekit.ParticipantIdentity + Config MetricsCollectorConfig + Provider MetricsCollectorProvider + Logger logger.Logger +} + +type MetricsCollector struct { + params MetricsCollectorParams + + lock sync.RWMutex + mbb *utils.MetricsBatchBuilder + publisherRTTMetricId map[livekit.ParticipantIdentity]int + subscriberRTTMetricId int + relayRTTMetricId map[livekit.ParticipantIdentity]int + + stop core.Fuse +} + +func NewMetricsCollector(params MetricsCollectorParams) *MetricsCollector { + mc := &MetricsCollector{ + params: params, + } + mc.reset() + + go mc.worker() + return mc +} + +func (mc *MetricsCollector) Stop() { + if mc != nil { + mc.stop.Break() + } +} + +func (mc *MetricsCollector) AddPublisherRTT(participantIdentity livekit.ParticipantIdentity, rtt float32) { + mc.lock.Lock() + defer mc.lock.Unlock() + + metricId, ok := mc.publisherRTTMetricId[participantIdentity] + if !ok { + var err error + metricId, err = mc.createTimeSeriesMetric(livekit.MetricLabel_PUBLISHER_RTT, participantIdentity) + if err != nil { + mc.params.Logger.Warnw("could not add time series metric for publisher RTT", err) + return + } + + mc.publisherRTTMetricId[participantIdentity] = metricId + } + + mc.addTimeSeriesMetricSample(metricId, rtt) +} + +func (mc *MetricsCollector) AddSubscriberRTT(rtt float32) { + mc.lock.Lock() + defer mc.lock.Unlock() + + if mc.subscriberRTTMetricId == utils.MetricsBatchBuilderInvalidTimeSeriesMetricId { + var err error + mc.subscriberRTTMetricId, err = mc.createTimeSeriesMetric(livekit.MetricLabel_SUBSCRIBER_RTT, mc.params.ParticipantIdentity) + if err != nil { + mc.params.Logger.Warnw("could not add time series metric for publisher RTT", err) + return + } + } + + mc.addTimeSeriesMetricSample(mc.subscriberRTTMetricId, rtt) +} + +func (mc *MetricsCollector) AddRelayRTT(participantIdentity livekit.ParticipantIdentity, rtt float32) { + mc.lock.Lock() + defer mc.lock.Unlock() + + metricId, ok := mc.relayRTTMetricId[participantIdentity] + if !ok { + var err error + metricId, err = mc.createTimeSeriesMetric(livekit.MetricLabel_SERVER_MESH_RTT, participantIdentity) + if err != nil { + mc.params.Logger.Warnw("could not add time series metric for server mesh RTT", err) + return + } + + mc.relayRTTMetricId[participantIdentity] = metricId + } + + mc.addTimeSeriesMetricSample(metricId, rtt) +} + +func (mc *MetricsCollector) getMetricsBatchAndReset() *livekit.MetricsBatch { + mc.lock.Lock() + mbb := mc.mbb + + mc.reset() + mc.lock.Unlock() + + if mbb.IsEmpty() { + return nil + } + + now := mono.Now() + mbb.SetTime(now, now) + return mbb.ToProto() +} + +func (mc *MetricsCollector) reset() { + mc.mbb = utils.NewMetricsBatchBuilder() + mc.mbb.SetRestrictedLabels(utils.MetricRestrictedLabels{ + LabelRanges: []utils.MetricLabelRange{ + { + StartInclusive: livekit.MetricLabel_CLIENT_VIDEO_SUBSCRIBER_FREEZE_COUNT, + EndInclusive: livekit.MetricLabel_CLIENT_VIDEO_PUBLISHER_QUALITY_LIMITATION_DURATION_OTHER, + }, + }, + ParticipantIdentity: mc.params.ParticipantIdentity, + }) + + mc.publisherRTTMetricId = make(map[livekit.ParticipantIdentity]int) + mc.subscriberRTTMetricId = utils.MetricsBatchBuilderInvalidTimeSeriesMetricId + mc.relayRTTMetricId = make(map[livekit.ParticipantIdentity]int) +} + +func (mc *MetricsCollector) createTimeSeriesMetric( + label livekit.MetricLabel, + participantIdentity livekit.ParticipantIdentity, +) (int, error) { + return mc.mbb.AddTimeSeriesMetric(utils.TimeSeriesMetric{ + MetricLabel: label, + ParticipantIdentity: participantIdentity, + }, + ) +} + +func (mc *MetricsCollector) addTimeSeriesMetricSample(metricId int, value float32) { + now := mono.Now() + if err := mc.mbb.AddMetricSamplesToTimeSeriesMetric(metricId, []utils.MetricSample{ + { + At: now, + NormalizedAt: now, + Value: value, + }, + }); err != nil { + mc.params.Logger.Warnw("could not add metric sample", err, "metricId", metricId) + } +} + +func (mc *MetricsCollector) worker() { + samplingIntervalMs := mc.params.Config.SamplingIntervalMs + if samplingIntervalMs == 0 { + samplingIntervalMs = DefaultMetricsCollectorConfig.SamplingIntervalMs + } + samplingTicker := time.NewTicker(time.Duration(samplingIntervalMs) * time.Millisecond) + defer samplingTicker.Stop() + + batchIntervalMs := mc.params.Config.BatchIntervalMs + if batchIntervalMs < samplingIntervalMs { + batchIntervalMs = samplingIntervalMs + } + batchTicker := time.NewTicker(time.Duration(batchIntervalMs) * time.Millisecond) + defer batchTicker.Stop() + + for { + select { + case <-samplingTicker.C: + mc.params.Provider.MetricsCollectorTimeToCollectMetrics() + + case <-batchTicker.C: + if mb := mc.getMetricsBatchAndReset(); mb != nil { + mc.params.Provider.MetricsCollectorBatchReady(mb) + } + + case <-mc.stop.Watch(): + return + } + } +} diff --git a/livekit/pkg/metric/metrics_reporter.go b/livekit/pkg/metric/metrics_reporter.go new file mode 100644 index 0000000..78f4b4c --- /dev/null +++ b/livekit/pkg/metric/metrics_reporter.go @@ -0,0 +1,138 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + + "github.com/livekit/protocol/utils" +) + +type MetricsReporterConsumer interface { + MetricsReporterBatchReady(mb *livekit.MetricsBatch) +} + +// -------------------------------------------------------- + +type MetricsReporterConfig struct { + ReportingIntervalMs uint32 `yaml:"reporting_interval_ms,omitempty" json:"reporting_interval_ms,omitempty"` +} + +var ( + DefaultMetricsReporterConfig = MetricsReporterConfig{ + ReportingIntervalMs: 10 * 1000, + } +) + +// -------------------------------------------------------- + +type MetricsReporterParams struct { + ParticipantIdentity livekit.ParticipantIdentity + Config MetricsReporterConfig + Consumer MetricsReporterConsumer + Logger logger.Logger +} + +type MetricsReporter struct { + params MetricsReporterParams + + lock sync.RWMutex + mbb *utils.MetricsBatchBuilder + + stop core.Fuse +} + +func NewMetricsReporter(params MetricsReporterParams) *MetricsReporter { + mr := &MetricsReporter{ + params: params, + } + mr.reset() + + go mr.worker() + return mr +} + +func (mr *MetricsReporter) Stop() { + if mr != nil { + mr.stop.Break() + } +} + +func (mr *MetricsReporter) Merge(other *livekit.MetricsBatch) { + if mr == nil { + return + } + + mr.lock.Lock() + defer mr.lock.Unlock() + + mr.mbb.Merge(other) +} + +func (mr *MetricsReporter) getMetricsBatchAndReset() *livekit.MetricsBatch { + mr.lock.Lock() + mbb := mr.mbb + + mr.reset() + mr.lock.Unlock() + + if mbb.IsEmpty() { + return nil + } + + now := mono.Now() + mbb.SetTime(now, now) + return mbb.ToProto() +} + +func (mr *MetricsReporter) reset() { + mr.mbb = utils.NewMetricsBatchBuilder() + mr.mbb.SetRestrictedLabels(utils.MetricRestrictedLabels{ + LabelRanges: []utils.MetricLabelRange{ + { + StartInclusive: livekit.MetricLabel_CLIENT_VIDEO_SUBSCRIBER_FREEZE_COUNT, + EndInclusive: livekit.MetricLabel_CLIENT_VIDEO_PUBLISHER_QUALITY_LIMITATION_DURATION_OTHER, + }, + }, + ParticipantIdentity: mr.params.ParticipantIdentity, + }) +} + +func (mr *MetricsReporter) worker() { + reportingIntervalMs := mr.params.Config.ReportingIntervalMs + if reportingIntervalMs == 0 { + reportingIntervalMs = DefaultMetricsReporterConfig.ReportingIntervalMs + } + reportingTicker := time.NewTicker(time.Duration(reportingIntervalMs) * time.Millisecond) + defer reportingTicker.Stop() + + for { + select { + case <-reportingTicker.C: + if mb := mr.getMetricsBatchAndReset(); mb != nil { + mr.params.Consumer.MetricsReporterBatchReady(mb) + } + + case <-mr.stop.Watch(): + return + } + } +} diff --git a/livekit/pkg/routing/errors.go b/livekit/pkg/routing/errors.go new file mode 100644 index 0000000..28a0dd1 --- /dev/null +++ b/livekit/pkg/routing/errors.go @@ -0,0 +1,36 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "errors" +) + +var ( + ErrNotFound = errors.New("could not find object") + ErrIPNotSet = errors.New("ip address is required and not set") + ErrHandlerNotDefined = errors.New("handler not defined") + ErrIncorrectRTCNode = errors.New("current node isn't the RTC node for the room") + ErrNodeNotFound = errors.New("could not locate the node") + ErrNodeLimitReached = errors.New("reached configured limit for node") + ErrInvalidRouterMessage = errors.New("invalid router message") + ErrChannelClosed = errors.New("channel closed") + ErrChannelFull = errors.New("channel is full") + + // errors when starting signal connection + ErrRequestChannelClosed = errors.New("request channel closed") + ErrCouldNotMigrateParticipant = errors.New("could not migrate participant") + ErrClientInfoNotSet = errors.New("client info not set") +) diff --git a/livekit/pkg/routing/interfaces.go b/livekit/pkg/routing/interfaces.go new file mode 100644 index 0000000..6b2e1b6 --- /dev/null +++ b/livekit/pkg/routing/interfaces.go @@ -0,0 +1,318 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "encoding/json" + + "github.com/redis/go-redis/v9" + "go.uber.org/atomic" + "go.uber.org/zap/zapcore" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +// MessageSink is an abstraction for writing protobuf messages and having them read by a MessageSource, +// potentially on a different node via a transport +// +//counterfeiter:generate . MessageSink +type MessageSink interface { + WriteMessage(msg proto.Message) error + IsClosed() bool + Close() + ConnectionID() livekit.ConnectionID +} + +// ---------- + +type NullMessageSink struct { + connID livekit.ConnectionID + isClosed atomic.Bool +} + +func NewNullMessageSink(connID livekit.ConnectionID) *NullMessageSink { + return &NullMessageSink{ + connID: connID, + } +} + +func (n *NullMessageSink) WriteMessage(_msg proto.Message) error { + return nil +} + +func (n *NullMessageSink) IsClosed() bool { + return n.isClosed.Load() +} + +func (n *NullMessageSink) Close() { + n.isClosed.Store(true) +} + +func (n *NullMessageSink) ConnectionID() livekit.ConnectionID { + return n.connID +} + +// ------------------------------------------------ + +//counterfeiter:generate . MessageSource +type MessageSource interface { + // ReadChan exposes a one way channel to make it easier to use with select + ReadChan() <-chan proto.Message + IsClosed() bool + Close() + ConnectionID() livekit.ConnectionID +} + +// ---------- + +type NullMessageSource struct { + connID livekit.ConnectionID + msgChan chan proto.Message + isClosed atomic.Bool +} + +func NewNullMessageSource(connID livekit.ConnectionID) *NullMessageSource { + return &NullMessageSource{ + connID: connID, + msgChan: make(chan proto.Message), + } +} + +func (n *NullMessageSource) ReadChan() <-chan proto.Message { + return n.msgChan +} + +func (n *NullMessageSource) IsClosed() bool { + return n.isClosed.Load() +} + +func (n *NullMessageSource) Close() { + if !n.isClosed.Swap(true) { + close(n.msgChan) + } +} + +func (n *NullMessageSource) ConnectionID() livekit.ConnectionID { + return n.connID +} + +// ------------------------------------------------ + +// Router allows multiple nodes to coordinate the participant session +// +//counterfeiter:generate . Router +type Router interface { + MessageRouter + + RegisterNode() error + UnregisterNode() error + RemoveDeadNodes() error + + ListNodes() ([]*livekit.Node, error) + + GetNodeForRoom(ctx context.Context, roomName livekit.RoomName) (*livekit.Node, error) + SetNodeForRoom(ctx context.Context, roomName livekit.RoomName, nodeId livekit.NodeID) error + ClearRoomState(ctx context.Context, roomName livekit.RoomName) error + + GetRegion() string + + Start() error + Drain() + Stop() +} + +type StartParticipantSignalResults struct { + ConnectionID livekit.ConnectionID + RequestSink MessageSink + ResponseSource MessageSource + NodeID livekit.NodeID + NodeSelectionReason string +} + +type MessageRouter interface { + // CreateRoom starts an rtc room + CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (res *livekit.Room, err error) + + // StartParticipantSignal participant signal connection is ready to start + StartParticipantSignal( + ctx context.Context, + roomName livekit.RoomName, + pi ParticipantInit, + ) (res StartParticipantSignalResults, err error) +} + +func CreateRouter( + rc redis.UniversalClient, + node LocalNode, + signalClient SignalClient, + roomManagerClient RoomManagerClient, + kps rpc.KeepalivePubSub, + nodeStatsConfig config.NodeStatsConfig, +) Router { + lr := NewLocalRouter(node, signalClient, roomManagerClient, nodeStatsConfig) + + if rc != nil { + return NewRedisRouter(lr, rc, kps) + } + + // local routing and store + logger.Infow("using single-node routing") + return lr +} + +// ------------------------------------------------ + +type ParticipantInit struct { + Identity livekit.ParticipantIdentity + Name livekit.ParticipantName + Reconnect bool + ReconnectReason livekit.ReconnectReason + AutoSubscribe bool + AutoSubscribeDataTrack *bool + Client *livekit.ClientInfo + Grants *auth.ClaimGrants + Region string + AdaptiveStream bool + ID livekit.ParticipantID + SubscriberAllowPause *bool + DisableICELite bool + CreateRoom *livekit.CreateRoomRequest + AddTrackRequests []*livekit.AddTrackRequest + PublisherOffer *livekit.SessionDescription + SyncState *livekit.SyncState + UseSinglePeerConnection bool +} + +func (pi *ParticipantInit) MarshalLogObject(e zapcore.ObjectEncoder) error { + if pi == nil { + return nil + } + + logBoolPtr := func(prop string, val *bool) { + if val == nil { + e.AddString(prop, "not-set") + } else { + e.AddBool(prop, *val) + } + } + + e.AddString("Identity", string(pi.Identity)) + logBoolPtr("Reconnect", &pi.Reconnect) + e.AddString("ReconnectReason", pi.ReconnectReason.String()) + logBoolPtr("AutoSubscribe", &pi.AutoSubscribe) + logBoolPtr("AutoSubscribeDataTrack", pi.AutoSubscribeDataTrack) + e.AddObject("Client", logger.Proto(utils.ClientInfoWithoutAddress(pi.Client))) + e.AddObject("Grants", pi.Grants) + e.AddString("Region", pi.Region) + logBoolPtr("AdaptiveStream", &pi.AdaptiveStream) + e.AddString("ID", string(pi.ID)) + logBoolPtr("SubscriberAllowPause", pi.SubscriberAllowPause) + logBoolPtr("DisableICELite", &pi.DisableICELite) + e.AddObject("CreateRoom", logger.Proto(pi.CreateRoom)) + e.AddArray("AddTrackRequests", logger.ProtoSlice(pi.AddTrackRequests)) + e.AddObject("PublisherOffer", logger.Proto(pi.PublisherOffer)) + e.AddObject("SyncState", logger.Proto(pi.SyncState)) + logBoolPtr("UseSinglePeerConnection", &pi.UseSinglePeerConnection) + return nil +} + +func (pi *ParticipantInit) ToStartSession(roomName livekit.RoomName, connectionID livekit.ConnectionID) (*livekit.StartSession, error) { + claims, err := json.Marshal(pi.Grants) + if err != nil { + return nil, err + } + + ss := &livekit.StartSession{ + RoomName: string(roomName), + Identity: string(pi.Identity), + Name: string(pi.Name), + ConnectionId: string(connectionID), + Reconnect: pi.Reconnect, + ReconnectReason: pi.ReconnectReason, + AutoSubscribe: pi.AutoSubscribe, + Client: pi.Client, + GrantsJson: string(claims), + AdaptiveStream: pi.AdaptiveStream, + ParticipantId: string(pi.ID), + DisableIceLite: pi.DisableICELite, + CreateRoom: pi.CreateRoom, + AddTrackRequests: pi.AddTrackRequests, + PublisherOffer: pi.PublisherOffer, + SyncState: pi.SyncState, + UseSinglePeerConnection: pi.UseSinglePeerConnection, + } + if pi.AutoSubscribeDataTrack != nil { + autoSubscribeDataTrack := *pi.AutoSubscribeDataTrack + ss.AutoSubscribeDataTrack = &autoSubscribeDataTrack + } + if pi.SubscriberAllowPause != nil { + subscriberAllowPause := *pi.SubscriberAllowPause + ss.SubscriberAllowPause = &subscriberAllowPause + } + + return ss, nil +} + +func ParticipantInitFromStartSession(ss *livekit.StartSession, region string) (*ParticipantInit, error) { + claims := &auth.ClaimGrants{} + if err := json.Unmarshal([]byte(ss.GrantsJson), claims); err != nil { + return nil, err + } + + pi := &ParticipantInit{ + Identity: livekit.ParticipantIdentity(ss.Identity), + Name: livekit.ParticipantName(ss.Name), + Reconnect: ss.Reconnect, + ReconnectReason: ss.ReconnectReason, + Client: ss.Client, + AutoSubscribe: ss.AutoSubscribe, + Grants: claims, + Region: region, + AdaptiveStream: ss.AdaptiveStream, + ID: livekit.ParticipantID(ss.ParticipantId), + DisableICELite: ss.DisableIceLite, + CreateRoom: ss.CreateRoom, + AddTrackRequests: ss.AddTrackRequests, + PublisherOffer: ss.PublisherOffer, + SyncState: ss.SyncState, + UseSinglePeerConnection: ss.UseSinglePeerConnection, + } + if ss.AutoSubscribeDataTrack != nil { + autoSubscribeDataTrack := *ss.AutoSubscribeDataTrack + pi.AutoSubscribeDataTrack = &autoSubscribeDataTrack + } + if ss.SubscriberAllowPause != nil { + subscriberAllowPause := *ss.SubscriberAllowPause + pi.SubscriberAllowPause = &subscriberAllowPause + } + + // TODO: clean up after 1.7 eol + if pi.CreateRoom == nil { + pi.CreateRoom = &livekit.CreateRoomRequest{ + Name: ss.RoomName, + } + } + + return pi, nil +} diff --git a/livekit/pkg/routing/localrouter.go b/livekit/pkg/routing/localrouter.go new file mode 100644 index 0000000..b3e2db1 --- /dev/null +++ b/livekit/pkg/routing/localrouter.go @@ -0,0 +1,173 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "time" + + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +var _ Router = (*LocalRouter)(nil) + +// a router of messages on the same node, basic implementation for local testing +type LocalRouter struct { + currentNode LocalNode + signalClient SignalClient + roomManagerClient RoomManagerClient + nodeStatsConfig config.NodeStatsConfig + + // channels for each participant + requestChannels map[string]*MessageChannel + responseChannels map[string]*MessageChannel + isStarted atomic.Bool +} + +func NewLocalRouter( + currentNode LocalNode, + signalClient SignalClient, + roomManagerClient RoomManagerClient, + nodeStatsConfig config.NodeStatsConfig, +) *LocalRouter { + return &LocalRouter{ + currentNode: currentNode, + signalClient: signalClient, + roomManagerClient: roomManagerClient, + nodeStatsConfig: nodeStatsConfig, + requestChannels: make(map[string]*MessageChannel), + responseChannels: make(map[string]*MessageChannel), + } +} + +func (r *LocalRouter) GetNodeForRoom(_ context.Context, _ livekit.RoomName) (*livekit.Node, error) { + return r.currentNode.Clone(), nil +} + +func (r *LocalRouter) SetNodeForRoom(_ context.Context, _ livekit.RoomName, _ livekit.NodeID) error { + return nil +} + +func (r *LocalRouter) ClearRoomState(_ context.Context, _ livekit.RoomName) error { + return nil +} + +func (r *LocalRouter) RegisterNode() error { + return nil +} + +func (r *LocalRouter) UnregisterNode() error { + return nil +} + +func (r *LocalRouter) RemoveDeadNodes() error { + return nil +} + +func (r *LocalRouter) GetNode(nodeID livekit.NodeID) (*livekit.Node, error) { + if nodeID == r.currentNode.NodeID() { + return r.currentNode.Clone(), nil + } + return nil, ErrNotFound +} + +func (r *LocalRouter) ListNodes() ([]*livekit.Node, error) { + return []*livekit.Node{ + r.currentNode.Clone(), + }, nil +} + +func (r *LocalRouter) CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (res *livekit.Room, err error) { + return r.CreateRoomWithNodeID(ctx, req, r.currentNode.NodeID()) +} + +func (r *LocalRouter) CreateRoomWithNodeID(ctx context.Context, req *livekit.CreateRoomRequest, nodeID livekit.NodeID) (res *livekit.Room, err error) { + return r.roomManagerClient.CreateRoom(ctx, nodeID, req) +} + +func (r *LocalRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error) { + return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, r.currentNode.NodeID()) +} + +func (r *LocalRouter) StartParticipantSignalWithNodeID(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (res StartParticipantSignalResults, err error) { + connectionID, reqSink, resSource, err := r.signalClient.StartParticipantSignal(ctx, roomName, pi, nodeID) + if err != nil { + logger.Errorw( + "could not handle new participant", err, + "room", roomName, + "participant", pi.Identity, + "connID", connectionID, + ) + } else { + return StartParticipantSignalResults{ + ConnectionID: connectionID, + RequestSink: reqSink, + ResponseSource: resSource, + NodeID: nodeID, + }, nil + } + return +} + +func (r *LocalRouter) Start() error { + if r.isStarted.Swap(true) { + return nil + } + go r.statsWorker() + // go r.memStatsWorker() + return nil +} + +func (r *LocalRouter) Drain() { + r.currentNode.SetState(livekit.NodeState_SHUTTING_DOWN) +} + +func (r *LocalRouter) Stop() {} + +func (r *LocalRouter) GetRegion() string { + return r.currentNode.Region() +} + +func (r *LocalRouter) statsWorker() { + for { + if !r.isStarted.Load() { + return + } + <-time.After(r.nodeStatsConfig.StatsUpdateInterval) + r.currentNode.UpdateNodeStats() + } +} + +/* + func (r *LocalRouter) memStatsWorker() { + ticker := time.NewTicker(time.Second * 30) + defer ticker.Stop() + + for { + <-ticker.C + + var m runtime.MemStats + runtime.ReadMemStats(&m) + logger.Infow("memstats", + "mallocs", m.Mallocs, "frees", m.Frees, "m-f", m.Mallocs-m.Frees, + "hinuse", m.HeapInuse, "halloc", m.HeapAlloc, "frag", m.HeapInuse-m.HeapAlloc, + ) + } + } +*/ diff --git a/livekit/pkg/routing/messagechannel.go b/livekit/pkg/routing/messagechannel.go new file mode 100644 index 0000000..e761f3a --- /dev/null +++ b/livekit/pkg/routing/messagechannel.go @@ -0,0 +1,94 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "sync" + + "github.com/livekit/protocol/livekit" + "google.golang.org/protobuf/proto" +) + +const DefaultMessageChannelSize = 200 + +type MessageChannel struct { + connectionID livekit.ConnectionID + msgChan chan proto.Message + onClose func() + isClosed bool + lock sync.RWMutex +} + +func NewDefaultMessageChannel(connectionID livekit.ConnectionID) *MessageChannel { + return NewMessageChannel(connectionID, DefaultMessageChannelSize) +} + +func NewMessageChannel(connectionID livekit.ConnectionID, size int) *MessageChannel { + return &MessageChannel{ + connectionID: connectionID, + // allow some buffer to avoid blocked writes + msgChan: make(chan proto.Message, size), + } +} + +func (m *MessageChannel) OnClose(f func()) { + m.onClose = f +} + +func (m *MessageChannel) IsClosed() bool { + m.lock.RLock() + defer m.lock.RUnlock() + return m.isClosed +} + +func (m *MessageChannel) WriteMessage(msg proto.Message) error { + m.lock.RLock() + defer m.lock.RUnlock() + if m.isClosed { + return ErrChannelClosed + } + + select { + case m.msgChan <- msg: + // published + return nil + default: + // channel is full + return ErrChannelFull + } +} + +func (m *MessageChannel) ReadChan() <-chan proto.Message { + return m.msgChan +} + +func (m *MessageChannel) Close() { + m.lock.Lock() + if m.isClosed { + m.lock.Unlock() + return + } + m.isClosed = true + close(m.msgChan) + m.lock.Unlock() + + if m.onClose != nil { + m.onClose() + } +} + +func (m *MessageChannel) ConnectionID() livekit.ConnectionID { + return m.connectionID +} diff --git a/livekit/pkg/routing/messagechannel_test.go b/livekit/pkg/routing/messagechannel_test.go new file mode 100644 index 0000000..ab0ab81 --- /dev/null +++ b/livekit/pkg/routing/messagechannel_test.go @@ -0,0 +1,50 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "sync" + "testing" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/routing" +) + +func TestMessageChannel_WriteMessageClosed(t *testing.T) { + // ensure it doesn't panic when written to after closing + m := routing.NewMessageChannel(livekit.ConnectionID("test"), routing.DefaultMessageChannelSize) + go func() { + for msg := range m.ReadChan() { + if msg == nil { + return + } + } + }() + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for range 100 { + _ = m.WriteMessage(&livekit.SignalRequest{}) + } + }() + _ = m.WriteMessage(&livekit.SignalRequest{}) + m.Close() + _ = m.WriteMessage(&livekit.SignalRequest{}) + + wg.Wait() +} diff --git a/livekit/pkg/routing/node.go b/livekit/pkg/routing/node.go new file mode 100644 index 0000000..183faff --- /dev/null +++ b/livekit/pkg/routing/node.go @@ -0,0 +1,158 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "runtime" + "sync" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/config" +) + +type LocalNode interface { + Clone() *livekit.Node + SetNodeID(nodeID livekit.NodeID) + NodeID() livekit.NodeID + NodeType() livekit.NodeType + NodeIP() string + Region() string + SetState(state livekit.NodeState) + SetStats(stats *livekit.NodeStats) + UpdateNodeStats() bool + SecondsSinceNodeStatsUpdate() float64 +} + +type LocalNodeImpl struct { + lock sync.RWMutex + node *livekit.Node + + nodeStats *NodeStats +} + +func NewLocalNode(conf *config.Config) (*LocalNodeImpl, error) { + nodeID := guid.New(utils.NodePrefix) + if conf != nil && conf.RTC.NodeIP == "" { + return nil, ErrIPNotSet + } + nowUnix := time.Now().Unix() + l := &LocalNodeImpl{ + node: &livekit.Node{ + Id: nodeID, + NumCpus: uint32(runtime.NumCPU()), + State: livekit.NodeState_SERVING, + Stats: &livekit.NodeStats{ + StartedAt: nowUnix, + UpdatedAt: nowUnix, + }, + }, + } + var nsc *config.NodeStatsConfig + if conf != nil { + l.node.Ip = conf.RTC.NodeIP + l.node.Region = conf.Region + + nsc = &conf.NodeStats + } + l.nodeStats = NewNodeStats(nsc, nowUnix) + + return l, nil +} + +func NewLocalNodeFromNodeProto(node *livekit.Node) (*LocalNodeImpl, error) { + return &LocalNodeImpl{node: utils.CloneProto(node)}, nil +} + +func (l *LocalNodeImpl) Clone() *livekit.Node { + l.lock.RLock() + defer l.lock.RUnlock() + + return utils.CloneProto(l.node) +} + +// for testing only +func (l *LocalNodeImpl) SetNodeID(nodeID livekit.NodeID) { + l.lock.Lock() + defer l.lock.Unlock() + + l.node.Id = string(nodeID) +} + +func (l *LocalNodeImpl) NodeID() livekit.NodeID { + l.lock.RLock() + defer l.lock.RUnlock() + + return livekit.NodeID(l.node.Id) +} + +func (l *LocalNodeImpl) NodeType() livekit.NodeType { + l.lock.RLock() + defer l.lock.RUnlock() + + return l.node.Type +} + +func (l *LocalNodeImpl) NodeIP() string { + l.lock.RLock() + defer l.lock.RUnlock() + + return l.node.Ip +} + +func (l *LocalNodeImpl) Region() string { + l.lock.RLock() + defer l.lock.RUnlock() + + return l.node.Region +} + +func (l *LocalNodeImpl) SetState(state livekit.NodeState) { + l.lock.Lock() + defer l.lock.Unlock() + + l.node.State = state +} + +// for testing only +func (l *LocalNodeImpl) SetStats(stats *livekit.NodeStats) { + l.lock.Lock() + defer l.lock.Unlock() + + l.node.Stats = utils.CloneProto(stats) +} + +func (l *LocalNodeImpl) UpdateNodeStats() bool { + l.lock.Lock() + defer l.lock.Unlock() + + stats, err := l.nodeStats.UpdateAndGetNodeStats() + if err != nil { + return false + } + + l.node.Stats = stats + return true +} + +func (l *LocalNodeImpl) SecondsSinceNodeStatsUpdate() float64 { + l.lock.RLock() + defer l.lock.RUnlock() + + return time.Since(time.Unix(l.node.Stats.UpdatedAt, 0)).Seconds() +} diff --git a/livekit/pkg/routing/nodestats.go b/livekit/pkg/routing/nodestats.go new file mode 100644 index 0000000..999ae9c --- /dev/null +++ b/livekit/pkg/routing/nodestats.go @@ -0,0 +1,82 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "sync" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" +) + +type NodeStats struct { + config config.NodeStatsConfig + startedAt int64 + + lock sync.Mutex + statsHistory []*livekit.NodeStats + statsHistoryWritePtr int +} + +func NewNodeStats(conf *config.NodeStatsConfig, startedAt int64) *NodeStats { + n := &NodeStats{ + startedAt: startedAt, + } + n.UpdateConfig(conf) + return n +} + +func (n *NodeStats) UpdateConfig(conf *config.NodeStatsConfig) { + n.lock.Lock() + defer n.lock.Unlock() + + if conf == nil { + conf = &config.DefaultNodeStatsConfig + } + n.config = *conf + + // set up stats history to be able to measure different rate windows + var maxInterval time.Duration + for _, rateInterval := range conf.StatsRateMeasurementIntervals { + if rateInterval > maxInterval { + maxInterval = rateInterval + } + } + n.statsHistory = make([]*livekit.NodeStats, (maxInterval+conf.StatsUpdateInterval-1)/conf.StatsUpdateInterval) + n.statsHistoryWritePtr = 0 +} + +func (n *NodeStats) UpdateAndGetNodeStats() (*livekit.NodeStats, error) { + n.lock.Lock() + defer n.lock.Unlock() + + stats, err := prometheus.GetNodeStats( + n.startedAt, + append(n.statsHistory[n.statsHistoryWritePtr:], n.statsHistory[0:n.statsHistoryWritePtr]...), + n.config.StatsRateMeasurementIntervals, + ) + if err != nil { + logger.Errorw("could not update node stats", err) + return nil, err + } + + n.statsHistory[n.statsHistoryWritePtr] = stats + n.statsHistoryWritePtr = (n.statsHistoryWritePtr + 1) % len(n.statsHistory) + return stats, nil +} diff --git a/livekit/pkg/routing/redisrouter.go b/livekit/pkg/routing/redisrouter.go new file mode 100644 index 0000000..44859b4 --- /dev/null +++ b/livekit/pkg/routing/redisrouter.go @@ -0,0 +1,252 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "bytes" + "context" + "runtime/pprof" + "time" + + "github.com/pkg/errors" + "github.com/redis/go-redis/v9" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +const ( + // hash of node_id => Node proto + NodesKey = "nodes" + + // hash of room_name => node_id + NodeRoomKey = "room_node_map" +) + +var _ Router = (*RedisRouter)(nil) + +// RedisRouter uses Redis pub/sub to route signaling messages across different nodes +// It relies on the RTC node to be the primary driver of the participant connection. +// Because +type RedisRouter struct { + *LocalRouter + + rc redis.UniversalClient + kps rpc.KeepalivePubSub + ctx context.Context + isStarted atomic.Bool + + cancel func() +} + +func NewRedisRouter(lr *LocalRouter, rc redis.UniversalClient, kps rpc.KeepalivePubSub) *RedisRouter { + rr := &RedisRouter{ + LocalRouter: lr, + rc: rc, + kps: kps, + } + rr.ctx, rr.cancel = context.WithCancel(context.Background()) + return rr +} + +func (r *RedisRouter) RegisterNode() error { + data, err := proto.Marshal(r.currentNode.Clone()) + if err != nil { + return err + } + if err := r.rc.HSet(r.ctx, NodesKey, string(r.currentNode.NodeID()), data).Err(); err != nil { + return errors.Wrap(err, "could not register node") + } + return nil +} + +func (r *RedisRouter) UnregisterNode() error { + // could be called after Stop(), so we'd want to use an unrelated context + return r.rc.HDel(context.Background(), NodesKey, string(r.currentNode.NodeID())).Err() +} + +func (r *RedisRouter) RemoveDeadNodes() error { + nodes, err := r.ListNodes() + if err != nil { + return err + } + for _, n := range nodes { + if !selector.IsAvailable(n) { + if err := r.rc.HDel(context.Background(), NodesKey, n.Id).Err(); err != nil { + return err + } + } + } + return nil +} + +// GetNodeForRoom finds the node where the room is hosted at +func (r *RedisRouter) GetNodeForRoom(_ context.Context, roomName livekit.RoomName) (*livekit.Node, error) { + nodeID, err := r.rc.HGet(r.ctx, NodeRoomKey, string(roomName)).Result() + if err == redis.Nil { + return nil, ErrNotFound + } else if err != nil { + return nil, errors.Wrap(err, "could not get node for room") + } + + return r.GetNode(livekit.NodeID(nodeID)) +} + +func (r *RedisRouter) SetNodeForRoom(_ context.Context, roomName livekit.RoomName, nodeID livekit.NodeID) error { + return r.rc.HSet(r.ctx, NodeRoomKey, string(roomName), string(nodeID)).Err() +} + +func (r *RedisRouter) ClearRoomState(_ context.Context, roomName livekit.RoomName) error { + if err := r.rc.HDel(context.Background(), NodeRoomKey, string(roomName)).Err(); err != nil { + return errors.Wrap(err, "could not clear room state") + } + return nil +} + +func (r *RedisRouter) GetNode(nodeID livekit.NodeID) (*livekit.Node, error) { + data, err := r.rc.HGet(r.ctx, NodesKey, string(nodeID)).Result() + if err == redis.Nil { + return nil, ErrNotFound + } else if err != nil { + return nil, err + } + n := livekit.Node{} + if err = proto.Unmarshal([]byte(data), &n); err != nil { + return nil, err + } + return &n, nil +} + +func (r *RedisRouter) ListNodes() ([]*livekit.Node, error) { + items, err := r.rc.HVals(r.ctx, NodesKey).Result() + if err != nil { + return nil, errors.Wrap(err, "could not list nodes") + } + nodes := make([]*livekit.Node, 0, len(items)) + for _, item := range items { + n := livekit.Node{} + if err := proto.Unmarshal([]byte(item), &n); err != nil { + return nil, err + } + nodes = append(nodes, &n) + } + return nodes, nil +} + +func (r *RedisRouter) CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (res *livekit.Room, err error) { + rtcNode, err := r.GetNodeForRoom(ctx, livekit.RoomName(req.Name)) + if err != nil { + return + } + + return r.CreateRoomWithNodeID(ctx, req, livekit.NodeID(rtcNode.Id)) +} + +// StartParticipantSignal signal connection sets up paths to the RTC node, and starts to route messages to that message queue +func (r *RedisRouter) StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit) (res StartParticipantSignalResults, err error) { + rtcNode, err := r.GetNodeForRoom(ctx, roomName) + if err != nil { + return + } + + return r.StartParticipantSignalWithNodeID(ctx, roomName, pi, livekit.NodeID(rtcNode.Id)) +} + +func (r *RedisRouter) Start() error { + if r.isStarted.Swap(true) { + return nil + } + + workerStarted := make(chan error) + go r.statsWorker() + go r.keepaliveWorker(workerStarted) + + // wait until worker is running + return <-workerStarted +} + +func (r *RedisRouter) Drain() { + r.currentNode.SetState(livekit.NodeState_SHUTTING_DOWN) + if err := r.RegisterNode(); err != nil { + logger.Errorw("failed to mark as draining", err, "nodeID", r.currentNode.NodeID()) + } +} + +func (r *RedisRouter) Stop() { + if !r.isStarted.Swap(false) { + return + } + logger.Debugw("stopping RedisRouter") + _ = r.UnregisterNode() + r.cancel() +} + +// update node stats and cleanup +func (r *RedisRouter) statsWorker() { + goroutineDumped := false + for r.ctx.Err() == nil { + // update periodically + select { + case <-time.After(r.nodeStatsConfig.StatsUpdateInterval): + r.kps.PublishPing(r.ctx, r.currentNode.NodeID(), &rpc.KeepalivePing{Timestamp: time.Now().Unix()}) + + delaySeconds := r.currentNode.SecondsSinceNodeStatsUpdate() + if delaySeconds > r.nodeStatsConfig.StatsMaxDelay.Seconds() { + if !goroutineDumped { + goroutineDumped = true + buf := bytes.NewBuffer(nil) + _ = pprof.Lookup("goroutine").WriteTo(buf, 2) + logger.Errorw("status update delayed, possible deadlock", nil, + "delay", delaySeconds, + "goroutines", buf.String()) + } + } else { + goroutineDumped = false + } + case <-r.ctx.Done(): + return + } + } +} + +func (r *RedisRouter) keepaliveWorker(startedChan chan error) { + pings, err := r.kps.SubscribePing(r.ctx, r.currentNode.NodeID()) + if err != nil { + startedChan <- err + return + } + close(startedChan) + + for ping := range pings.Channel() { + if time.Since(time.Unix(ping.Timestamp, 0)) > r.nodeStatsConfig.StatsUpdateInterval { + logger.Infow("keep alive too old, skipping", "timestamp", ping.Timestamp) + continue + } + + if !r.currentNode.UpdateNodeStats() { + continue + } + + // TODO: check stats against config.Limit values + if err := r.RegisterNode(); err != nil { + logger.Errorw("could not update node", err) + } + } +} diff --git a/livekit/pkg/routing/roommanager.go b/livekit/pkg/routing/roommanager.go new file mode 100644 index 0000000..796ffe7 --- /dev/null +++ b/livekit/pkg/routing/roommanager.go @@ -0,0 +1,63 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/middleware" +) + +//counterfeiter:generate . RoomManagerClient +type RoomManagerClient interface { + rpc.TypedRoomManagerClient +} + +type roomManagerClient struct { + config config.RoomConfig + client rpc.TypedRoomManagerClient +} + +func NewRoomManagerClient(clientParams rpc.ClientParams, config config.RoomConfig) (RoomManagerClient, error) { + c, err := rpc.NewTypedRoomManagerClient( + clientParams.Bus, + psrpc.WithClientChannelSize(clientParams.BufferSize), + middleware.WithClientMetrics(clientParams.Observer), + rpc.WithClientLogger(clientParams.Logger), + ) + if err != nil { + return nil, err + } + + return &roomManagerClient{ + config: config, + client: c, + }, nil +} + +func (c *roomManagerClient) CreateRoom(ctx context.Context, nodeID livekit.NodeID, req *livekit.CreateRoomRequest, opts ...psrpc.RequestOption) (*livekit.Room, error) { + return c.client.CreateRoom(ctx, nodeID, req, append(opts, psrpc.WithRequestInterceptors(middleware.NewRPCRetryInterceptor(middleware.RetryOptions{ + MaxAttempts: c.config.CreateRoomAttempts, + Timeout: c.config.CreateRoomTimeout, + })))...) +} + +func (c *roomManagerClient) Close() { + c.client.Close() +} diff --git a/livekit/pkg/routing/routingfakes/fake_message_sink.go b/livekit/pkg/routing/routingfakes/fake_message_sink.go new file mode 100644 index 0000000..5069f8b --- /dev/null +++ b/livekit/pkg/routing/routingfakes/fake_message_sink.go @@ -0,0 +1,265 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package routingfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" + "google.golang.org/protobuf/proto" +) + +type FakeMessageSink struct { + CloseStub func() + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + ConnectionIDStub func() livekit.ConnectionID + connectionIDMutex sync.RWMutex + connectionIDArgsForCall []struct { + } + connectionIDReturns struct { + result1 livekit.ConnectionID + } + connectionIDReturnsOnCall map[int]struct { + result1 livekit.ConnectionID + } + IsClosedStub func() bool + isClosedMutex sync.RWMutex + isClosedArgsForCall []struct { + } + isClosedReturns struct { + result1 bool + } + isClosedReturnsOnCall map[int]struct { + result1 bool + } + WriteMessageStub func(proto.Message) error + writeMessageMutex sync.RWMutex + writeMessageArgsForCall []struct { + arg1 proto.Message + } + writeMessageReturns struct { + result1 error + } + writeMessageReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeMessageSink) Close() { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub() + } +} + +func (fake *FakeMessageSink) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeMessageSink) CloseCalls(stub func()) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeMessageSink) ConnectionID() livekit.ConnectionID { + fake.connectionIDMutex.Lock() + ret, specificReturn := fake.connectionIDReturnsOnCall[len(fake.connectionIDArgsForCall)] + fake.connectionIDArgsForCall = append(fake.connectionIDArgsForCall, struct { + }{}) + stub := fake.ConnectionIDStub + fakeReturns := fake.connectionIDReturns + fake.recordInvocation("ConnectionID", []interface{}{}) + fake.connectionIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSink) ConnectionIDCallCount() int { + fake.connectionIDMutex.RLock() + defer fake.connectionIDMutex.RUnlock() + return len(fake.connectionIDArgsForCall) +} + +func (fake *FakeMessageSink) ConnectionIDCalls(stub func() livekit.ConnectionID) { + fake.connectionIDMutex.Lock() + defer fake.connectionIDMutex.Unlock() + fake.ConnectionIDStub = stub +} + +func (fake *FakeMessageSink) ConnectionIDReturns(result1 livekit.ConnectionID) { + fake.connectionIDMutex.Lock() + defer fake.connectionIDMutex.Unlock() + fake.ConnectionIDStub = nil + fake.connectionIDReturns = struct { + result1 livekit.ConnectionID + }{result1} +} + +func (fake *FakeMessageSink) ConnectionIDReturnsOnCall(i int, result1 livekit.ConnectionID) { + fake.connectionIDMutex.Lock() + defer fake.connectionIDMutex.Unlock() + fake.ConnectionIDStub = nil + if fake.connectionIDReturnsOnCall == nil { + fake.connectionIDReturnsOnCall = make(map[int]struct { + result1 livekit.ConnectionID + }) + } + fake.connectionIDReturnsOnCall[i] = struct { + result1 livekit.ConnectionID + }{result1} +} + +func (fake *FakeMessageSink) IsClosed() bool { + fake.isClosedMutex.Lock() + ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] + fake.isClosedArgsForCall = append(fake.isClosedArgsForCall, struct { + }{}) + stub := fake.IsClosedStub + fakeReturns := fake.isClosedReturns + fake.recordInvocation("IsClosed", []interface{}{}) + fake.isClosedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSink) IsClosedCallCount() int { + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() + return len(fake.isClosedArgsForCall) +} + +func (fake *FakeMessageSink) IsClosedCalls(stub func() bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = stub +} + +func (fake *FakeMessageSink) IsClosedReturns(result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + fake.isClosedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMessageSink) IsClosedReturnsOnCall(i int, result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + if fake.isClosedReturnsOnCall == nil { + fake.isClosedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isClosedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeMessageSink) WriteMessage(arg1 proto.Message) error { + fake.writeMessageMutex.Lock() + ret, specificReturn := fake.writeMessageReturnsOnCall[len(fake.writeMessageArgsForCall)] + fake.writeMessageArgsForCall = append(fake.writeMessageArgsForCall, struct { + arg1 proto.Message + }{arg1}) + stub := fake.WriteMessageStub + fakeReturns := fake.writeMessageReturns + fake.recordInvocation("WriteMessage", []interface{}{arg1}) + fake.writeMessageMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSink) WriteMessageCallCount() int { + fake.writeMessageMutex.RLock() + defer fake.writeMessageMutex.RUnlock() + return len(fake.writeMessageArgsForCall) +} + +func (fake *FakeMessageSink) WriteMessageCalls(stub func(proto.Message) error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = stub +} + +func (fake *FakeMessageSink) WriteMessageArgsForCall(i int) proto.Message { + fake.writeMessageMutex.RLock() + defer fake.writeMessageMutex.RUnlock() + argsForCall := fake.writeMessageArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMessageSink) WriteMessageReturns(result1 error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = nil + fake.writeMessageReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeMessageSink) WriteMessageReturnsOnCall(i int, result1 error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = nil + if fake.writeMessageReturnsOnCall == nil { + fake.writeMessageReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeMessageSink) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeMessageSink) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ routing.MessageSink = new(FakeMessageSink) diff --git a/livekit/pkg/routing/routingfakes/fake_message_source.go b/livekit/pkg/routing/routingfakes/fake_message_source.go new file mode 100644 index 0000000..37eec2f --- /dev/null +++ b/livekit/pkg/routing/routingfakes/fake_message_source.go @@ -0,0 +1,256 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package routingfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" + "google.golang.org/protobuf/proto" +) + +type FakeMessageSource struct { + CloseStub func() + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + ConnectionIDStub func() livekit.ConnectionID + connectionIDMutex sync.RWMutex + connectionIDArgsForCall []struct { + } + connectionIDReturns struct { + result1 livekit.ConnectionID + } + connectionIDReturnsOnCall map[int]struct { + result1 livekit.ConnectionID + } + IsClosedStub func() bool + isClosedMutex sync.RWMutex + isClosedArgsForCall []struct { + } + isClosedReturns struct { + result1 bool + } + isClosedReturnsOnCall map[int]struct { + result1 bool + } + ReadChanStub func() <-chan proto.Message + readChanMutex sync.RWMutex + readChanArgsForCall []struct { + } + readChanReturns struct { + result1 <-chan proto.Message + } + readChanReturnsOnCall map[int]struct { + result1 <-chan proto.Message + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeMessageSource) Close() { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub() + } +} + +func (fake *FakeMessageSource) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeMessageSource) CloseCalls(stub func()) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeMessageSource) ConnectionID() livekit.ConnectionID { + fake.connectionIDMutex.Lock() + ret, specificReturn := fake.connectionIDReturnsOnCall[len(fake.connectionIDArgsForCall)] + fake.connectionIDArgsForCall = append(fake.connectionIDArgsForCall, struct { + }{}) + stub := fake.ConnectionIDStub + fakeReturns := fake.connectionIDReturns + fake.recordInvocation("ConnectionID", []interface{}{}) + fake.connectionIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSource) ConnectionIDCallCount() int { + fake.connectionIDMutex.RLock() + defer fake.connectionIDMutex.RUnlock() + return len(fake.connectionIDArgsForCall) +} + +func (fake *FakeMessageSource) ConnectionIDCalls(stub func() livekit.ConnectionID) { + fake.connectionIDMutex.Lock() + defer fake.connectionIDMutex.Unlock() + fake.ConnectionIDStub = stub +} + +func (fake *FakeMessageSource) ConnectionIDReturns(result1 livekit.ConnectionID) { + fake.connectionIDMutex.Lock() + defer fake.connectionIDMutex.Unlock() + fake.ConnectionIDStub = nil + fake.connectionIDReturns = struct { + result1 livekit.ConnectionID + }{result1} +} + +func (fake *FakeMessageSource) ConnectionIDReturnsOnCall(i int, result1 livekit.ConnectionID) { + fake.connectionIDMutex.Lock() + defer fake.connectionIDMutex.Unlock() + fake.ConnectionIDStub = nil + if fake.connectionIDReturnsOnCall == nil { + fake.connectionIDReturnsOnCall = make(map[int]struct { + result1 livekit.ConnectionID + }) + } + fake.connectionIDReturnsOnCall[i] = struct { + result1 livekit.ConnectionID + }{result1} +} + +func (fake *FakeMessageSource) IsClosed() bool { + fake.isClosedMutex.Lock() + ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] + fake.isClosedArgsForCall = append(fake.isClosedArgsForCall, struct { + }{}) + stub := fake.IsClosedStub + fakeReturns := fake.isClosedReturns + fake.recordInvocation("IsClosed", []interface{}{}) + fake.isClosedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSource) IsClosedCallCount() int { + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() + return len(fake.isClosedArgsForCall) +} + +func (fake *FakeMessageSource) IsClosedCalls(stub func() bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = stub +} + +func (fake *FakeMessageSource) IsClosedReturns(result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + fake.isClosedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMessageSource) IsClosedReturnsOnCall(i int, result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + if fake.isClosedReturnsOnCall == nil { + fake.isClosedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isClosedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeMessageSource) ReadChan() <-chan proto.Message { + fake.readChanMutex.Lock() + ret, specificReturn := fake.readChanReturnsOnCall[len(fake.readChanArgsForCall)] + fake.readChanArgsForCall = append(fake.readChanArgsForCall, struct { + }{}) + stub := fake.ReadChanStub + fakeReturns := fake.readChanReturns + fake.recordInvocation("ReadChan", []interface{}{}) + fake.readChanMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMessageSource) ReadChanCallCount() int { + fake.readChanMutex.RLock() + defer fake.readChanMutex.RUnlock() + return len(fake.readChanArgsForCall) +} + +func (fake *FakeMessageSource) ReadChanCalls(stub func() <-chan proto.Message) { + fake.readChanMutex.Lock() + defer fake.readChanMutex.Unlock() + fake.ReadChanStub = stub +} + +func (fake *FakeMessageSource) ReadChanReturns(result1 <-chan proto.Message) { + fake.readChanMutex.Lock() + defer fake.readChanMutex.Unlock() + fake.ReadChanStub = nil + fake.readChanReturns = struct { + result1 <-chan proto.Message + }{result1} +} + +func (fake *FakeMessageSource) ReadChanReturnsOnCall(i int, result1 <-chan proto.Message) { + fake.readChanMutex.Lock() + defer fake.readChanMutex.Unlock() + fake.ReadChanStub = nil + if fake.readChanReturnsOnCall == nil { + fake.readChanReturnsOnCall = make(map[int]struct { + result1 <-chan proto.Message + }) + } + fake.readChanReturnsOnCall[i] = struct { + result1 <-chan proto.Message + }{result1} +} + +func (fake *FakeMessageSource) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeMessageSource) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ routing.MessageSource = new(FakeMessageSource) diff --git a/livekit/pkg/routing/routingfakes/fake_room_manager_client.go b/livekit/pkg/routing/routingfakes/fake_room_manager_client.go new file mode 100644 index 0000000..6e7e652 --- /dev/null +++ b/livekit/pkg/routing/routingfakes/fake_room_manager_client.go @@ -0,0 +1,151 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package routingfakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" + "github.com/livekit/psrpc" +) + +type FakeRoomManagerClient struct { + CloseStub func() + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + CreateRoomStub func(context.Context, livekit.NodeID, *livekit.CreateRoomRequest, ...psrpc.RequestOption) (*livekit.Room, error) + createRoomMutex sync.RWMutex + createRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.NodeID + arg3 *livekit.CreateRoomRequest + arg4 []psrpc.RequestOption + } + createRoomReturns struct { + result1 *livekit.Room + result2 error + } + createRoomReturnsOnCall map[int]struct { + result1 *livekit.Room + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRoomManagerClient) Close() { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub() + } +} + +func (fake *FakeRoomManagerClient) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeRoomManagerClient) CloseCalls(stub func()) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeRoomManagerClient) CreateRoom(arg1 context.Context, arg2 livekit.NodeID, arg3 *livekit.CreateRoomRequest, arg4 ...psrpc.RequestOption) (*livekit.Room, error) { + fake.createRoomMutex.Lock() + ret, specificReturn := fake.createRoomReturnsOnCall[len(fake.createRoomArgsForCall)] + fake.createRoomArgsForCall = append(fake.createRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.NodeID + arg3 *livekit.CreateRoomRequest + arg4 []psrpc.RequestOption + }{arg1, arg2, arg3, arg4}) + stub := fake.CreateRoomStub + fakeReturns := fake.createRoomReturns + fake.recordInvocation("CreateRoom", []interface{}{arg1, arg2, arg3, arg4}) + fake.createRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4...) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRoomManagerClient) CreateRoomCallCount() int { + fake.createRoomMutex.RLock() + defer fake.createRoomMutex.RUnlock() + return len(fake.createRoomArgsForCall) +} + +func (fake *FakeRoomManagerClient) CreateRoomCalls(stub func(context.Context, livekit.NodeID, *livekit.CreateRoomRequest, ...psrpc.RequestOption) (*livekit.Room, error)) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = stub +} + +func (fake *FakeRoomManagerClient) CreateRoomArgsForCall(i int) (context.Context, livekit.NodeID, *livekit.CreateRoomRequest, []psrpc.RequestOption) { + fake.createRoomMutex.RLock() + defer fake.createRoomMutex.RUnlock() + argsForCall := fake.createRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeRoomManagerClient) CreateRoomReturns(result1 *livekit.Room, result2 error) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = nil + fake.createRoomReturns = struct { + result1 *livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeRoomManagerClient) CreateRoomReturnsOnCall(i int, result1 *livekit.Room, result2 error) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = nil + if fake.createRoomReturnsOnCall == nil { + fake.createRoomReturnsOnCall = make(map[int]struct { + result1 *livekit.Room + result2 error + }) + } + fake.createRoomReturnsOnCall[i] = struct { + result1 *livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeRoomManagerClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRoomManagerClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ routing.RoomManagerClient = new(FakeRoomManagerClient) diff --git a/livekit/pkg/routing/routingfakes/fake_router.go b/livekit/pkg/routing/routingfakes/fake_router.go new file mode 100644 index 0000000..48fec0b --- /dev/null +++ b/livekit/pkg/routing/routingfakes/fake_router.go @@ -0,0 +1,867 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package routingfakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" +) + +type FakeRouter struct { + ClearRoomStateStub func(context.Context, livekit.RoomName) error + clearRoomStateMutex sync.RWMutex + clearRoomStateArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + clearRoomStateReturns struct { + result1 error + } + clearRoomStateReturnsOnCall map[int]struct { + result1 error + } + CreateRoomStub func(context.Context, *livekit.CreateRoomRequest) (*livekit.Room, error) + createRoomMutex sync.RWMutex + createRoomArgsForCall []struct { + arg1 context.Context + arg2 *livekit.CreateRoomRequest + } + createRoomReturns struct { + result1 *livekit.Room + result2 error + } + createRoomReturnsOnCall map[int]struct { + result1 *livekit.Room + result2 error + } + DrainStub func() + drainMutex sync.RWMutex + drainArgsForCall []struct { + } + GetNodeForRoomStub func(context.Context, livekit.RoomName) (*livekit.Node, error) + getNodeForRoomMutex sync.RWMutex + getNodeForRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + getNodeForRoomReturns struct { + result1 *livekit.Node + result2 error + } + getNodeForRoomReturnsOnCall map[int]struct { + result1 *livekit.Node + result2 error + } + GetRegionStub func() string + getRegionMutex sync.RWMutex + getRegionArgsForCall []struct { + } + getRegionReturns struct { + result1 string + } + getRegionReturnsOnCall map[int]struct { + result1 string + } + ListNodesStub func() ([]*livekit.Node, error) + listNodesMutex sync.RWMutex + listNodesArgsForCall []struct { + } + listNodesReturns struct { + result1 []*livekit.Node + result2 error + } + listNodesReturnsOnCall map[int]struct { + result1 []*livekit.Node + result2 error + } + RegisterNodeStub func() error + registerNodeMutex sync.RWMutex + registerNodeArgsForCall []struct { + } + registerNodeReturns struct { + result1 error + } + registerNodeReturnsOnCall map[int]struct { + result1 error + } + RemoveDeadNodesStub func() error + removeDeadNodesMutex sync.RWMutex + removeDeadNodesArgsForCall []struct { + } + removeDeadNodesReturns struct { + result1 error + } + removeDeadNodesReturnsOnCall map[int]struct { + result1 error + } + SetNodeForRoomStub func(context.Context, livekit.RoomName, livekit.NodeID) error + setNodeForRoomMutex sync.RWMutex + setNodeForRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.NodeID + } + setNodeForRoomReturns struct { + result1 error + } + setNodeForRoomReturnsOnCall map[int]struct { + result1 error + } + StartStub func() error + startMutex sync.RWMutex + startArgsForCall []struct { + } + startReturns struct { + result1 error + } + startReturnsOnCall map[int]struct { + result1 error + } + StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit) (routing.StartParticipantSignalResults, error) + startParticipantSignalMutex sync.RWMutex + startParticipantSignalArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + } + startParticipantSignalReturns struct { + result1 routing.StartParticipantSignalResults + result2 error + } + startParticipantSignalReturnsOnCall map[int]struct { + result1 routing.StartParticipantSignalResults + result2 error + } + StopStub func() + stopMutex sync.RWMutex + stopArgsForCall []struct { + } + UnregisterNodeStub func() error + unregisterNodeMutex sync.RWMutex + unregisterNodeArgsForCall []struct { + } + unregisterNodeReturns struct { + result1 error + } + unregisterNodeReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRouter) ClearRoomState(arg1 context.Context, arg2 livekit.RoomName) error { + fake.clearRoomStateMutex.Lock() + ret, specificReturn := fake.clearRoomStateReturnsOnCall[len(fake.clearRoomStateArgsForCall)] + fake.clearRoomStateArgsForCall = append(fake.clearRoomStateArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.ClearRoomStateStub + fakeReturns := fake.clearRoomStateReturns + fake.recordInvocation("ClearRoomState", []interface{}{arg1, arg2}) + fake.clearRoomStateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) ClearRoomStateCallCount() int { + fake.clearRoomStateMutex.RLock() + defer fake.clearRoomStateMutex.RUnlock() + return len(fake.clearRoomStateArgsForCall) +} + +func (fake *FakeRouter) ClearRoomStateCalls(stub func(context.Context, livekit.RoomName) error) { + fake.clearRoomStateMutex.Lock() + defer fake.clearRoomStateMutex.Unlock() + fake.ClearRoomStateStub = stub +} + +func (fake *FakeRouter) ClearRoomStateArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.clearRoomStateMutex.RLock() + defer fake.clearRoomStateMutex.RUnlock() + argsForCall := fake.clearRoomStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRouter) ClearRoomStateReturns(result1 error) { + fake.clearRoomStateMutex.Lock() + defer fake.clearRoomStateMutex.Unlock() + fake.ClearRoomStateStub = nil + fake.clearRoomStateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) ClearRoomStateReturnsOnCall(i int, result1 error) { + fake.clearRoomStateMutex.Lock() + defer fake.clearRoomStateMutex.Unlock() + fake.ClearRoomStateStub = nil + if fake.clearRoomStateReturnsOnCall == nil { + fake.clearRoomStateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.clearRoomStateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) CreateRoom(arg1 context.Context, arg2 *livekit.CreateRoomRequest) (*livekit.Room, error) { + fake.createRoomMutex.Lock() + ret, specificReturn := fake.createRoomReturnsOnCall[len(fake.createRoomArgsForCall)] + fake.createRoomArgsForCall = append(fake.createRoomArgsForCall, struct { + arg1 context.Context + arg2 *livekit.CreateRoomRequest + }{arg1, arg2}) + stub := fake.CreateRoomStub + fakeReturns := fake.createRoomReturns + fake.recordInvocation("CreateRoom", []interface{}{arg1, arg2}) + fake.createRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRouter) CreateRoomCallCount() int { + fake.createRoomMutex.RLock() + defer fake.createRoomMutex.RUnlock() + return len(fake.createRoomArgsForCall) +} + +func (fake *FakeRouter) CreateRoomCalls(stub func(context.Context, *livekit.CreateRoomRequest) (*livekit.Room, error)) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = stub +} + +func (fake *FakeRouter) CreateRoomArgsForCall(i int) (context.Context, *livekit.CreateRoomRequest) { + fake.createRoomMutex.RLock() + defer fake.createRoomMutex.RUnlock() + argsForCall := fake.createRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRouter) CreateRoomReturns(result1 *livekit.Room, result2 error) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = nil + fake.createRoomReturns = struct { + result1 *livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) CreateRoomReturnsOnCall(i int, result1 *livekit.Room, result2 error) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = nil + if fake.createRoomReturnsOnCall == nil { + fake.createRoomReturnsOnCall = make(map[int]struct { + result1 *livekit.Room + result2 error + }) + } + fake.createRoomReturnsOnCall[i] = struct { + result1 *livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) Drain() { + fake.drainMutex.Lock() + fake.drainArgsForCall = append(fake.drainArgsForCall, struct { + }{}) + stub := fake.DrainStub + fake.recordInvocation("Drain", []interface{}{}) + fake.drainMutex.Unlock() + if stub != nil { + fake.DrainStub() + } +} + +func (fake *FakeRouter) DrainCallCount() int { + fake.drainMutex.RLock() + defer fake.drainMutex.RUnlock() + return len(fake.drainArgsForCall) +} + +func (fake *FakeRouter) DrainCalls(stub func()) { + fake.drainMutex.Lock() + defer fake.drainMutex.Unlock() + fake.DrainStub = stub +} + +func (fake *FakeRouter) GetNodeForRoom(arg1 context.Context, arg2 livekit.RoomName) (*livekit.Node, error) { + fake.getNodeForRoomMutex.Lock() + ret, specificReturn := fake.getNodeForRoomReturnsOnCall[len(fake.getNodeForRoomArgsForCall)] + fake.getNodeForRoomArgsForCall = append(fake.getNodeForRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.GetNodeForRoomStub + fakeReturns := fake.getNodeForRoomReturns + fake.recordInvocation("GetNodeForRoom", []interface{}{arg1, arg2}) + fake.getNodeForRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRouter) GetNodeForRoomCallCount() int { + fake.getNodeForRoomMutex.RLock() + defer fake.getNodeForRoomMutex.RUnlock() + return len(fake.getNodeForRoomArgsForCall) +} + +func (fake *FakeRouter) GetNodeForRoomCalls(stub func(context.Context, livekit.RoomName) (*livekit.Node, error)) { + fake.getNodeForRoomMutex.Lock() + defer fake.getNodeForRoomMutex.Unlock() + fake.GetNodeForRoomStub = stub +} + +func (fake *FakeRouter) GetNodeForRoomArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.getNodeForRoomMutex.RLock() + defer fake.getNodeForRoomMutex.RUnlock() + argsForCall := fake.getNodeForRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRouter) GetNodeForRoomReturns(result1 *livekit.Node, result2 error) { + fake.getNodeForRoomMutex.Lock() + defer fake.getNodeForRoomMutex.Unlock() + fake.GetNodeForRoomStub = nil + fake.getNodeForRoomReturns = struct { + result1 *livekit.Node + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) GetNodeForRoomReturnsOnCall(i int, result1 *livekit.Node, result2 error) { + fake.getNodeForRoomMutex.Lock() + defer fake.getNodeForRoomMutex.Unlock() + fake.GetNodeForRoomStub = nil + if fake.getNodeForRoomReturnsOnCall == nil { + fake.getNodeForRoomReturnsOnCall = make(map[int]struct { + result1 *livekit.Node + result2 error + }) + } + fake.getNodeForRoomReturnsOnCall[i] = struct { + result1 *livekit.Node + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) GetRegion() string { + fake.getRegionMutex.Lock() + ret, specificReturn := fake.getRegionReturnsOnCall[len(fake.getRegionArgsForCall)] + fake.getRegionArgsForCall = append(fake.getRegionArgsForCall, struct { + }{}) + stub := fake.GetRegionStub + fakeReturns := fake.getRegionReturns + fake.recordInvocation("GetRegion", []interface{}{}) + fake.getRegionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) GetRegionCallCount() int { + fake.getRegionMutex.RLock() + defer fake.getRegionMutex.RUnlock() + return len(fake.getRegionArgsForCall) +} + +func (fake *FakeRouter) GetRegionCalls(stub func() string) { + fake.getRegionMutex.Lock() + defer fake.getRegionMutex.Unlock() + fake.GetRegionStub = stub +} + +func (fake *FakeRouter) GetRegionReturns(result1 string) { + fake.getRegionMutex.Lock() + defer fake.getRegionMutex.Unlock() + fake.GetRegionStub = nil + fake.getRegionReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeRouter) GetRegionReturnsOnCall(i int, result1 string) { + fake.getRegionMutex.Lock() + defer fake.getRegionMutex.Unlock() + fake.GetRegionStub = nil + if fake.getRegionReturnsOnCall == nil { + fake.getRegionReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.getRegionReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeRouter) ListNodes() ([]*livekit.Node, error) { + fake.listNodesMutex.Lock() + ret, specificReturn := fake.listNodesReturnsOnCall[len(fake.listNodesArgsForCall)] + fake.listNodesArgsForCall = append(fake.listNodesArgsForCall, struct { + }{}) + stub := fake.ListNodesStub + fakeReturns := fake.listNodesReturns + fake.recordInvocation("ListNodes", []interface{}{}) + fake.listNodesMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRouter) ListNodesCallCount() int { + fake.listNodesMutex.RLock() + defer fake.listNodesMutex.RUnlock() + return len(fake.listNodesArgsForCall) +} + +func (fake *FakeRouter) ListNodesCalls(stub func() ([]*livekit.Node, error)) { + fake.listNodesMutex.Lock() + defer fake.listNodesMutex.Unlock() + fake.ListNodesStub = stub +} + +func (fake *FakeRouter) ListNodesReturns(result1 []*livekit.Node, result2 error) { + fake.listNodesMutex.Lock() + defer fake.listNodesMutex.Unlock() + fake.ListNodesStub = nil + fake.listNodesReturns = struct { + result1 []*livekit.Node + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) ListNodesReturnsOnCall(i int, result1 []*livekit.Node, result2 error) { + fake.listNodesMutex.Lock() + defer fake.listNodesMutex.Unlock() + fake.ListNodesStub = nil + if fake.listNodesReturnsOnCall == nil { + fake.listNodesReturnsOnCall = make(map[int]struct { + result1 []*livekit.Node + result2 error + }) + } + fake.listNodesReturnsOnCall[i] = struct { + result1 []*livekit.Node + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) RegisterNode() error { + fake.registerNodeMutex.Lock() + ret, specificReturn := fake.registerNodeReturnsOnCall[len(fake.registerNodeArgsForCall)] + fake.registerNodeArgsForCall = append(fake.registerNodeArgsForCall, struct { + }{}) + stub := fake.RegisterNodeStub + fakeReturns := fake.registerNodeReturns + fake.recordInvocation("RegisterNode", []interface{}{}) + fake.registerNodeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) RegisterNodeCallCount() int { + fake.registerNodeMutex.RLock() + defer fake.registerNodeMutex.RUnlock() + return len(fake.registerNodeArgsForCall) +} + +func (fake *FakeRouter) RegisterNodeCalls(stub func() error) { + fake.registerNodeMutex.Lock() + defer fake.registerNodeMutex.Unlock() + fake.RegisterNodeStub = stub +} + +func (fake *FakeRouter) RegisterNodeReturns(result1 error) { + fake.registerNodeMutex.Lock() + defer fake.registerNodeMutex.Unlock() + fake.RegisterNodeStub = nil + fake.registerNodeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) RegisterNodeReturnsOnCall(i int, result1 error) { + fake.registerNodeMutex.Lock() + defer fake.registerNodeMutex.Unlock() + fake.RegisterNodeStub = nil + if fake.registerNodeReturnsOnCall == nil { + fake.registerNodeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.registerNodeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) RemoveDeadNodes() error { + fake.removeDeadNodesMutex.Lock() + ret, specificReturn := fake.removeDeadNodesReturnsOnCall[len(fake.removeDeadNodesArgsForCall)] + fake.removeDeadNodesArgsForCall = append(fake.removeDeadNodesArgsForCall, struct { + }{}) + stub := fake.RemoveDeadNodesStub + fakeReturns := fake.removeDeadNodesReturns + fake.recordInvocation("RemoveDeadNodes", []interface{}{}) + fake.removeDeadNodesMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) RemoveDeadNodesCallCount() int { + fake.removeDeadNodesMutex.RLock() + defer fake.removeDeadNodesMutex.RUnlock() + return len(fake.removeDeadNodesArgsForCall) +} + +func (fake *FakeRouter) RemoveDeadNodesCalls(stub func() error) { + fake.removeDeadNodesMutex.Lock() + defer fake.removeDeadNodesMutex.Unlock() + fake.RemoveDeadNodesStub = stub +} + +func (fake *FakeRouter) RemoveDeadNodesReturns(result1 error) { + fake.removeDeadNodesMutex.Lock() + defer fake.removeDeadNodesMutex.Unlock() + fake.RemoveDeadNodesStub = nil + fake.removeDeadNodesReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) RemoveDeadNodesReturnsOnCall(i int, result1 error) { + fake.removeDeadNodesMutex.Lock() + defer fake.removeDeadNodesMutex.Unlock() + fake.RemoveDeadNodesStub = nil + if fake.removeDeadNodesReturnsOnCall == nil { + fake.removeDeadNodesReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.removeDeadNodesReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) SetNodeForRoom(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.NodeID) error { + fake.setNodeForRoomMutex.Lock() + ret, specificReturn := fake.setNodeForRoomReturnsOnCall[len(fake.setNodeForRoomArgsForCall)] + fake.setNodeForRoomArgsForCall = append(fake.setNodeForRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.NodeID + }{arg1, arg2, arg3}) + stub := fake.SetNodeForRoomStub + fakeReturns := fake.setNodeForRoomReturns + fake.recordInvocation("SetNodeForRoom", []interface{}{arg1, arg2, arg3}) + fake.setNodeForRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) SetNodeForRoomCallCount() int { + fake.setNodeForRoomMutex.RLock() + defer fake.setNodeForRoomMutex.RUnlock() + return len(fake.setNodeForRoomArgsForCall) +} + +func (fake *FakeRouter) SetNodeForRoomCalls(stub func(context.Context, livekit.RoomName, livekit.NodeID) error) { + fake.setNodeForRoomMutex.Lock() + defer fake.setNodeForRoomMutex.Unlock() + fake.SetNodeForRoomStub = stub +} + +func (fake *FakeRouter) SetNodeForRoomArgsForCall(i int) (context.Context, livekit.RoomName, livekit.NodeID) { + fake.setNodeForRoomMutex.RLock() + defer fake.setNodeForRoomMutex.RUnlock() + argsForCall := fake.setNodeForRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRouter) SetNodeForRoomReturns(result1 error) { + fake.setNodeForRoomMutex.Lock() + defer fake.setNodeForRoomMutex.Unlock() + fake.SetNodeForRoomStub = nil + fake.setNodeForRoomReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) SetNodeForRoomReturnsOnCall(i int, result1 error) { + fake.setNodeForRoomMutex.Lock() + defer fake.setNodeForRoomMutex.Unlock() + fake.SetNodeForRoomStub = nil + if fake.setNodeForRoomReturnsOnCall == nil { + fake.setNodeForRoomReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setNodeForRoomReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) Start() error { + fake.startMutex.Lock() + ret, specificReturn := fake.startReturnsOnCall[len(fake.startArgsForCall)] + fake.startArgsForCall = append(fake.startArgsForCall, struct { + }{}) + stub := fake.StartStub + fakeReturns := fake.startReturns + fake.recordInvocation("Start", []interface{}{}) + fake.startMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) StartCallCount() int { + fake.startMutex.RLock() + defer fake.startMutex.RUnlock() + return len(fake.startArgsForCall) +} + +func (fake *FakeRouter) StartCalls(stub func() error) { + fake.startMutex.Lock() + defer fake.startMutex.Unlock() + fake.StartStub = stub +} + +func (fake *FakeRouter) StartReturns(result1 error) { + fake.startMutex.Lock() + defer fake.startMutex.Unlock() + fake.StartStub = nil + fake.startReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) StartReturnsOnCall(i int, result1 error) { + fake.startMutex.Lock() + defer fake.startMutex.Unlock() + fake.StartStub = nil + if fake.startReturnsOnCall == nil { + fake.startReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.startReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit) (routing.StartParticipantSignalResults, error) { + fake.startParticipantSignalMutex.Lock() + ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)] + fake.startParticipantSignalArgsForCall = append(fake.startParticipantSignalArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + }{arg1, arg2, arg3}) + stub := fake.StartParticipantSignalStub + fakeReturns := fake.startParticipantSignalReturns + fake.recordInvocation("StartParticipantSignal", []interface{}{arg1, arg2, arg3}) + fake.startParticipantSignalMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRouter) StartParticipantSignalCallCount() int { + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + return len(fake.startParticipantSignalArgsForCall) +} + +func (fake *FakeRouter) StartParticipantSignalCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit) (routing.StartParticipantSignalResults, error)) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = stub +} + +func (fake *FakeRouter) StartParticipantSignalArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit) { + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + argsForCall := fake.startParticipantSignalArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRouter) StartParticipantSignalReturns(result1 routing.StartParticipantSignalResults, result2 error) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = nil + fake.startParticipantSignalReturns = struct { + result1 routing.StartParticipantSignalResults + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) StartParticipantSignalReturnsOnCall(i int, result1 routing.StartParticipantSignalResults, result2 error) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = nil + if fake.startParticipantSignalReturnsOnCall == nil { + fake.startParticipantSignalReturnsOnCall = make(map[int]struct { + result1 routing.StartParticipantSignalResults + result2 error + }) + } + fake.startParticipantSignalReturnsOnCall[i] = struct { + result1 routing.StartParticipantSignalResults + result2 error + }{result1, result2} +} + +func (fake *FakeRouter) Stop() { + fake.stopMutex.Lock() + fake.stopArgsForCall = append(fake.stopArgsForCall, struct { + }{}) + stub := fake.StopStub + fake.recordInvocation("Stop", []interface{}{}) + fake.stopMutex.Unlock() + if stub != nil { + fake.StopStub() + } +} + +func (fake *FakeRouter) StopCallCount() int { + fake.stopMutex.RLock() + defer fake.stopMutex.RUnlock() + return len(fake.stopArgsForCall) +} + +func (fake *FakeRouter) StopCalls(stub func()) { + fake.stopMutex.Lock() + defer fake.stopMutex.Unlock() + fake.StopStub = stub +} + +func (fake *FakeRouter) UnregisterNode() error { + fake.unregisterNodeMutex.Lock() + ret, specificReturn := fake.unregisterNodeReturnsOnCall[len(fake.unregisterNodeArgsForCall)] + fake.unregisterNodeArgsForCall = append(fake.unregisterNodeArgsForCall, struct { + }{}) + stub := fake.UnregisterNodeStub + fakeReturns := fake.unregisterNodeReturns + fake.recordInvocation("UnregisterNode", []interface{}{}) + fake.unregisterNodeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRouter) UnregisterNodeCallCount() int { + fake.unregisterNodeMutex.RLock() + defer fake.unregisterNodeMutex.RUnlock() + return len(fake.unregisterNodeArgsForCall) +} + +func (fake *FakeRouter) UnregisterNodeCalls(stub func() error) { + fake.unregisterNodeMutex.Lock() + defer fake.unregisterNodeMutex.Unlock() + fake.UnregisterNodeStub = stub +} + +func (fake *FakeRouter) UnregisterNodeReturns(result1 error) { + fake.unregisterNodeMutex.Lock() + defer fake.unregisterNodeMutex.Unlock() + fake.UnregisterNodeStub = nil + fake.unregisterNodeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) UnregisterNodeReturnsOnCall(i int, result1 error) { + fake.unregisterNodeMutex.Lock() + defer fake.unregisterNodeMutex.Unlock() + fake.UnregisterNodeStub = nil + if fake.unregisterNodeReturnsOnCall == nil { + fake.unregisterNodeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.unregisterNodeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRouter) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRouter) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ routing.Router = new(FakeRouter) diff --git a/livekit/pkg/routing/routingfakes/fake_signal_client.go b/livekit/pkg/routing/routingfakes/fake_signal_client.go new file mode 100644 index 0000000..c7eeff1 --- /dev/null +++ b/livekit/pkg/routing/routingfakes/fake_signal_client.go @@ -0,0 +1,195 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package routingfakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" +) + +type FakeSignalClient struct { + ActiveCountStub func() int + activeCountMutex sync.RWMutex + activeCountArgsForCall []struct { + } + activeCountReturns struct { + result1 int + } + activeCountReturnsOnCall map[int]struct { + result1 int + } + StartParticipantSignalStub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) + startParticipantSignalMutex sync.RWMutex + startParticipantSignalArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.NodeID + } + startParticipantSignalReturns struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + } + startParticipantSignalReturnsOnCall map[int]struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSignalClient) ActiveCount() int { + fake.activeCountMutex.Lock() + ret, specificReturn := fake.activeCountReturnsOnCall[len(fake.activeCountArgsForCall)] + fake.activeCountArgsForCall = append(fake.activeCountArgsForCall, struct { + }{}) + stub := fake.ActiveCountStub + fakeReturns := fake.activeCountReturns + fake.recordInvocation("ActiveCount", []interface{}{}) + fake.activeCountMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSignalClient) ActiveCountCallCount() int { + fake.activeCountMutex.RLock() + defer fake.activeCountMutex.RUnlock() + return len(fake.activeCountArgsForCall) +} + +func (fake *FakeSignalClient) ActiveCountCalls(stub func() int) { + fake.activeCountMutex.Lock() + defer fake.activeCountMutex.Unlock() + fake.ActiveCountStub = stub +} + +func (fake *FakeSignalClient) ActiveCountReturns(result1 int) { + fake.activeCountMutex.Lock() + defer fake.activeCountMutex.Unlock() + fake.ActiveCountStub = nil + fake.activeCountReturns = struct { + result1 int + }{result1} +} + +func (fake *FakeSignalClient) ActiveCountReturnsOnCall(i int, result1 int) { + fake.activeCountMutex.Lock() + defer fake.activeCountMutex.Unlock() + fake.ActiveCountStub = nil + if fake.activeCountReturnsOnCall == nil { + fake.activeCountReturnsOnCall = make(map[int]struct { + result1 int + }) + } + fake.activeCountReturnsOnCall[i] = struct { + result1 int + }{result1} +} + +func (fake *FakeSignalClient) StartParticipantSignal(arg1 context.Context, arg2 livekit.RoomName, arg3 routing.ParticipantInit, arg4 livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error) { + fake.startParticipantSignalMutex.Lock() + ret, specificReturn := fake.startParticipantSignalReturnsOnCall[len(fake.startParticipantSignalArgsForCall)] + fake.startParticipantSignalArgsForCall = append(fake.startParticipantSignalArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 routing.ParticipantInit + arg4 livekit.NodeID + }{arg1, arg2, arg3, arg4}) + stub := fake.StartParticipantSignalStub + fakeReturns := fake.startParticipantSignalReturns + fake.recordInvocation("StartParticipantSignal", []interface{}{arg1, arg2, arg3, arg4}) + fake.startParticipantSignalMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3, ret.result4 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 +} + +func (fake *FakeSignalClient) StartParticipantSignalCallCount() int { + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + return len(fake.startParticipantSignalArgsForCall) +} + +func (fake *FakeSignalClient) StartParticipantSignalCalls(stub func(context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) (livekit.ConnectionID, routing.MessageSink, routing.MessageSource, error)) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = stub +} + +func (fake *FakeSignalClient) StartParticipantSignalArgsForCall(i int) (context.Context, livekit.RoomName, routing.ParticipantInit, livekit.NodeID) { + fake.startParticipantSignalMutex.RLock() + defer fake.startParticipantSignalMutex.RUnlock() + argsForCall := fake.startParticipantSignalArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeSignalClient) StartParticipantSignalReturns(result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = nil + fake.startParticipantSignalReturns = struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeSignalClient) StartParticipantSignalReturnsOnCall(i int, result1 livekit.ConnectionID, result2 routing.MessageSink, result3 routing.MessageSource, result4 error) { + fake.startParticipantSignalMutex.Lock() + defer fake.startParticipantSignalMutex.Unlock() + fake.StartParticipantSignalStub = nil + if fake.startParticipantSignalReturnsOnCall == nil { + fake.startParticipantSignalReturnsOnCall = make(map[int]struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + }) + } + fake.startParticipantSignalReturnsOnCall[i] = struct { + result1 livekit.ConnectionID + result2 routing.MessageSink + result3 routing.MessageSource + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeSignalClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSignalClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ routing.SignalClient = new(FakeSignalClient) diff --git a/livekit/pkg/routing/selector/any.go b/livekit/pkg/routing/selector/any.go new file mode 100644 index 0000000..592e2ed --- /dev/null +++ b/livekit/pkg/routing/selector/any.go @@ -0,0 +1,34 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "github.com/livekit/protocol/livekit" +) + +// AnySelector selects any available node with no limitations +type AnySelector struct { + SortBy string + Algorithm string +} + +func (s *AnySelector) SelectNode(nodes []*livekit.Node) (*livekit.Node, error) { + nodes = GetAvailableNodes(nodes) + if len(nodes) == 0 { + return nil, ErrNoAvailableNodes + } + + return SelectSortedNode(nodes, s.SortBy, s.Algorithm) +} diff --git a/livekit/pkg/routing/selector/any_test.go b/livekit/pkg/routing/selector/any_test.go new file mode 100644 index 0000000..e7c6d36 --- /dev/null +++ b/livekit/pkg/routing/selector/any_test.go @@ -0,0 +1,287 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" +) + +func createTestNode(id string, cpuLoad float32, numRooms int32, numClients int32, state livekit.NodeState) *livekit.Node { + return &livekit.Node{ + Id: id, + State: state, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix() - 1, // Recent update to be considered available + CpuLoad: cpuLoad, + NumRooms: numRooms, + NumClients: numClients, + NumCpus: 4, + LoadAvgLast1Min: cpuLoad * 4, // Simulate system load + }, + } +} + +func TestAnySelector_SelectNode_TwoChoice(t *testing.T) { + tests := []struct { + name string + sortBy string + algorithm string + nodes []*livekit.Node + wantErr string + expected string + notExpected string + }{ + { + name: "successful selection with cpuload sorting", + sortBy: "cpuload", + algorithm: "twochoice", + nodes: []*livekit.Node{ + createTestNode("node1", 0.8, 5, 10, livekit.NodeState_SERVING), + createTestNode("node2", 0.3, 2, 5, livekit.NodeState_SERVING), + createTestNode("node3", 0.6, 3, 8, livekit.NodeState_SERVING), + createTestNode("node4", 0.9, 6, 12, livekit.NodeState_SERVING), + }, + wantErr: "", + expected: "", // Not determinstic selection, so no specific expected node + notExpected: "node4", // Node with highest load should not be selected + }, + { + name: "successful selection with rooms sorting", + sortBy: "rooms", + algorithm: "twochoice", + nodes: []*livekit.Node{ + createTestNode("node1", 0.5, 8, 15, livekit.NodeState_SERVING), + createTestNode("node2", 0.4, 2, 5, livekit.NodeState_SERVING), + createTestNode("node3", 0.6, 12, 20, livekit.NodeState_SERVING), + }, + wantErr: "", + expected: "", // Not determinstic selection, so no specific expected node + notExpected: "node3", // Node with highest room count should not be selected + }, + { + name: "successful selection with clients sorting", + sortBy: "clients", + algorithm: "twochoice", + nodes: []*livekit.Node{ + createTestNode("node1", 0.5, 3, 25, livekit.NodeState_SERVING), + createTestNode("node2", 0.4, 2, 5, livekit.NodeState_SERVING), + createTestNode("node3", 0.6, 4, 30, livekit.NodeState_SERVING), + }, + wantErr: "", + expected: "", // Not determinstic selection, so no specific expected node + notExpected: "node3", // Node with highest clients should not be selected + }, + { + name: "empty nodes list", + sortBy: "cpuload", + algorithm: "twochoice", + nodes: []*livekit.Node{}, + wantErr: "could not find any available nodes", + }, + { + name: "no available nodes - all unavailable", + sortBy: "cpuload", + algorithm: "twochoice", + nodes: []*livekit.Node{ + { + Id: "node1", + State: livekit.NodeState_SERVING, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix() - 10, // Too old + CpuLoad: 0.3, + }, + }, + }, + wantErr: "could not find any available nodes", + }, + { + name: "no available nodes - not serving", + sortBy: "cpuload", + algorithm: "twochoice", + nodes: []*livekit.Node{ + { + Id: "node1", + State: livekit.NodeState_SHUTTING_DOWN, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix() - 1, + CpuLoad: 0.3, + }, + }, + }, + wantErr: "could not find any available nodes", + }, + { + name: "single available node", + sortBy: "cpuload", + algorithm: "twochoice", + nodes: []*livekit.Node{ + createTestNode("node1", 0.5, 3, 10, livekit.NodeState_SERVING), + }, + wantErr: "", + expected: "node1", // Should select the only available node + notExpected: "", // No other nodes to compare against + }, + { + name: "two available nodes", + sortBy: "cpuload", + algorithm: "twochoice", + nodes: []*livekit.Node{ + createTestNode("node1", 0.8, 5, 15, livekit.NodeState_SERVING), + createTestNode("node2", 0.3, 2, 5, livekit.NodeState_SERVING), + }, + wantErr: "", + expected: "node2", // Should select the node with lower load + notExpected: "node1", // Should not select the node with higher load + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selector := &AnySelector{ + SortBy: tt.sortBy, + Algorithm: tt.algorithm, + } + + node, err := selector.SelectNode(tt.nodes) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, node) + } else { + require.NoError(t, err) + require.NotNil(t, node) + require.NotEmpty(t, node.Id) + + // Verify the selected node is one of the available nodes + found := false + availableNodes := GetAvailableNodes(tt.nodes) + for _, availableNode := range availableNodes { + if availableNode.Id == node.Id { + found = true + break + } + } + require.True(t, found, "Selected node should be one of the available nodes") + + if tt.expected != "" { + require.Equal(t, tt.expected, node.Id, "Selected node should match expected") + } + if tt.notExpected != "" { + require.NotEqual(t, tt.notExpected, node.Id, "Selected node should not match not expected") + } + } + }) + } +} + +func TestAnySelector_SelectNode_TwoChoice_Probabilistic_Behavior(t *testing.T) { + // Test that two-choice algorithm favors nodes with lower metrics + // This test runs multiple iterations to increase confidence in the probabilistic behavior + selector := &AnySelector{ + SortBy: "cpuload", + Algorithm: "twochoice", + } + + // Create nodes where node2 has significantly lower CPU load + nodes := []*livekit.Node{ + createTestNode("node1", 0.95, 10, 20, livekit.NodeState_SERVING), // Very high load + createTestNode("node2", 0.1, 1, 2, livekit.NodeState_SERVING), // Low load + createTestNode("node3", 0.5, 9, 18, livekit.NodeState_SERVING), // Medium load + createTestNode("node4", 0.85, 8, 16, livekit.NodeState_SERVING), // High load + } + + // Run multiple selections and count how often the low-load node is selected + iterations := 1000 + lowLoadSelections := 0 + higestLoadSelections := 0 + + for range iterations { + node, err := selector.SelectNode(nodes) + require.NoError(t, err) + require.NotNil(t, node) + + if node.Id == "node2" { + lowLoadSelections++ + } + if node.Id == "node1" { + higestLoadSelections++ + } + } + + // The low-load node should be selected more often than pure random (25%) + // Due to the two-choice algorithm favoring the better node + selectionRate := float64(lowLoadSelections) / float64(iterations) + require.Greater(t, selectionRate, 0.4, "Two-choice algorithm should favor the low-load node more than random selection") + require.Equal(t, higestLoadSelections, 0, "Two-choice algorithm should never favor the highest load node") +} + +func TestAnySelector_SelectNode_InvalidParameters(t *testing.T) { + nodes := []*livekit.Node{ + createTestNode("node1", 0.5, 3, 10, livekit.NodeState_SERVING), + } + + tests := []struct { + name string + sortBy string + algorithm string + wantErr string + }{ + { + name: "empty sortBy", + sortBy: "", + algorithm: "twochoice", + wantErr: "sort by option cannot be blank", + }, + { + name: "empty algorithm", + sortBy: "cpuload", + algorithm: "", + wantErr: "node selector algorithm option cannot be blank", + }, + { + name: "unknown sortBy", + sortBy: "invalid", + algorithm: "twochoice", + wantErr: "unknown sort by option", + }, + { + name: "unknown algorithm", + sortBy: "cpuload", + algorithm: "invalid", + wantErr: "unknown node selector algorithm option", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selector := &AnySelector{ + SortBy: tt.sortBy, + Algorithm: tt.algorithm, + } + + node, err := selector.SelectNode(nodes) + + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + require.Nil(t, node) + }) + } +} diff --git a/livekit/pkg/routing/selector/cpuload.go b/livekit/pkg/routing/selector/cpuload.go new file mode 100644 index 0000000..f84b4a3 --- /dev/null +++ b/livekit/pkg/routing/selector/cpuload.go @@ -0,0 +1,55 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "github.com/livekit/protocol/livekit" +) + +// CPULoadSelector eliminates nodes that have CPU usage higher than CPULoadLimit +// then selects a node from nodes that are not overloaded +type CPULoadSelector struct { + CPULoadLimit float32 + SortBy string + Algorithm string +} + +func (s *CPULoadSelector) filterNodes(nodes []*livekit.Node) ([]*livekit.Node, error) { + nodes = GetAvailableNodes(nodes) + if len(nodes) == 0 { + return nil, ErrNoAvailableNodes + } + + nodesLowLoad := make([]*livekit.Node, 0) + for _, node := range nodes { + stats := node.Stats + if stats.CpuLoad < s.CPULoadLimit { + nodesLowLoad = append(nodesLowLoad, node) + } + } + if len(nodesLowLoad) > 0 { + nodes = nodesLowLoad + } + return nodes, nil +} + +func (s *CPULoadSelector) SelectNode(nodes []*livekit.Node) (*livekit.Node, error) { + nodes, err := s.filterNodes(nodes) + if err != nil { + return nil, err + } + + return SelectSortedNode(nodes, s.SortBy, s.Algorithm) +} diff --git a/livekit/pkg/routing/selector/cpuload_test.go b/livekit/pkg/routing/selector/cpuload_test.go new file mode 100644 index 0000000..b267dcf --- /dev/null +++ b/livekit/pkg/routing/selector/cpuload_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +func TestCPULoadSelector_SelectNode(t *testing.T) { + sel := selector.CPULoadSelector{CPULoadLimit: 0.8, SortBy: "random", Algorithm: "lowest"} + + var nodes []*livekit.Node + _, err := sel.SelectNode(nodes) + require.Error(t, err, "should error no available nodes") + + // Select a node with high load when no nodes with low load are available + nodes = []*livekit.Node{nodeLoadHigh} + if _, err := sel.SelectNode(nodes); err != nil { + t.Error(err) + } + + // Select a node with low load when available + nodes = []*livekit.Node{nodeLoadLow, nodeLoadHigh} + for range 5 { + node, err := sel.SelectNode(nodes) + if err != nil { + t.Error(err) + } + if node != nodeLoadLow { + t.Error("selected the wrong node") + } + } +} diff --git a/livekit/pkg/routing/selector/errors.go b/livekit/pkg/routing/selector/errors.go new file mode 100644 index 0000000..690153a --- /dev/null +++ b/livekit/pkg/routing/selector/errors.go @@ -0,0 +1,27 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import "errors" + +var ( + ErrNoAvailableNodes = errors.New("could not find any available nodes") + ErrCurrentRegionNotSet = errors.New("current region cannot be blank") + ErrCurrentRegionUnknownLatLon = errors.New("unknown lat and lon for the current region") + ErrSortByNotSet = errors.New("sort by option cannot be blank") + ErrAlgorithmNotSet = errors.New("node selector algorithm option cannot be blank") + ErrSortByUnknown = errors.New("unknown sort by option") + ErrAlgorithmUnknown = errors.New("unknown node selector algorithm option") +) diff --git a/livekit/pkg/routing/selector/interfaces.go b/livekit/pkg/routing/selector/interfaces.go new file mode 100644 index 0000000..6fc223e --- /dev/null +++ b/livekit/pkg/routing/selector/interfaces.go @@ -0,0 +1,66 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "errors" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/config" +) + +var ErrUnsupportedSelector = errors.New("unsupported node selector") + +// NodeSelector selects an appropriate node to run the current session +type NodeSelector interface { + SelectNode(nodes []*livekit.Node) (*livekit.Node, error) +} + +func CreateNodeSelector(conf *config.Config) (NodeSelector, error) { + kind := conf.NodeSelector.Kind + if kind == "" { + kind = "any" + } + switch kind { + case "any": + return &AnySelector{conf.NodeSelector.SortBy, conf.NodeSelector.Algorithm}, nil + case "cpuload": + return &CPULoadSelector{ + CPULoadLimit: conf.NodeSelector.CPULoadLimit, + SortBy: conf.NodeSelector.SortBy, + Algorithm: conf.NodeSelector.Algorithm, + }, nil + case "sysload": + return &SystemLoadSelector{ + SysloadLimit: conf.NodeSelector.SysloadLimit, + SortBy: conf.NodeSelector.SortBy, + Algorithm: conf.NodeSelector.Algorithm, + }, nil + case "regionaware": + s, err := NewRegionAwareSelector(conf.Region, conf.NodeSelector.Regions, conf.NodeSelector.SortBy, conf.NodeSelector.Algorithm) + if err != nil { + return nil, err + } + s.SysloadLimit = conf.NodeSelector.SysloadLimit + return s, nil + case "random": + logger.Warnw("random node selector is deprecated, please switch to \"any\" or another selector", nil) + return &AnySelector{conf.NodeSelector.SortBy, conf.NodeSelector.Algorithm}, nil + default: + return nil, ErrUnsupportedSelector + } +} diff --git a/livekit/pkg/routing/selector/regionaware.go b/livekit/pkg/routing/selector/regionaware.go new file mode 100644 index 0000000..756781a --- /dev/null +++ b/livekit/pkg/routing/selector/regionaware.go @@ -0,0 +1,127 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "math" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/config" +) + +// RegionAwareSelector prefers available nodes that are closest to the region of the current instance +type RegionAwareSelector struct { + SystemLoadSelector + CurrentRegion string + regionDistances map[string]float64 + regions []config.RegionConfig + SortBy string + Algorithm string +} + +func NewRegionAwareSelector(currentRegion string, regions []config.RegionConfig, sortBy string, algorithm string) (*RegionAwareSelector, error) { + if currentRegion == "" { + return nil, ErrCurrentRegionNotSet + } + // build internal map of distances + s := &RegionAwareSelector{ + CurrentRegion: currentRegion, + regionDistances: make(map[string]float64), + regions: regions, + SortBy: sortBy, + Algorithm: algorithm, + } + + var currentRC *config.RegionConfig + + for _, region := range regions { + if region.Name == currentRegion { + currentRC = ®ion + break + } + } + + if currentRC == nil && len(regions) > 0 { + return nil, ErrCurrentRegionUnknownLatLon + } + + if currentRC != nil { + for _, region := range regions { + s.regionDistances[region.Name] = distanceBetween(currentRC.Lat, currentRC.Lon, region.Lat, region.Lon) + } + } + + return s, nil +} + +func (s *RegionAwareSelector) SelectNode(nodes []*livekit.Node) (*livekit.Node, error) { + nodes, err := s.SystemLoadSelector.filterNodes(nodes) + if err != nil { + return nil, err + } + + // find nodes nearest to current region + var nearestNodes []*livekit.Node + nearestRegion := "" + minDist := math.MaxFloat64 + for _, node := range nodes { + if node.Region == nearestRegion { + nearestNodes = append(nearestNodes, node) + continue + } + if dist, ok := s.regionDistances[node.Region]; ok { + if dist < minDist { + minDist = dist + nearestRegion = node.Region + nearestNodes = nearestNodes[:0] + nearestNodes = append(nearestNodes, node) + } + } + } + + if len(nearestNodes) > 0 { + nodes = nearestNodes + } + + return SelectSortedNode(nodes, s.SortBy, s.Algorithm) +} + +// haversine(θ) function +func hsin(theta float64) float64 { + return math.Pow(math.Sin(theta/2), 2) +} + +var piBy180 = math.Pi / 180 + +// Haversine Distance Formula +// http://en.wikipedia.org/wiki/Haversine_formula +// from https://gist.github.com/cdipaolo/d3f8db3848278b49db68 +func distanceBetween(lat1, lon1, lat2, lon2 float64) float64 { + // convert to radians + // must cast radius as float to multiply later + var la1, lo1, la2, lo2, r float64 + la1 = lat1 * piBy180 + lo1 = lon1 * piBy180 + la2 = lat2 * piBy180 + lo2 = lon2 * piBy180 + + r = 6378100 // Earth radius in METERS + + // calculate + h := hsin(la2-la1) + math.Cos(la1)*math.Cos(la2)*hsin(lo2-lo1) + + return 2 * r * math.Asin(math.Sqrt(h)) +} diff --git a/livekit/pkg/routing/selector/regionaware_test.go b/livekit/pkg/routing/selector/regionaware_test.go new file mode 100644 index 0000000..7433101 --- /dev/null +++ b/livekit/pkg/routing/selector/regionaware_test.go @@ -0,0 +1,166 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +const ( + loadLimit = 0.5 + regionWest = "us-west" + regionEast = "us-east" + regionSeattle = "seattle" + sortBy = "random" + algorithm = "lowest" +) + +func TestRegionAwareRouting(t *testing.T) { + rc := []config.RegionConfig{ + { + Name: regionWest, + Lat: 37.64046607830567, + Lon: -120.88026233189062, + }, + { + Name: regionEast, + Lat: 40.68914362140307, + Lon: -74.04445748616385, + }, + { + Name: regionSeattle, + Lat: 47.620426730945454, + Lon: -122.34938468973702, + }, + } + t.Run("works without region config", func(t *testing.T) { + nodes := []*livekit.Node{ + newTestNodeInRegion("", false), + } + s, err := selector.NewRegionAwareSelector(regionEast, nil, sortBy, algorithm) + require.NoError(t, err) + + node, err := s.SelectNode(nodes) + require.NoError(t, err) + require.NotNil(t, node) + }) + + t.Run("picks available nodes in same region", func(t *testing.T) { + expectedNode := newTestNodeInRegion(regionEast, true) + nodes := []*livekit.Node{ + newTestNodeInRegion(regionSeattle, true), + newTestNodeInRegion(regionWest, true), + expectedNode, + newTestNodeInRegion(regionEast, false), + } + s, err := selector.NewRegionAwareSelector(regionEast, rc, sortBy, algorithm) + require.NoError(t, err) + s.SysloadLimit = loadLimit + + node, err := s.SelectNode(nodes) + require.NoError(t, err) + require.Equal(t, expectedNode, node) + }) + + t.Run("picks available nodes in same region when current node is first in the list", func(t *testing.T) { + expectedNode := newTestNodeInRegion(regionEast, true) + nodes := []*livekit.Node{ + expectedNode, + newTestNodeInRegion(regionSeattle, true), + newTestNodeInRegion(regionWest, true), + newTestNodeInRegion(regionEast, false), + } + s, err := selector.NewRegionAwareSelector(regionEast, rc, sortBy, algorithm) + require.NoError(t, err) + s.SysloadLimit = loadLimit + + node, err := s.SelectNode(nodes) + require.NoError(t, err) + require.Equal(t, expectedNode, node) + }) + + t.Run("picks closest node in a diff region", func(t *testing.T) { + expectedNode := newTestNodeInRegion(regionWest, true) + nodes := []*livekit.Node{ + newTestNodeInRegion(regionSeattle, false), + expectedNode, + newTestNodeInRegion(regionEast, true), + } + s, err := selector.NewRegionAwareSelector(regionSeattle, rc, sortBy, algorithm) + require.NoError(t, err) + s.SysloadLimit = loadLimit + + node, err := s.SelectNode(nodes) + require.NoError(t, err) + require.Equal(t, expectedNode, node) + }) + + t.Run("handles multiple nodes in same region", func(t *testing.T) { + expectedNode := newTestNodeInRegion(regionWest, true) + nodes := []*livekit.Node{ + newTestNodeInRegion(regionSeattle, false), + newTestNodeInRegion(regionEast, true), + newTestNodeInRegion(regionEast, true), + expectedNode, + expectedNode, + } + s, err := selector.NewRegionAwareSelector(regionSeattle, rc, sortBy, algorithm) + require.NoError(t, err) + s.SysloadLimit = loadLimit + + node, err := s.SelectNode(nodes) + require.NoError(t, err) + require.Equal(t, expectedNode, node) + }) + + t.Run("functions when current region is full", func(t *testing.T) { + nodes := []*livekit.Node{ + newTestNodeInRegion(regionWest, true), + } + s, err := selector.NewRegionAwareSelector(regionEast, rc, sortBy, algorithm) + require.NoError(t, err) + + node, err := s.SelectNode(nodes) + require.NoError(t, err) + require.NotNil(t, node) + }) +} + +func newTestNodeInRegion(region string, available bool) *livekit.Node { + load := float32(0.4) + if !available { + load = 1.0 + } + return &livekit.Node{ + Id: guid.New(utils.NodePrefix), + Region: region, + State: livekit.NodeState_SERVING, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix(), + NumCpus: 1, + LoadAvgLast1Min: load, + }, + } +} diff --git a/livekit/pkg/routing/selector/sortby_test.go b/livekit/pkg/routing/selector/sortby_test.go new file mode 100644 index 0000000..57260e4 --- /dev/null +++ b/livekit/pkg/routing/selector/sortby_test.go @@ -0,0 +1,64 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector_test + +import ( + "testing" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +func SortByTest(t *testing.T, sortBy string) { + sel := selector.SystemLoadSelector{SortBy: sortBy, Algorithm: "lowest"} + nodes := []*livekit.Node{nodeLoadLow, nodeLoadMedium, nodeLoadHigh} + + for range 5 { + node, err := sel.SelectNode(nodes) + if err != nil { + t.Error(err) + } + if node != nodeLoadLow { + t.Error("selected the wrong node for SortBy:", sortBy) + } + } +} + +func TestSortByErrors(t *testing.T) { + sel := selector.SystemLoadSelector{Algorithm: "lowest"} + nodes := []*livekit.Node{nodeLoadLow, nodeLoadMedium, nodeLoadHigh} + + // Test unset sort by option error + _, err := sel.SelectNode(nodes) + if err != selector.ErrSortByNotSet { + t.Error("shouldn't allow empty sortBy") + } + + // Test unknown sort by option error + sel.SortBy = "testFail" + _, err = sel.SelectNode(nodes) + if err != selector.ErrSortByUnknown { + t.Error("shouldn't allow unknown sortBy") + } +} + +func TestSortBy(t *testing.T) { + sortByTests := []string{"sysload", "cpuload", "rooms", "clients", "tracks", "bytespersec"} + + for _, sortBy := range sortByTests { + SortByTest(t, sortBy) + } +} diff --git a/livekit/pkg/routing/selector/sysload.go b/livekit/pkg/routing/selector/sysload.go new file mode 100644 index 0000000..285fca4 --- /dev/null +++ b/livekit/pkg/routing/selector/sysload.go @@ -0,0 +1,54 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "github.com/livekit/protocol/livekit" +) + +// SystemLoadSelector eliminates nodes that surpass has a per-cpu node higher than SysloadLimit +// then selects a node from nodes that are not overloaded +type SystemLoadSelector struct { + SysloadLimit float32 + SortBy string + Algorithm string +} + +func (s *SystemLoadSelector) filterNodes(nodes []*livekit.Node) ([]*livekit.Node, error) { + nodes = GetAvailableNodes(nodes) + if len(nodes) == 0 { + return nil, ErrNoAvailableNodes + } + + nodesLowLoad := make([]*livekit.Node, 0) + for _, node := range nodes { + if GetNodeSysload(node) < s.SysloadLimit { + nodesLowLoad = append(nodesLowLoad, node) + } + } + if len(nodesLowLoad) > 0 { + nodes = nodesLowLoad + } + return nodes, nil +} + +func (s *SystemLoadSelector) SelectNode(nodes []*livekit.Node) (*livekit.Node, error) { + nodes, err := s.filterNodes(nodes) + if err != nil { + return nil, err + } + + return SelectSortedNode(nodes, s.SortBy, s.Algorithm) +} diff --git a/livekit/pkg/routing/selector/sysload_test.go b/livekit/pkg/routing/selector/sysload_test.go new file mode 100644 index 0000000..f2dd342 --- /dev/null +++ b/livekit/pkg/routing/selector/sysload_test.go @@ -0,0 +1,114 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +var ( + nodeLoadLow = &livekit.Node{ + State: livekit.NodeState_SERVING, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix(), + NumCpus: 1, + CpuLoad: 0.1, + LoadAvgLast1Min: 0.0, + NumRooms: 1, + NumClients: 2, + NumTracksIn: 4, + NumTracksOut: 8, + Rates: []*livekit.NodeStatsRate{ + { + BytesIn: 1000, + BytesOut: 2000, + }, + }, + }, + } + + nodeLoadMedium = &livekit.Node{ + State: livekit.NodeState_SERVING, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix(), + NumCpus: 1, + CpuLoad: 0.5, + LoadAvgLast1Min: 0.5, + NumRooms: 5, + NumClients: 10, + NumTracksIn: 20, + NumTracksOut: 200, + Rates: []*livekit.NodeStatsRate{ + { + BytesIn: 5000, + BytesOut: 10000, + }, + }, + }, + } + + nodeLoadHigh = &livekit.Node{ + State: livekit.NodeState_SERVING, + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix(), + NumCpus: 1, + CpuLoad: 0.99, + LoadAvgLast1Min: 2.0, + NumRooms: 10, + NumClients: 20, + NumTracksIn: 40, + NumTracksOut: 800, + Rates: []*livekit.NodeStatsRate{ + { + BytesIn: 10000, + BytesOut: 40000, + }, + }, + }, + } +) + +func TestSystemLoadSelector_SelectNode(t *testing.T) { + sel := selector.SystemLoadSelector{SysloadLimit: 1.0, SortBy: "random", Algorithm: "lowest"} + + var nodes []*livekit.Node + _, err := sel.SelectNode(nodes) + require.Error(t, err, "should error no available nodes") + + // Select a node with high load when no nodes with low load are available + nodes = []*livekit.Node{nodeLoadHigh} + if _, err := sel.SelectNode(nodes); err != nil { + t.Error(err) + } + + // Select a node with low load when available + nodes = []*livekit.Node{nodeLoadLow, nodeLoadHigh} + for range 5 { + node, err := sel.SelectNode(nodes) + if err != nil { + t.Error(err) + } + if node != nodeLoadLow { + t.Error("selected the wrong node") + } + } +} diff --git a/livekit/pkg/routing/selector/utils.go b/livekit/pkg/routing/selector/utils.go new file mode 100644 index 0000000..88b6f29 --- /dev/null +++ b/livekit/pkg/routing/selector/utils.go @@ -0,0 +1,178 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector + +import ( + "math/rand/v2" + "sort" + "time" + + "github.com/thoas/go-funk" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/config" +) + +const AvailableSeconds = 5 + +// checks if a node has been updated recently to be considered for selection +func IsAvailable(node *livekit.Node) bool { + if node.Stats == nil { + // available till stats are available + return true + } + + delta := time.Now().Unix() - node.Stats.UpdatedAt + return int(delta) < AvailableSeconds +} + +func GetAvailableNodes(nodes []*livekit.Node) []*livekit.Node { + return funk.Filter(nodes, func(node *livekit.Node) bool { + return IsAvailable(node) && node.State == livekit.NodeState_SERVING + }).([]*livekit.Node) +} + +func GetNodeSysload(node *livekit.Node) float32 { + stats := node.Stats + numCpus := stats.NumCpus + if numCpus == 0 { + numCpus = 1 + } + return stats.LoadAvgLast1Min / float32(numCpus) +} + +// TODO: check remote node configured limit, instead of this node's config +func LimitsReached(limitConfig config.LimitConfig, nodeStats *livekit.NodeStats) bool { + if nodeStats == nil { + return false + } + + if limitConfig.NumTracks > 0 && limitConfig.NumTracks <= nodeStats.NumTracksIn+nodeStats.NumTracksOut { + return true + } + + rate := &livekit.NodeStatsRate{} + if len(nodeStats.Rates) > 0 { + rate = nodeStats.Rates[0] + } + if limitConfig.BytesPerSec > 0 && limitConfig.BytesPerSec <= rate.BytesIn+rate.BytesOut { + return true + } + + return false +} + +func SelectSortedNode(nodes []*livekit.Node, sortBy string, algorithm string) (*livekit.Node, error) { + if sortBy == "" { + return nil, ErrSortByNotSet + } + if algorithm == "" { + return nil, ErrAlgorithmNotSet + } + + switch algorithm { + case "lowest": // examine all nodes and select the lowest based on sort criteria + return selectLowestSortedNode(nodes, sortBy) + case "twochoice": // randomly select two nodes and return the lowest based on sort criteria "Power of Two Random Choices" + return selectTwoChoiceSortedNode(nodes, sortBy) + default: + return nil, ErrAlgorithmUnknown + } +} + +func selectTwoChoiceSortedNode(nodes []*livekit.Node, sortBy string) (*livekit.Node, error) { + if len(nodes) <= 2 { + return selectLowestSortedNode(nodes, sortBy) + } + + // randomly select two nodes + node1, node2, err := selectTwoRandomNodes(nodes) + if err != nil { + return nil, err + } + + // compare the two nodes based on the sort criteria + if node1 == nil || node2 == nil { + return nil, ErrNoAvailableNodes + } + + selectedNode, err := selectLowestSortedNode([]*livekit.Node{node1, node2}, sortBy) + if err != nil { + return nil, err + } + + return selectedNode, nil +} + +func selectLowestSortedNode(nodes []*livekit.Node, sortBy string) (*livekit.Node, error) { + // Return a node based on what it should be sorted by for priority + switch sortBy { + case "random": + idx := funk.RandomInt(0, len(nodes)) + return nodes[idx], nil + case "sysload": + sort.Slice(nodes, func(i, j int) bool { + return GetNodeSysload(nodes[i]) < GetNodeSysload(nodes[j]) + }) + return nodes[0], nil + case "cpuload": + sort.Slice(nodes, func(i, j int) bool { + return nodes[i].Stats.CpuLoad < nodes[j].Stats.CpuLoad + }) + return nodes[0], nil + case "rooms": + sort.Slice(nodes, func(i, j int) bool { + return nodes[i].Stats.NumRooms < nodes[j].Stats.NumRooms + }) + return nodes[0], nil + case "clients": + sort.Slice(nodes, func(i, j int) bool { + return nodes[i].Stats.NumClients < nodes[j].Stats.NumClients + }) + return nodes[0], nil + case "tracks": + sort.Slice(nodes, func(i, j int) bool { + return nodes[i].Stats.NumTracksIn+nodes[i].Stats.NumTracksOut < nodes[j].Stats.NumTracksIn+nodes[j].Stats.NumTracksOut + }) + return nodes[0], nil + case "bytespersec": + sort.Slice(nodes, func(i, j int) bool { + ratei := &livekit.NodeStatsRate{} + if len(nodes[i].Stats.Rates) > 0 { + ratei = nodes[i].Stats.Rates[0] + } + + ratej := &livekit.NodeStatsRate{} + if len(nodes[j].Stats.Rates) > 0 { + ratej = nodes[j].Stats.Rates[0] + } + return ratei.BytesIn+ratei.BytesOut < ratej.BytesIn+ratej.BytesOut + }) + return nodes[0], nil + default: + return nil, ErrSortByUnknown + } +} + +func selectTwoRandomNodes(nodes []*livekit.Node) (*livekit.Node, *livekit.Node, error) { + if len(nodes) < 2 { + return nil, nil, ErrNoAvailableNodes + } + + shuffledIndices := rand.Perm(len(nodes)) + + return nodes[shuffledIndices[0]], nodes[shuffledIndices[1]], nil +} diff --git a/livekit/pkg/routing/selector/utils_test.go b/livekit/pkg/routing/selector/utils_test.go new file mode 100644 index 0000000..4f62f6d --- /dev/null +++ b/livekit/pkg/routing/selector/utils_test.go @@ -0,0 +1,46 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package selector_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +func TestIsAvailable(t *testing.T) { + t.Run("still available", func(t *testing.T) { + n := &livekit.Node{ + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix() - 3, + }, + } + require.True(t, selector.IsAvailable(n)) + }) + + t.Run("expired", func(t *testing.T) { + n := &livekit.Node{ + Stats: &livekit.NodeStats{ + UpdatedAt: time.Now().Unix() - 20, + }, + } + require.False(t, selector.IsAvailable(n)) + }) +} diff --git a/livekit/pkg/routing/signal.go b/livekit/pkg/routing/signal.go new file mode 100644 index 0000000..caec171 --- /dev/null +++ b/livekit/pkg/routing/signal.go @@ -0,0 +1,385 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "errors" + "sync" + "time" + + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/middleware" +) + +var ErrSignalWriteFailed = errors.New("signal write failed") +var ErrSignalMessageDropped = errors.New("signal message dropped") + +//counterfeiter:generate . SignalClient +type SignalClient interface { + ActiveCount() int + StartParticipantSignal(ctx context.Context, roomName livekit.RoomName, pi ParticipantInit, nodeID livekit.NodeID) (connectionID livekit.ConnectionID, reqSink MessageSink, resSource MessageSource, err error) +} + +type signalClient struct { + nodeID livekit.NodeID + config config.SignalRelayConfig + client rpc.TypedSignalClient + active atomic.Int32 +} + +func NewSignalClient(nodeID livekit.NodeID, bus psrpc.MessageBus, config config.SignalRelayConfig) (SignalClient, error) { + client, err := rpc.NewTypedSignalClient( + nodeID, + bus, + middleware.WithClientMetrics(rpc.PSRPCMetricsObserver{}), + psrpc.WithClientChannelSize(config.StreamBufferSize), + ) + if err != nil { + return nil, err + } + + return &signalClient{ + nodeID: nodeID, + config: config, + client: client, + }, nil +} + +func (r *signalClient) ActiveCount() int { + return int(r.active.Load()) +} + +func (r *signalClient) StartParticipantSignal( + ctx context.Context, + roomName livekit.RoomName, + pi ParticipantInit, + nodeID livekit.NodeID, +) ( + connectionID livekit.ConnectionID, + reqSink MessageSink, + resSource MessageSource, + err error, +) { + connectionID = livekit.ConnectionID(guid.New("CO_")) + ss, err := pi.ToStartSession(roomName, connectionID) + if err != nil { + return + } + + l := utils.GetLogger(ctx).WithValues( + "room", roomName, + "reqNodeID", nodeID, + "participant", pi.Identity, + "connID", connectionID, + "participantInit", pi, + "startSession", logger.Proto(ss), + ) + + l.Debugw("starting signal connection") + + stream, err := r.client.RelaySignal(ctx, nodeID) + if err != nil { + prometheus.RecordSignalRequestFailure() + return + } + + err = stream.Send(&rpc.RelaySignalRequest{StartSession: ss}) + if err != nil { + stream.Close(err) + prometheus.RecordSignalRequestFailure() + return + } + + sink := NewSignalMessageSink(SignalSinkParams[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse]{ + Logger: l, + Stream: stream, + Config: r.config, + Writer: signalRequestMessageWriter{}, + CloseOnFailure: true, + BlockOnClose: true, + ConnectionID: connectionID, + }) + resChan := NewDefaultMessageChannel(connectionID) + + go func() { + r.active.Inc() + defer r.active.Dec() + + err := CopySignalStreamToMessageChannel[*rpc.RelaySignalRequest, *rpc.RelaySignalResponse]( + stream, + resChan, + signalResponseMessageReader{}, + r.config, + prometheus.RecordSignalResponseSuccess, + prometheus.RecordSignalResponseFailure, + ) + l.Debugw("signal stream closed", "error", err) + + resChan.Close() + }() + + return connectionID, sink, resChan, nil +} + +// ------------------------------ + +type signalRequestMessageWriter struct{} + +func (e signalRequestMessageWriter) Write(seq uint64, close bool, msgs []proto.Message) *rpc.RelaySignalRequest { + r := &rpc.RelaySignalRequest{ + Seq: seq, + Requests: make([]*livekit.SignalRequest, 0, len(msgs)), + Close: close, + } + for _, m := range msgs { + r.Requests = append(r.Requests, m.(*livekit.SignalRequest)) + } + return r +} + +// ------------------------------- + +type signalResponseMessageReader struct{} + +func (e signalResponseMessageReader) Read(rm *rpc.RelaySignalResponse) ([]proto.Message, error) { + msgs := make([]proto.Message, 0, len(rm.Responses)) + for _, m := range rm.Responses { + msgs = append(msgs, m) + } + return msgs, nil +} + +// ----------------------------------------- + +type RelaySignalMessage interface { + proto.Message + GetSeq() uint64 + GetClose() bool +} + +type SignalMessageWriter[SendType RelaySignalMessage] interface { + Write(seq uint64, close bool, msgs []proto.Message) SendType +} + +type SignalMessageReader[RecvType RelaySignalMessage] interface { + Read(msg RecvType) ([]proto.Message, error) +} + +func CopySignalStreamToMessageChannel[SendType, RecvType RelaySignalMessage]( + stream psrpc.Stream[SendType, RecvType], + ch *MessageChannel, + reader SignalMessageReader[RecvType], + config config.SignalRelayConfig, + promSignalSuccess func(), + promSignalFailure func(), +) error { + r := &signalMessageReader[SendType, RecvType]{ + reader: reader, + config: config, + } + for msg := range stream.Channel() { + res, err := r.Read(msg) + if err != nil { + promSignalFailure() + return err + } + + for _, r := range res { + if err = ch.WriteMessage(r); err != nil { + promSignalFailure() + return err + } + promSignalSuccess() + } + + if msg.GetClose() { + return stream.Close(nil) + } + } + return stream.Err() +} + +// ---------------------------------------- + +type signalMessageReader[SendType, RecvType RelaySignalMessage] struct { + seq uint64 + reader SignalMessageReader[RecvType] + config config.SignalRelayConfig +} + +func (r *signalMessageReader[SendType, RecvType]) Read(msg RecvType) ([]proto.Message, error) { + res, err := r.reader.Read(msg) + if err != nil { + return nil, err + } + + if r.seq < msg.GetSeq() { + return nil, ErrSignalMessageDropped + } + if r.seq > msg.GetSeq() { + n := int(r.seq - msg.GetSeq()) + if n > len(res) { + n = len(res) + } + res = res[n:] + } + r.seq += uint64(len(res)) + + return res, nil +} + +// ---------------------------------------- + +type SignalSinkParams[SendType, RecvType RelaySignalMessage] struct { + Stream psrpc.Stream[SendType, RecvType] + Logger logger.Logger + Config config.SignalRelayConfig + Writer SignalMessageWriter[SendType] + CloseOnFailure bool + BlockOnClose bool + ConnectionID livekit.ConnectionID +} + +func NewSignalMessageSink[SendType, RecvType RelaySignalMessage](params SignalSinkParams[SendType, RecvType]) MessageSink { + return &signalMessageSink[SendType, RecvType]{ + SignalSinkParams: params, + } +} + +type signalMessageSink[SendType, RecvType RelaySignalMessage] struct { + SignalSinkParams[SendType, RecvType] + + mu sync.Mutex + seq uint64 + queue []proto.Message + writing bool + draining bool +} + +func (s *signalMessageSink[SendType, RecvType]) Close() { + s.mu.Lock() + s.draining = true + if !s.writing { + s.writing = true + go s.write() + } + s.mu.Unlock() + + // conditionally block while closing to wait for outgoing messages to drain + // + // on media the signal sink shares a goroutine with other signal connection + // attempts from the same participant so blocking delays establishing new + // sessions during reconnect. + // + // on controller closing without waiting for the outstanding messages to + // drain causes leave messages to be dropped from the write queue. when + // this happens other participants in the room aren't notified about the + // departure until the participant times out. + if s.BlockOnClose { + <-s.Stream.Context().Done() + } +} + +func (s *signalMessageSink[SendType, RecvType]) IsClosed() bool { + return s.Stream.Err() != nil +} + +func (s *signalMessageSink[SendType, RecvType]) write() { + interval := s.Config.MinRetryInterval + deadline := time.Now().Add(s.Config.RetryTimeout) + var err error + + s.mu.Lock() + for { + close := s.draining + if (!close && len(s.queue) == 0) || s.IsClosed() { + break + } + msg, n := s.Writer.Write(s.seq, close, s.queue), len(s.queue) + s.mu.Unlock() + + err = s.Stream.Send(msg, psrpc.WithTimeout(interval)) + if err != nil { + if time.Now().After(deadline) { + s.Logger.Warnw("could not send signal message", err) + + s.mu.Lock() + s.seq += uint64(len(s.queue)) + s.queue = nil + break + } + + interval *= 2 + if interval > s.Config.MaxRetryInterval { + interval = s.Config.MaxRetryInterval + } + } + + s.mu.Lock() + if err == nil { + interval = s.Config.MinRetryInterval + deadline = time.Now().Add(s.Config.RetryTimeout) + + s.seq += uint64(n) + s.queue = s.queue[n:] + + if close { + break + } + } + } + + s.writing = false + if s.draining { + s.Stream.Close(nil) + } + if err != nil && s.CloseOnFailure { + s.Stream.Close(ErrSignalWriteFailed) + } + s.mu.Unlock() +} + +func (s *signalMessageSink[SendType, RecvType]) WriteMessage(msg proto.Message) error { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.Stream.Err(); err != nil { + return err + } else if s.draining { + return psrpc.ErrStreamClosed + } + + s.queue = append(s.queue, msg) + if !s.writing { + s.writing = true + go s.write() + } + return nil +} + +func (s *signalMessageSink[SendType, RecvType]) ConnectionID() livekit.ConnectionID { + return s.SignalSinkParams.ConnectionID +} diff --git a/livekit/pkg/rtc/clientinfo.go b/livekit/pkg/rtc/clientinfo.go new file mode 100644 index 0000000..54bdb2f --- /dev/null +++ b/livekit/pkg/rtc/clientinfo.go @@ -0,0 +1,142 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "strconv" + "strings" + + "github.com/livekit/protocol/livekit" +) + +type ClientInfo struct { + *livekit.ClientInfo +} + +func (c ClientInfo) isFirefox() bool { + return c.ClientInfo != nil && (strings.EqualFold(c.ClientInfo.Browser, "firefox") || strings.EqualFold(c.ClientInfo.Browser, "firefox mobile")) +} + +func (c ClientInfo) isSafari() bool { + return c.ClientInfo != nil && strings.EqualFold(c.ClientInfo.Browser, "safari") +} + +func (c ClientInfo) isGo() bool { + return c.ClientInfo != nil && c.ClientInfo.Sdk == livekit.ClientInfo_GO +} + +func (c ClientInfo) isLinux() bool { + return c.ClientInfo != nil && strings.EqualFold(c.ClientInfo.Os, "linux") +} + +func (c ClientInfo) isAndroid() bool { + return c.ClientInfo != nil && strings.EqualFold(c.ClientInfo.Os, "android") +} + +func (c ClientInfo) isOBS() bool { + return c.ClientInfo != nil && strings.Contains(c.ClientInfo.Browser, "OBS") +} + +func (c ClientInfo) SupportsAudioRED() bool { + return !c.isFirefox() && !c.isSafari() +} + +func (c ClientInfo) SupportsPrflxOverRelay() bool { + return !c.isFirefox() +} + +// GoSDK(pion) relies on rtp packets to fire ontrack event, browsers and native (libwebrtc) rely on sdp +func (c ClientInfo) FireTrackByRTPPacket() bool { + return c.isGo() +} + +func (c ClientInfo) SupportsCodecChange() bool { + return c.ClientInfo != nil && c.ClientInfo.Sdk != livekit.ClientInfo_GO && c.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN +} + +func (c ClientInfo) CanHandleReconnectResponse() bool { + if c.Sdk == livekit.ClientInfo_JS { + // JS handles Reconnect explicitly in 1.6.3, prior to 1.6.4 it could not handle unknown responses + if c.compareVersion("1.6.3") < 0 { + return false + } + } + return true +} + +func (c ClientInfo) SupportsICETCP() bool { + if c.ClientInfo == nil { + return false + } + if c.ClientInfo.Sdk == livekit.ClientInfo_GO { + // Go does not support active TCP + return false + } + if c.ClientInfo.Sdk == livekit.ClientInfo_SWIFT { + // ICE/TCP added in 1.0.5 + return c.compareVersion("1.0.5") >= 0 + } + // most SDKs support ICE/TCP + return true +} + +func (c ClientInfo) SupportsChangeRTPSenderEncodingActive() bool { + return !c.isFirefox() +} + +func (c ClientInfo) ComplyWithCodecOrderInSDPAnswer() bool { + return !((c.isLinux() || c.isAndroid()) && c.isFirefox()) +} + +// Rust SDK can't decode unknown signal message (TrackSubscribed and ErrorResponse) +func (c ClientInfo) SupportsTrackSubscribedEvent() bool { + return !(c.ClientInfo.GetSdk() == livekit.ClientInfo_RUST && c.ClientInfo.GetProtocol() < 10) +} + +func (c ClientInfo) SupportsRequestResponse() bool { + return c.SupportsTrackSubscribedEvent() +} + +func (c ClientInfo) SupportsSctpZeroChecksum() bool { + return !(c.ClientInfo.GetSdk() == livekit.ClientInfo_UNKNOWN || + (c.isGo() && c.compareVersion("2.4.0") < 0)) +} + +// compareVersion compares a semver against the current client SDK version +// returning 1 if current version is greater than version +// 0 if they are the same, and -1 if it's an earlier version +func (c ClientInfo) compareVersion(version string) int { + if c.ClientInfo == nil { + return -1 + } + parts0 := strings.Split(c.ClientInfo.Version, ".") + parts1 := strings.Split(version, ".") + ints0 := make([]int, 3) + ints1 := make([]int, 3) + for i := range 3 { + if len(parts0) > i { + ints0[i], _ = strconv.Atoi(parts0[i]) + } + if len(parts1) > i { + ints1[i], _ = strconv.Atoi(parts1[i]) + } + if ints0[i] > ints1[i] { + return 1 + } else if ints0[i] < ints1[i] { + return -1 + } + } + return 0 +} diff --git a/livekit/pkg/rtc/clientinfo_test.go b/livekit/pkg/rtc/clientinfo_test.go new file mode 100644 index 0000000..5664024 --- /dev/null +++ b/livekit/pkg/rtc/clientinfo_test.go @@ -0,0 +1,59 @@ +/* + * Copyright 2022 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rtc + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" +) + +func TestClientInfo_CompareVersion(t *testing.T) { + c := ClientInfo{ + ClientInfo: &livekit.ClientInfo{ + Version: "1", + }, + } + require.Equal(t, 1, c.compareVersion("0.1.0")) + require.Equal(t, 0, c.compareVersion("1.0.0")) + require.Equal(t, -1, c.compareVersion("1.0.5")) +} + +func TestClientInfo_SupportsICETCP(t *testing.T) { + t.Run("GO SDK cannot support TCP", func(t *testing.T) { + c := ClientInfo{ + ClientInfo: &livekit.ClientInfo{ + Sdk: livekit.ClientInfo_GO, + }, + } + require.False(t, c.SupportsICETCP()) + }) + + t.Run("Swift SDK cannot support TCP before 1.0.5", func(t *testing.T) { + c := ClientInfo{ + ClientInfo: &livekit.ClientInfo{ + Sdk: livekit.ClientInfo_SWIFT, + Version: "1.0.4", + }, + } + require.False(t, c.SupportsICETCP()) + c.Version = "1.0.5" + require.True(t, c.SupportsICETCP()) + }) +} diff --git a/livekit/pkg/rtc/config.go b/livekit/pkg/rtc/config.go new file mode 100644 index 0000000..2f30060 --- /dev/null +++ b/livekit/pkg/rtc/config.go @@ -0,0 +1,207 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/mediatransportutil/pkg/rtcconfig" +) + +const ( + frameMarkingURI = "urn:ietf:params:rtp-hdrext:framemarking" + repairedRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id" +) + +type WebRTCConfig struct { + rtcconfig.WebRTCConfig + + BufferFactory *buffer.Factory + Receiver ReceiverConfig + Publisher DirectionConfig + Subscriber DirectionConfig +} + +type ReceiverConfig struct { + PacketBufferSizeVideo int + PacketBufferSizeAudio int +} + +type RTPHeaderExtensionConfig struct { + Audio []string + Video []string +} + +type RTCPFeedbackConfig struct { + Audio []webrtc.RTCPFeedback + Video []webrtc.RTCPFeedback +} + +type DirectionConfig struct { + RTPHeaderExtension RTPHeaderExtensionConfig + RTCPFeedback RTCPFeedbackConfig +} + +func NewWebRTCConfig(conf *config.Config) (*WebRTCConfig, error) { + rtcConf := conf.RTC + + webRTCConfig, err := rtcconfig.NewWebRTCConfig(&rtcConf.RTCConfig, conf.Development) + if err != nil { + return nil, err + } + + // we don't want to use active TCP on a server, clients should be dialing + webRTCConfig.SettingEngine.DisableActiveTCP(true) + + if rtcConf.PacketBufferSize == 0 { + rtcConf.PacketBufferSize = 500 + } + if rtcConf.PacketBufferSizeVideo == 0 { + rtcConf.PacketBufferSizeVideo = rtcConf.PacketBufferSize + } + if rtcConf.PacketBufferSizeAudio == 0 { + rtcConf.PacketBufferSizeAudio = rtcConf.PacketBufferSize + } + + return &WebRTCConfig{ + WebRTCConfig: *webRTCConfig, + Receiver: ReceiverConfig{ + PacketBufferSizeVideo: rtcConf.PacketBufferSizeVideo, + PacketBufferSizeAudio: rtcConf.PacketBufferSizeAudio, + }, + Publisher: getPublisherConfig(false), + Subscriber: getSubscriberConfig(rtcConf.CongestionControl.UseSendSideBWEInterceptor || rtcConf.CongestionControl.UseSendSideBWE), + }, nil +} + +func (c *WebRTCConfig) UpdatePublisherConfig(consolidated bool) { + c.Publisher = getPublisherConfig(consolidated) +} + +func (c *WebRTCConfig) UpdateSubscriberConfig(ccConf config.CongestionControlConfig) { + c.Subscriber = getSubscriberConfig(ccConf.UseSendSideBWEInterceptor || ccConf.UseSendSideBWE) +} + +func (c *WebRTCConfig) SetBufferFactory(factory *buffer.Factory) { + c.BufferFactory = factory + c.SettingEngine.BufferFactory = factory.GetOrNew +} + +func getPublisherConfig(consolidated bool) DirectionConfig { + if consolidated { + return DirectionConfig{ + RTPHeaderExtension: RTPHeaderExtensionConfig{ + Audio: []string{ + sdp.SDESMidURI, + sdp.SDESRTPStreamIDURI, + sdp.AudioLevelURI, + act.AbsCaptureTimeURI, + }, + Video: []string{ + sdp.SDESMidURI, + sdp.SDESRTPStreamIDURI, + sdp.TransportCCURI, + sdp.ABSSendTimeURI, + frameMarkingURI, + dd.ExtensionURI, + repairedRTPStreamIDURI, + act.AbsCaptureTimeURI, + }, + }, + RTCPFeedback: RTCPFeedbackConfig{ + Audio: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBNACK}, + }, + Video: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBTransportCC}, + {Type: webrtc.TypeRTCPFBGoogREMB}, + {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, + {Type: webrtc.TypeRTCPFBNACK}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}, + }, + }, + } + } + + return DirectionConfig{ + RTPHeaderExtension: RTPHeaderExtensionConfig{ + Audio: []string{ + sdp.SDESMidURI, + sdp.SDESRTPStreamIDURI, + sdp.AudioLevelURI, + act.AbsCaptureTimeURI, + }, + Video: []string{ + sdp.SDESMidURI, + sdp.SDESRTPStreamIDURI, + sdp.TransportCCURI, + frameMarkingURI, + dd.ExtensionURI, + repairedRTPStreamIDURI, + act.AbsCaptureTimeURI, + }, + }, + RTCPFeedback: RTCPFeedbackConfig{ + Audio: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBNACK}, + }, + Video: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBTransportCC}, + {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, + {Type: webrtc.TypeRTCPFBNACK}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}, + }, + }, + } +} + +func getSubscriberConfig(enableTWCC bool) DirectionConfig { + subscriberConfig := DirectionConfig{ + RTPHeaderExtension: RTPHeaderExtensionConfig{ + Video: []string{ + dd.ExtensionURI, + act.AbsCaptureTimeURI, + }, + Audio: []string{ + act.AbsCaptureTimeURI, + }, + }, + RTCPFeedback: RTCPFeedbackConfig{ + Audio: []webrtc.RTCPFeedback{ + // always enable NACK for audio but disable it later for red enabled transceiver. https://github.com/pion/webrtc/pull/2972 + {Type: webrtc.TypeRTCPFBNACK}, + }, + Video: []webrtc.RTCPFeedback{ + {Type: webrtc.TypeRTCPFBCCM, Parameter: "fir"}, + {Type: webrtc.TypeRTCPFBNACK}, + {Type: webrtc.TypeRTCPFBNACK, Parameter: "pli"}, + }, + }, + } + if enableTWCC { + subscriberConfig.RTPHeaderExtension.Video = append(subscriberConfig.RTPHeaderExtension.Video, sdp.TransportCCURI) + subscriberConfig.RTCPFeedback.Video = append(subscriberConfig.RTCPFeedback.Video, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBTransportCC}) + } else { + subscriberConfig.RTPHeaderExtension.Video = append(subscriberConfig.RTPHeaderExtension.Video, sdp.ABSSendTimeURI) + subscriberConfig.RTCPFeedback.Video = append(subscriberConfig.RTCPFeedback.Video, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBGoogREMB}) + } + + return subscriberConfig +} diff --git a/livekit/pkg/rtc/datadowntrack.go b/livekit/pkg/rtc/datadowntrack.go new file mode 100644 index 0000000..098f940 --- /dev/null +++ b/livekit/pkg/rtc/datadowntrack.go @@ -0,0 +1,99 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "time" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +var _ types.DataDownTrack = (*DataDownTrack)(nil) +var _ types.DataTrackSender = (*DataDownTrack)(nil) + +type DataDownTrackParams struct { + Logger logger.Logger + SubscriberID livekit.ParticipantID + PublishDataTrack types.DataTrack + Handle uint16 + Transport types.DataTrackTransport +} + +type DataDownTrack struct { + params DataDownTrackParams + createdAt int64 +} + +func NewDataDownTrack(params DataDownTrackParams) (*DataDownTrack, error) { + d := &DataDownTrack{ + params: params, + createdAt: time.Now().UnixNano(), + } + + if err := d.params.PublishDataTrack.AddDataDownTrack(d); err != nil { + d.params.Logger.Warnw("could not add data down track", err) + return nil, err + } + + d.params.Logger.Infow("created data down track", "name", d.Name()) + return d, nil +} + +func (d *DataDownTrack) Close() { + d.params.Logger.Infow("closing data down track", "name", d.Name()) + d.params.PublishDataTrack.DeleteDataDownTrack(d.SubscriberID()) +} + +func (d *DataDownTrack) Handle() uint16 { + return d.params.Handle +} + +func (d *DataDownTrack) PublishDataTrack() types.DataTrack { + return d.params.PublishDataTrack +} + +func (d *DataDownTrack) ID() livekit.TrackID { + return d.params.PublishDataTrack.ID() +} + +func (d *DataDownTrack) Name() string { + return d.params.PublishDataTrack.Name() +} + +func (d *DataDownTrack) SubscriberID() livekit.ParticipantID { + // add `createdAt` to ensure repeated subscriptions from same subscriber to same publisher does not collide + return livekit.ParticipantID(fmt.Sprintf("%s:%d", d.params.SubscriberID, d.createdAt)) +} + +func (d *DataDownTrack) WritePacket(data []byte, packet *datatrack.Packet, _arrivalTime int64) { + forwardedPacket := *packet + forwardedPacket.Handle = d.params.Handle + buf, err := forwardedPacket.Marshal() + if err != nil { + d.params.Logger.Warnw("could not marshal data track message", err) + return + } + if err := d.params.Transport.SendDataTrackMessage(buf); err != nil { + d.params.Logger.Warnw("could not send data track message", err, "handle", d.params.Handle) + } +} + +func (d *DataDownTrack) UpdateSubscriptionOptions(subscriptionOptions *livekit.DataTrackSubscriptionOptions) { + // DT-TODO +} diff --git a/livekit/pkg/rtc/datatrack.go b/livekit/pkg/rtc/datatrack.go new file mode 100644 index 0000000..0401e5d --- /dev/null +++ b/livekit/pkg/rtc/datatrack.go @@ -0,0 +1,162 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "errors" + "sync" + + "github.com/frostbyte73/core" + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + sfuutils "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" +) + +var ( + errReceiverClosed = errors.New("datatrack is closed") +) + +var _ types.DataTrack = (*DataTrack)(nil) + +type DataTrackParams struct { + Logger logger.Logger + ParticipantID func() livekit.ParticipantID + ParticipantIdentity livekit.ParticipantIdentity +} + +type DataTrack struct { + params DataTrackParams + + lock sync.Mutex + dti *livekit.DataTrackInfo + subscribedTracks map[livekit.ParticipantID]types.DataDownTrack + + downTrackSpreader *sfuutils.DownTrackSpreader[types.DataTrackSender] + + closed core.Fuse +} + +func NewDataTrack(params DataTrackParams, dti *livekit.DataTrackInfo) *DataTrack { + d := &DataTrack{ + params: params, + dti: dti, + subscribedTracks: make(map[livekit.ParticipantID]types.DataDownTrack), + downTrackSpreader: sfuutils.NewDownTrackSpreader[types.DataTrackSender](sfuutils.DownTrackSpreaderParams{ + Threshold: 20, + Logger: params.Logger, + }), + } + d.params.Logger.Infow("created data track", "name", d.Name()) + return d +} + +func (d *DataTrack) Close() { + d.params.Logger.Infow("closing data track", "name", d.Name()) + d.closed.Break() +} + +func (d *DataTrack) PublisherID() livekit.ParticipantID { + return d.params.ParticipantID() +} + +func (d *DataTrack) PublisherIdentity() livekit.ParticipantIdentity { + return d.params.ParticipantIdentity +} + +func (d *DataTrack) ToProto() *livekit.DataTrackInfo { + return utils.CloneProto(d.dti) +} + +func (d *DataTrack) PubHandle() uint16 { + return uint16(d.dti.PubHandle) +} + +func (d *DataTrack) ID() livekit.TrackID { + return livekit.TrackID(d.dti.Sid) +} + +func (d *DataTrack) Name() string { + return d.dti.Name +} + +func (d *DataTrack) AddSubscriber(sub types.LocalParticipant) (types.DataDownTrack, error) { + d.lock.Lock() + defer d.lock.Unlock() + + if _, ok := d.subscribedTracks[sub.ID()]; ok { + return nil, errAlreadySubscribed + } + + dataDownTrack, err := NewDataDownTrack(DataDownTrackParams{ + Logger: sub.GetLogger().WithValues("trackID", d.ID()), + SubscriberID: sub.ID(), + PublishDataTrack: d, + Handle: sub.GetNextSubscribedDataTrackHandle(), + Transport: sub.GetDataTrackTransport(), + }) + if err != nil { + return nil, err + } + + d.subscribedTracks[sub.ID()] = dataDownTrack + return dataDownTrack, nil +} + +func (d *DataTrack) RemoveSubscriber(subID livekit.ParticipantID) { + d.lock.Lock() + dataDownTrack, ok := d.subscribedTracks[subID] + delete(d.subscribedTracks, subID) + d.lock.Unlock() + + if ok { + dataDownTrack.Close() + } +} + +func (d *DataTrack) IsSubscriber(subID livekit.ParticipantID) bool { + d.lock.Lock() + defer d.lock.Unlock() + + _, ok := d.subscribedTracks[subID] + return ok +} + +func (d *DataTrack) AddDataDownTrack(dts types.DataTrackSender) error { + if d.closed.IsBroken() { + return errReceiverClosed + } + + if d.downTrackSpreader.HasDownTrack(dts.SubscriberID()) { + d.params.Logger.Infow("subscriberID already exists, replacing data downtrack", "subscriberID", dts.SubscriberID()) + } + + d.downTrackSpreader.Store(dts) + d.params.Logger.Infow("data downtrack added", "subscriberID", dts.SubscriberID()) + return nil +} + +func (d *DataTrack) DeleteDataDownTrack(subscriberID livekit.ParticipantID) { + d.downTrackSpreader.Free(subscriberID) + d.params.Logger.Infow("data downtrack deleted", "subscriberID", subscriberID) +} + +func (d *DataTrack) HandlePacket(data []byte, packet *datatrack.Packet, arrivalTime int64) { + d.downTrackSpreader.Broadcast(func(dts types.DataTrackSender) { + dts.WritePacket(data, packet, arrivalTime) + }) +} diff --git a/livekit/pkg/rtc/datatrack/extension_participant_sid.go b/livekit/pkg/rtc/datatrack/extension_participant_sid.go new file mode 100644 index 0000000..bfad08a --- /dev/null +++ b/livekit/pkg/rtc/datatrack/extension_participant_sid.go @@ -0,0 +1,59 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datatrack + +import ( + "errors" + + "github.com/livekit/protocol/livekit" +) + +type ExtensionParticipantSid struct { + participantID livekit.ParticipantID +} + +func NewExtensionParticipantSid(participantID livekit.ParticipantID) (*ExtensionParticipantSid, error) { + if len(participantID) >= 65536 { + return nil, errors.New("participantID too long") + } + + return &ExtensionParticipantSid{participantID}, nil +} + +func (e *ExtensionParticipantSid) ParticipantID() livekit.ParticipantID { + return e.participantID +} + +func (e *ExtensionParticipantSid) Marshal() (Extension, error) { + data := make([]byte, len(e.participantID)) + copy(data, e.participantID) + return Extension{ + id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), + data: data, + }, nil +} + +func (e *ExtensionParticipantSid) Unmarshal(ext Extension) error { + if ext.id != uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID) { + return errors.New("invalid extension ID") + } + + if len(ext.data) == 0 { + return errors.New("empty extension data") + } + + e.participantID = livekit.ParticipantID(ext.data) + return nil +} diff --git a/livekit/pkg/rtc/datatrack/extension_participant_sid_test.go b/livekit/pkg/rtc/datatrack/extension_participant_sid_test.go new file mode 100644 index 0000000..243bfc6 --- /dev/null +++ b/livekit/pkg/rtc/datatrack/extension_participant_sid_test.go @@ -0,0 +1,46 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datatrack + +import ( + "testing" + + "github.com/livekit/protocol/livekit" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtensionParticipantSid(t *testing.T) { + longTestParticipantID := livekit.ParticipantID(make([]byte, 65536)) + extParticipantSid, err := NewExtensionParticipantSid(longTestParticipantID) + require.Error(t, err) + + testParticipantID := livekit.ParticipantID("test") + extParticipantSid, err = NewExtensionParticipantSid(testParticipantID) + require.NoError(t, err) + + expectedExt := Extension{ + id: uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID), + data: []byte{'t', 'e', 's', 't'}, + } + ext, err := extParticipantSid.Marshal() + require.NoError(t, err) + require.Equal(t, expectedExt, ext) + + var unmarshaled ExtensionParticipantSid + err = unmarshaled.Unmarshal(ext) + require.NoError(t, err) + assert.Equal(t, testParticipantID, unmarshaled.ParticipantID()) +} diff --git a/livekit/pkg/rtc/datatrack/packet.go b/livekit/pkg/rtc/datatrack/packet.go new file mode 100644 index 0000000..fe4da8f --- /dev/null +++ b/livekit/pkg/rtc/datatrack/packet.go @@ -0,0 +1,269 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datatrack + +import ( + "encoding/binary" + "errors" + "fmt" +) + +var ( + errHeaderSizeInsufficient = errors.New("data track packet header size insufficient") + errBufferSizeInsufficient = errors.New("data track packet buffer size insufficient") + errExtensionSizeInsufficient = errors.New("data track packet extension size insufficient") + errExtensionNotFound = errors.New("data track packet extension not found") +) + +const ( + headerLength = 12 + + versionShift = 5 + versionMask = (1 << 3) - 1 + + startOfFrameShift = 4 + startOfFrameMask = (1 << 1) - 1 + + finalOfFrameShift = 3 + finalOfFrameMask = (1 << 1) - 1 + + extensionsShift = 2 + extensionsMask = (1 << 1) - 1 + + handleOffset = 2 + handleLength = 2 + + seqNumOffset = 4 + seqNumLength = 2 + + frameNumOffset = 6 + frameNumLength = 2 + + timestampOffset = 8 + timestampLength = 4 + + extensionsSizeOffset = headerLength + extensionsSizeLength = 2 + + extensionIDLength = 2 + extensionSizeLength = 2 +) + +type Extension struct { + id uint16 + data []byte +} + +type Header struct { + Version uint8 + IsStartOfFrame bool + IsFinalOfFrame bool + HasExtensions bool + Handle uint16 + SequenceNumber uint16 + FrameNumber uint16 + Timestamp uint32 + ExtensionsSize uint16 + Extensions []Extension +} + +/* + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ┆* 0 1 2 3 + ┆* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ┆* |V |F|L|X| reserved | handle | + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ┆* | sequence number | frame number | + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ┆* | timestamp | + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |* Extensions Size if X=1 | Extensions... | + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + Each extension + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ┆* 0 1 2 3 + ┆* 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + ┆* | Extension ID | Extension size | + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |* Extension payload (padded to 4 byte boundary) | + ┆* +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +func (h *Header) Unmarshal(buf []byte) (int, error) { + if len(buf) < headerLength { + return 0, fmt.Errorf("%w: %d < %d", errHeaderSizeInsufficient, len(buf), headerLength) + } + + hdrSize := headerLength + h.Version = buf[0] >> versionShift & versionMask + h.IsStartOfFrame = (buf[0] >> startOfFrameShift & startOfFrameMask) > 0 + h.IsFinalOfFrame = (buf[0] >> finalOfFrameShift & finalOfFrameMask) > 0 + h.HasExtensions = (buf[0] >> extensionsShift & extensionsMask) > 0 + + h.Handle = binary.BigEndian.Uint16(buf[handleOffset : handleOffset+handleLength]) + h.SequenceNumber = binary.BigEndian.Uint16(buf[seqNumOffset : seqNumOffset+seqNumLength]) + h.FrameNumber = binary.BigEndian.Uint16(buf[frameNumOffset : frameNumOffset+frameNumLength]) + h.Timestamp = binary.BigEndian.Uint32(buf[timestampOffset : timestampOffset+timestampLength]) + + if h.HasExtensions { + h.ExtensionsSize = (binary.BigEndian.Uint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength]) + 1) * 4 + hdrSize += extensionsSizeLength + + remainingSize := int(h.ExtensionsSize) + idx := extensionsSizeOffset + extensionsSizeLength + for remainingSize != 0 { + if len(buf[idx:]) < 4 || remainingSize < 4 { + return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), 4) + } + + id := binary.BigEndian.Uint16(buf[idx : idx+2]) + size := int(binary.BigEndian.Uint16(buf[idx+2 : idx+4])) + remainingSize -= 4 + idx += 4 + hdrSize += 4 + + if len(buf[idx:]) < size || remainingSize < size { + return 0, fmt.Errorf("%w: %d/%d < %d", errExtensionSizeInsufficient, remainingSize, len(buf[idx:]), size) + } + h.Extensions = append(h.Extensions, Extension{id: id, data: buf[idx : idx+size]}) + + size = ((size + 3) / 4) * 4 + remainingSize -= size + idx += size + hdrSize += size + } + } + + return hdrSize, nil +} + +func (h *Header) MarshalSize() int { + size := headerLength + if h.HasExtensions { + size += 2 // extensions size field + for _, ext := range h.Extensions { + size += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + } + } + return size +} + +func (h *Header) MarshalTo(buf []byte) (int, error) { + if len(buf) < headerLength { + return 0, fmt.Errorf("%w: %d < %d", errHeaderSizeInsufficient, len(buf), headerLength) + } + + hdrSize := headerLength + buf[0] = h.Version << versionShift + if h.IsStartOfFrame { + buf[0] |= (1 << startOfFrameShift) + } + if h.IsFinalOfFrame { + buf[0] |= (1 << finalOfFrameShift) + } + if h.HasExtensions { + buf[0] |= (1 << extensionsShift) + } + + binary.BigEndian.PutUint16(buf[handleOffset:handleOffset+handleLength], h.Handle) + binary.BigEndian.PutUint16(buf[seqNumOffset:seqNumOffset+seqNumLength], h.SequenceNumber) + binary.BigEndian.PutUint16(buf[frameNumOffset:frameNumOffset+frameNumLength], h.FrameNumber) + binary.BigEndian.PutUint32(buf[timestampOffset:timestampOffset+timestampLength], h.Timestamp) + + if h.HasExtensions { + binary.BigEndian.PutUint16(buf[extensionsSizeOffset:extensionsSizeOffset+extensionsSizeLength], (h.ExtensionsSize/4)-1) + hdrSize += extensionsSizeLength + + idx := extensionsSizeOffset + extensionsSizeLength + for _, ext := range h.Extensions { + binary.BigEndian.PutUint16(buf[idx:idx+extensionIDLength], ext.id) + binary.BigEndian.PutUint16(buf[idx+extensionIDLength:idx+extensionIDLength+extensionSizeLength], uint16(len(ext.data))) + copy(buf[idx+extensionIDLength+extensionSizeLength:], ext.data) + + idx += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + hdrSize += ((len(ext.data)+3)/4)*4 + 2 /* extension ID field */ + 2 /* extension length field */ + } + } + + return hdrSize, nil +} + +func (h *Header) AddExtension(ext Extension) { + for i, existingExt := range h.Extensions { + if existingExt.id == ext.id { + h.ExtensionsSize -= uint16((len(existingExt.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.Extensions[i].data = ext.data + h.ExtensionsSize += uint16((len(h.Extensions[i].data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + return + } + } + + h.Extensions = append(h.Extensions, ext) + h.ExtensionsSize += uint16((len(ext.data)+3)/4*4 + 2 /* extension ID field */ + 2 /* extension length field */) + h.HasExtensions = true +} + +func (h *Header) GetExtension(id uint16) (Extension, error) { + for _, ext := range h.Extensions { + if ext.id == id { + return ext, nil + } + } + return Extension{}, fmt.Errorf("%w, id: %d", errExtensionNotFound, id) +} + +// ---------------------------------------------------- + +type Packet struct { + Header + Payload []byte +} + +func (p *Packet) Unmarshal(buf []byte) error { + hdrSize, err := p.Header.Unmarshal(buf) + if err != nil { + return err + } + + p.Payload = buf[hdrSize:] + return nil +} + +func (p *Packet) Marshal() ([]byte, error) { + buf := make([]byte, p.Header.MarshalSize()+len(p.Payload)) + if err := p.MarshalTo(buf); err != nil { + return nil, err + } + + return buf, nil +} + +func (p *Packet) MarshalTo(buf []byte) error { + size := p.Header.MarshalSize() + len(p.Payload) + if len(buf) < size { + return fmt.Errorf("%w: %d < %d", errBufferSizeInsufficient, len(buf), size) + } + + hdrSize, err := p.Header.MarshalTo(buf) + if err != nil { + return err + } + + copy(buf[hdrSize:], p.Payload) + return nil +} diff --git a/livekit/pkg/rtc/datatrack/packet_test.go b/livekit/pkg/rtc/datatrack/packet_test.go new file mode 100644 index 0000000..a144ebf --- /dev/null +++ b/livekit/pkg/rtc/datatrack/packet_test.go @@ -0,0 +1,265 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datatrack + +import ( + "testing" + + "github.com/livekit/protocol/livekit" + "github.com/stretchr/testify/require" +) + +func TestPacket(t *testing.T) { + t.Run("without extension", func(t *testing.T) { + payload := make([]byte, 6) + for i := range len(payload) { + payload[i] = byte(255 - i) + } + packet := &Packet{ + Header: Header{ + Version: 0, + IsStartOfFrame: true, + IsFinalOfFrame: true, + Handle: 3333, + SequenceNumber: 6666, + FrameNumber: 9999, + Timestamp: 0xdeadbeef, + }, + Payload: payload, + } + rawPacket, err := packet.Marshal() + require.NoError(t, err) + + expectedRawPacket := []byte{ + 0x18, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0xff, 0xfe, 0xfd, 0xfc, + 0xfb, 0xfa, + } + require.Equal(t, expectedRawPacket, rawPacket) + + var unmarshaled Packet + err = unmarshaled.Unmarshal(rawPacket) + require.NoError(t, err) + require.Equal(t, packet, &unmarshaled) + }) + + t.Run("with extension", func(t *testing.T) { + payload := make([]byte, 4) + for i := range len(payload) { + payload[i] = byte(255 - i) + } + packet := &Packet{ + Header: Header{ + Version: 0, + IsStartOfFrame: true, + IsFinalOfFrame: false, + Handle: 3333, + SequenceNumber: 6666, + FrameNumber: 9999, + Timestamp: 0xdeadbeef, + }, + Payload: payload, + } + if extParticipantSid, err := NewExtensionParticipantSid("test_participant"); err == nil { + if ext, err := extParticipantSid.Marshal(); err == nil { + packet.AddExtension(ext) + } + } + rawPacket, err := packet.Marshal() + require.NoError(t, err) + + expectedRawPacket := []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01, + 0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, + 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, + 0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc, + } + require.Equal(t, expectedRawPacket, rawPacket) + + var unmarshaled Packet + err = unmarshaled.Unmarshal(rawPacket) + require.NoError(t, err) + require.Equal(t, packet, &unmarshaled) + + ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + require.NoError(t, err) + + var extParticipantSid ExtensionParticipantSid + require.NoError(t, extParticipantSid.Unmarshal(ext)) + require.Equal(t, livekit.ParticipantID("test_participant"), extParticipantSid.ParticipantID()) + }) + + t.Run("with extension padding", func(t *testing.T) { + payload := make([]byte, 4) + for i := range len(payload) { + payload[i] = byte(255 - i) + } + packet := &Packet{ + Header: Header{ + Version: 0, + IsStartOfFrame: true, + IsFinalOfFrame: false, + Handle: 3333, + SequenceNumber: 6666, + FrameNumber: 9999, + Timestamp: 0xdeadbeef, + }, + Payload: payload, + } + if extParticipantSid, err := NewExtensionParticipantSid("participant"); err == nil { + if ext, err := extParticipantSid.Marshal(); err == nil { + packet.AddExtension(ext) + } + } + rawPacket, err := packet.Marshal() + require.NoError(t, err) + + expectedRawPacket := []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, + 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, + 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, + 0xfd, 0xfc, + } + require.Equal(t, expectedRawPacket, rawPacket) + + var unmarshaled Packet + err = unmarshaled.Unmarshal(rawPacket) + require.NoError(t, err) + require.Equal(t, packet, &unmarshaled) + + ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + require.NoError(t, err) + + var extParticipantSid ExtensionParticipantSid + require.NoError(t, extParticipantSid.Unmarshal(ext)) + require.Equal(t, livekit.ParticipantID("participant"), extParticipantSid.ParticipantID()) + }) + + t.Run("replace extension", func(t *testing.T) { + payload := make([]byte, 4) + for i := range len(payload) { + payload[i] = byte(255 - i) + } + packet := &Packet{ + Header: Header{ + Version: 0, + IsStartOfFrame: true, + IsFinalOfFrame: false, + Handle: 3333, + SequenceNumber: 6666, + FrameNumber: 9999, + Timestamp: 0xdeadbeef, + }, + Payload: payload, + } + if extParticipantSid, err := NewExtensionParticipantSid("participant"); err == nil { + if ext, err := extParticipantSid.Marshal(); err == nil { + packet.AddExtension(ext) + } + } + rawPacket, err := packet.Marshal() + require.NoError(t, err) + + expectedRawPacket := []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, + 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, + 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, + 0xfd, 0xfc, + } + require.Equal(t, expectedRawPacket, rawPacket) + + // replace existing extension ID and ensure that marshalled packet is updated + if extParticipantSid, err := NewExtensionParticipantSid("test_participant"); err == nil { + if ext, err := extParticipantSid.Marshal(); err == nil { + packet.AddExtension(ext) + } + } + rawPacket, err = packet.Marshal() + require.NoError(t, err) + + expectedRawPacket = []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x04, 0x00, 0x01, + 0x00, 0x10, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x70, + 0x61, 0x72, 0x74, 0x69, 0x63, 0x69, 0x70, 0x61, + 0x6e, 0x74, 0xff, 0xfe, 0xfd, 0xfc, + } + require.Equal(t, expectedRawPacket, rawPacket) + + var unmarshaled Packet + err = unmarshaled.Unmarshal(rawPacket) + require.NoError(t, err) + require.Equal(t, packet, &unmarshaled) + + ext, err := unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + require.NoError(t, err) + + var extParticipantSid ExtensionParticipantSid + require.NoError(t, extParticipantSid.Unmarshal(ext)) + require.Equal(t, livekit.ParticipantID("test_participant"), extParticipantSid.ParticipantID()) + }) + + t.Run("bad pcaket", func(t *testing.T) { + var unmarshaled Packet + // extensions size too small + badPacket := []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x02, 0x00, 0x01, + 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, + 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, + 0xfd, 0xfc, + } + err := unmarshaled.Unmarshal(badPacket) + require.Error(t, err) + + // get an invalid extension id + badPacket = []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x02, + 0x00, 0x0b, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, + 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, + 0xfd, 0xfc, + } + err = unmarshaled.Unmarshal(badPacket) + require.NoError(t, err) + _, err = unmarshaled.GetExtension(uint16(livekit.DataTrackExtensionID_DTEI_PARTICIPANT_SID)) + require.Error(t, err) + + // extension payload size bigger than payload + badPacket = []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, + 0x00, 0x0d, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, + 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, + 0xfd, 0xfc, + } + err = unmarshaled.Unmarshal(badPacket) + require.Error(t, err) + + // extension payload size smaller than payload + badPacket = []byte{ + 0x14, 0x00, 0x0d, 0x05, 0x1a, 0x0a, 0x27, 0x0f, + 0xde, 0xad, 0xbe, 0xef, 0x00, 0x03, 0x00, 0x01, + 0x00, 0x07, 0x70, 0x61, 0x72, 0x74, 0x69, 0x63, + 0x69, 0x70, 0x61, 0x6e, 0x74, 0x00, 0xff, 0xfe, + 0xfd, 0xfc, + } + err = unmarshaled.Unmarshal(badPacket) + require.Error(t, err) + }) +} diff --git a/livekit/pkg/rtc/datatrack/testutils.go b/livekit/pkg/rtc/datatrack/testutils.go new file mode 100644 index 0000000..e7df0bb --- /dev/null +++ b/livekit/pkg/rtc/datatrack/testutils.go @@ -0,0 +1,74 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datatrack + +import ( + "math/rand" + "time" +) + +func GenerateRawDataPackets(handle uint16, seqNum uint16, frameNum uint16, numFrames int, frameSize int, frameDuration time.Duration) [][]byte { + if seqNum == 0 { + seqNum = uint16(rand.Intn(256) + 1) + } + if frameNum == 0 { + frameNum = uint16(rand.Intn(256) + 1) + } + timestamp := uint32(rand.Intn(1024)) + + packetsPerFrame := (frameSize + 255) / 256 // using 256 bytes of payload per packet + if packetsPerFrame == 0 { + return nil + } + numPackets := packetsPerFrame * numFrames + rawPackets := make([][]byte, 0, numPackets) + for range numFrames { + remainingSize := frameSize + for packetIdx := range packetsPerFrame { + payloadSize := min(remainingSize, 256) + payload := make([]byte, payloadSize) + for i := range len(payload) { + payload[i] = byte(255 - i) + } + packet := &Packet{ + Header: Header{ + Version: 0, + IsStartOfFrame: packetIdx == 0, + IsFinalOfFrame: packetIdx == packetsPerFrame-1, + Handle: handle, + SequenceNumber: seqNum, + FrameNumber: frameNum, + Timestamp: timestamp, + }, + Payload: payload, + } + if extParticipantSid, err := NewExtensionParticipantSid("test_participant"); err == nil { + if ext, err := extParticipantSid.Marshal(); err == nil { + packet.AddExtension(ext) + } + } + rawPacket, err := packet.Marshal() + if err == nil { + rawPackets = append(rawPackets, rawPacket) + } + seqNum++ + remainingSize -= payloadSize + } + frameNum++ + timestamp += uint32(90000 * frameDuration.Seconds()) + } + + return rawPackets +} diff --git a/livekit/pkg/rtc/dynacast/dynacastmanager_test.go b/livekit/pkg/rtc/dynacast/dynacastmanager_test.go new file mode 100644 index 0000000..9bb876d --- /dev/null +++ b/livekit/pkg/rtc/dynacast/dynacastmanager_test.go @@ -0,0 +1,516 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "sort" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" +) + +type testDynacastManagerListener struct { + onSubscribedMaxQualityChange func(subscribedQualties []*livekit.SubscribedCodec) + onSubscribedAudioCodecChange func(subscribedCodecs []*livekit.SubscribedAudioCodec) +} + +func (t *testDynacastManagerListener) OnDynacastSubscribedMaxQualityChange( + subscribedQualities []*livekit.SubscribedCodec, + _maxSubscribedQualities []types.SubscribedCodecQuality, +) { + t.onSubscribedMaxQualityChange(subscribedQualities) +} + +func (t *testDynacastManagerListener) OnDynacastSubscribedAudioCodecChange( + codecs []*livekit.SubscribedAudioCodec, +) { + t.onSubscribedAudioCodecChange(codecs) +} + +func TestSubscribedMaxQuality(t *testing.T) { + t.Run("subscribers muted", func(t *testing.T) { + var lock sync.Mutex + actualSubscribedQualities := make([]*livekit.SubscribedCodec, 0) + + dm := NewDynacastManagerVideo(DynacastManagerVideoParams{ + Listener: &testDynacastManagerListener{ + onSubscribedMaxQualityChange: func(subscribedQualities []*livekit.SubscribedCodec) { + lock.Lock() + actualSubscribedQualities = subscribedQualities + lock.Unlock() + }, + }, + }) + + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeAV1, livekit.VideoQuality_HIGH) + + // mute all subscribers of vp8 + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_OFF) + + expectedSubscribedQualities := []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + }) + + t.Run("subscribers max quality", func(t *testing.T) { + lock := sync.RWMutex{} + actualSubscribedQualities := make([]*livekit.SubscribedCodec, 0) + + dm := NewDynacastManagerVideo(DynacastManagerVideoParams{ + Listener: &testDynacastManagerListener{ + onSubscribedMaxQualityChange: func(subscribedQualities []*livekit.SubscribedCodec) { + lock.Lock() + actualSubscribedQualities = subscribedQualities + lock.Unlock() + }, + }, + }) + + dm.(*dynacastManagerVideo).maxSubscribedQuality = map[mime.MimeType]livekit.VideoQuality{ + mime.MimeTypeVP8: livekit.VideoQuality_LOW, + mime.MimeTypeAV1: livekit.VideoQuality_LOW, + } + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_HIGH) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + + expectedSubscribedQualities := []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // "s1" dropping to MEDIUM should disable HIGH layer + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_MEDIUM) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // "s1" , "s2" , "s3" dropping to LOW should disable HIGH & MEDIUM + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_LOW) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // muting "s2" only should not disable all qualities of vp8, no change of expected qualities + dm.NotifySubscriberMaxQuality("s2", mime.MimeTypeVP8, livekit.VideoQuality_OFF) + + time.Sleep(100 * time.Millisecond) + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // muting "s1" and s3 also should disable all qualities + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_OFF) + dm.NotifySubscriberMaxQuality("s3", mime.MimeTypeAV1, livekit.VideoQuality_OFF) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // unmuting "s1" should enable vp8 previously set max quality + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeVP8, livekit.VideoQuality_LOW) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // a higher quality from a different node should trigger that quality + dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ + {CodecMime: mime.MimeTypeVP8, Quality: livekit.VideoQuality_HIGH}, + {CodecMime: mime.MimeTypeAV1, Quality: livekit.VideoQuality_MEDIUM}, + }) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + }) +} + +func TestCodecRegression(t *testing.T) { + t.Run("codec regression video", func(t *testing.T) { + var lock sync.Mutex + actualSubscribedQualities := make([]*livekit.SubscribedCodec, 0) + + dm := NewDynacastManagerVideo(DynacastManagerVideoParams{ + Listener: &testDynacastManagerListener{ + onSubscribedMaxQualityChange: func(subscribedQualities []*livekit.SubscribedCodec) { + lock.Lock() + actualSubscribedQualities = subscribedQualities + lock.Unlock() + }, + }, + }) + + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeAV1, livekit.VideoQuality_HIGH) + + expectedSubscribedQualities := []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + dm.HandleCodecRegression(mime.MimeTypeAV1, mime.MimeTypeVP8) + + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: true}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + + // av1 quality change should be forwarded to vp8 + // av1 quality change of node should be ignored + dm.NotifySubscriberMaxQuality("s1", mime.MimeTypeAV1, livekit.VideoQuality_MEDIUM) + dm.NotifySubscriberNodeMaxQuality("n1", []types.SubscribedCodecQuality{ + {CodecMime: mime.MimeTypeAV1, Quality: livekit.VideoQuality_HIGH}, + }) + expectedSubscribedQualities = []*livekit.SubscribedCodec{ + { + Codec: mime.MimeTypeAV1.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + { + Codec: mime.MimeTypeVP8.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: true}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: true}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedCodecsAsString(expectedSubscribedQualities) == subscribedCodecsAsString(actualSubscribedQualities) + }, 10*time.Second, 100*time.Millisecond) + }) + + t.Run("codec regression audio", func(t *testing.T) { + var lock sync.Mutex + actualSubscribedCodecs := make([]*livekit.SubscribedAudioCodec, 0) + + dm := NewDynacastManagerAudio(DynacastManagerAudioParams{ + Listener: &testDynacastManagerListener{ + onSubscribedAudioCodecChange: func(subscribedCodecs []*livekit.SubscribedAudioCodec) { + lock.Lock() + actualSubscribedCodecs = subscribedCodecs + lock.Unlock() + }, + }, + }) + + dm.NotifySubscription("s1", mime.MimeTypeRED, true) + + expectedSubscribedCodecs := []*livekit.SubscribedAudioCodec{ + { + Codec: mime.MimeTypeRED.String(), + Enabled: true, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedAudioCodecsAsString(expectedSubscribedCodecs) == subscribedAudioCodecsAsString(actualSubscribedCodecs) + }, 10*time.Second, 100*time.Millisecond) + + dm.HandleCodecRegression(mime.MimeTypeRED, mime.MimeTypeOpus) + + expectedSubscribedCodecs = []*livekit.SubscribedAudioCodec{ + { + Codec: mime.MimeTypeRED.String(), + Enabled: false, + }, + { + Codec: mime.MimeTypeOpus.String(), + Enabled: true, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedAudioCodecsAsString(expectedSubscribedCodecs) == subscribedAudioCodecsAsString(actualSubscribedCodecs) + }, 10*time.Second, 100*time.Millisecond) + + // RED disable as subscriber or subscriber node should be ignored as it has been regressed + dm.NotifySubscription("s1", mime.MimeTypeRED, false) + dm.NotifySubscriptionNode("n1", []*livekit.SubscribedAudioCodec{ + {Codec: mime.MimeTypeRED.String(), Enabled: false}, + }) + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedAudioCodecsAsString(expectedSubscribedCodecs) == subscribedAudioCodecsAsString(actualSubscribedCodecs) + }, 10*time.Second, 100*time.Millisecond) + + // `s1` unsubscription should turn off `opus` + dm.NotifySubscription("s1", mime.MimeTypeOpus, false) + expectedSubscribedCodecs = []*livekit.SubscribedAudioCodec{ + { + Codec: mime.MimeTypeRED.String(), + Enabled: false, + }, + { + Codec: mime.MimeTypeOpus.String(), + Enabled: false, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedAudioCodecsAsString(expectedSubscribedCodecs) == subscribedAudioCodecsAsString(actualSubscribedCodecs) + }, 10*time.Second, 100*time.Millisecond) + + // a node subscription should turn `opus` back on + dm.NotifySubscriptionNode("n1", []*livekit.SubscribedAudioCodec{ + { + Codec: mime.MimeTypeOpus.String(), + Enabled: true, + }, + }) + expectedSubscribedCodecs = []*livekit.SubscribedAudioCodec{ + { + Codec: mime.MimeTypeRED.String(), + Enabled: false, + }, + { + Codec: mime.MimeTypeOpus.String(), + Enabled: true, + }, + } + require.Eventually(t, func() bool { + lock.Lock() + defer lock.Unlock() + + return subscribedAudioCodecsAsString(expectedSubscribedCodecs) == subscribedAudioCodecsAsString(actualSubscribedCodecs) + }, 10*time.Second, 100*time.Millisecond) + + }) +} + +func subscribedCodecsAsString(c1 []*livekit.SubscribedCodec) string { + sort.Slice(c1, func(i, j int) bool { return c1[i].Codec < c1[j].Codec }) + var s1 string + for _, c := range c1 { + s1 += c.String() + } + return s1 +} + +func subscribedAudioCodecsAsString(c1 []*livekit.SubscribedAudioCodec) string { + sort.Slice(c1, func(i, j int) bool { return c1[i].Codec < c1[j].Codec }) + var s1 string + for _, c := range c1 { + s1 += c.String() + } + return s1 +} diff --git a/livekit/pkg/rtc/dynacast/dynacastmanageraudio.go b/livekit/pkg/rtc/dynacast/dynacastmanageraudio.go new file mode 100644 index 0000000..542a58a --- /dev/null +++ b/livekit/pkg/rtc/dynacast/dynacastmanageraudio.go @@ -0,0 +1,198 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/mime" +) + +var _ DynacastManager = (*dynacastManagerAudio)(nil) +var _ dynacastQualityListener = (*dynacastManagerAudio)(nil) + +type DynacastManagerAudioParams struct { + Listener DynacastManagerListener + Logger logger.Logger +} + +type dynacastManagerAudio struct { + params DynacastManagerAudioParams + + subscribedCodecs map[mime.MimeType]bool + committedSubscribedCodecs map[mime.MimeType]bool + + isClosed bool + + *dynacastManagerBase +} + +func NewDynacastManagerAudio(params DynacastManagerAudioParams) DynacastManager { + if params.Logger == nil { + params.Logger = logger.GetLogger() + } + d := &dynacastManagerAudio{ + params: params, + subscribedCodecs: make(map[mime.MimeType]bool), + committedSubscribedCodecs: make(map[mime.MimeType]bool), + } + d.dynacastManagerBase = newDynacastManagerBase(dynacastManagerBaseParams{ + Logger: params.Logger, + OpsQueueDepth: 4, + OnRestart: func() { + d.committedSubscribedCodecs = make(map[mime.MimeType]bool) + }, + OnDynacastQualityCreate: func(mimeType mime.MimeType) dynacastQuality { + dq := newDynacastQualityAudio(dynacastQualityAudioParams{ + MimeType: mimeType, + Listener: d, + Logger: d.params.Logger, + }) + return dq + }, + OnRegressCodec: func(fromMime, toMime mime.MimeType) { + d.subscribedCodecs[fromMime] = false + + // if the new codec is not added, notify the publisher to start publishing + if _, ok := d.subscribedCodecs[toMime]; !ok { + d.subscribedCodecs[toMime] = true + } + }, + OnUpdateNeeded: d.update, + }) + return d +} + +// It is possible for tracks to be in pending close state. When track +// is waiting to be closed, a node is not streaming a track. This can +// be used to force an update announcing that subscribed codec is disabled, +// i.e. indicating not pulling track any more. +func (d *dynacastManagerAudio) ForceEnable(enabled bool) { + d.lock.Lock() + defer d.lock.Unlock() + + for mime := range d.committedSubscribedCodecs { + d.committedSubscribedCodecs[mime] = enabled + } + + d.enqueueSubscribedChange() +} + +func (d *dynacastManagerAudio) NotifySubscription( + subscriberID livekit.ParticipantID, + mime mime.MimeType, + enabled bool, +) { + dq := d.getOrCreateDynacastQuality(mime) + if dq != nil { + dq.NotifySubscription(subscriberID, enabled) + } +} + +func (d *dynacastManagerAudio) NotifySubscriptionNode( + nodeID livekit.NodeID, + codecs []*livekit.SubscribedAudioCodec, +) { + for _, codec := range codecs { + dq := d.getOrCreateDynacastQuality(mime.NormalizeMimeType(codec.Codec)) + if dq != nil { + dq.NotifySubscriptionNode(nodeID, codec.Enabled) + } + } +} + +func (d *dynacastManagerAudio) OnUpdateAudioCodecForMime(mime mime.MimeType, enabled bool) { + d.lock.Lock() + if _, ok := d.regressedCodec[mime]; !ok { + d.subscribedCodecs[mime] = enabled + } + d.lock.Unlock() + + d.update(false) +} + +func (d *dynacastManagerAudio) update(force bool) { + d.lock.Lock() + + d.params.Logger.Debugw( + "processing subscribed codec change", + "force", force, + "committedSubscribedCodecs", d.committedSubscribedCodecs, + "subscribedCodecs", d.subscribedCodecs, + ) + + if len(d.subscribedCodecs) == 0 { + // no mime has been added, nothing to update + d.lock.Unlock() + return + } + + // add or remove of a mime triggers an update + changed := len(d.subscribedCodecs) != len(d.committedSubscribedCodecs) + if !changed { + for mime, enabled := range d.subscribedCodecs { + if ce, ok := d.committedSubscribedCodecs[mime]; ok { + if ce != enabled { + changed = true + break + } + } + } + } + + if !force && !changed { + d.lock.Unlock() + return + } + + d.params.Logger.Debugw( + "committing subscribed codec change", + "force", force, + "committedSubscribedCoecs", d.committedSubscribedCodecs, + "subscribedcodecs", d.subscribedCodecs, + ) + + // commit change + d.committedSubscribedCodecs = make(map[mime.MimeType]bool, len(d.subscribedCodecs)) + for mime, enabled := range d.subscribedCodecs { + d.committedSubscribedCodecs[mime] = enabled + } + + d.enqueueSubscribedChange() + d.lock.Unlock() +} + +func (d *dynacastManagerAudio) enqueueSubscribedChange() { + if d.isClosed || d.params.Listener == nil { + return + } + + subscribedCodecs := make([]*livekit.SubscribedAudioCodec, 0, len(d.committedSubscribedCodecs)) + for mime, enabled := range d.committedSubscribedCodecs { + subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedAudioCodec{ + Codec: mime.String(), + Enabled: enabled, + }) + } + + d.params.Logger.Debugw( + "subscribedAudioCodecChange", + "subscribedCodecs", logger.ProtoSlice(subscribedCodecs), + ) + d.notifyOpsQueue.Enqueue(func() { + d.params.Listener.OnDynacastSubscribedAudioCodecChange(subscribedCodecs) + }) +} diff --git a/livekit/pkg/rtc/dynacast/dynacastmanagerbase.go b/livekit/pkg/rtc/dynacast/dynacastmanagerbase.go new file mode 100644 index 0000000..94812ac --- /dev/null +++ b/livekit/pkg/rtc/dynacast/dynacastmanagerbase.go @@ -0,0 +1,165 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "sync" + + "golang.org/x/exp/maps" + + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/utils" +) + +type dynacastManagerBaseParams struct { + Logger logger.Logger + OpsQueueDepth uint + OnRestart func() + OnDynacastQualityCreate func(mimeType mime.MimeType) dynacastQuality + OnRegressCodec func(fromMime, toMime mime.MimeType) + OnUpdateNeeded func(force bool) +} + +type dynacastManagerBase struct { + params dynacastManagerBaseParams + + lock sync.RWMutex + regressedCodec map[mime.MimeType]struct{} + dynacastQuality map[mime.MimeType]dynacastQuality + + notifyOpsQueue *utils.OpsQueue + + isClosed bool + + dynacastManagerNull + dynacastQualityListenerNull +} + +func newDynacastManagerBase(params dynacastManagerBaseParams) *dynacastManagerBase { + if params.OpsQueueDepth == 0 { + params.OpsQueueDepth = 4 + } + d := &dynacastManagerBase{ + params: params, + regressedCodec: make(map[mime.MimeType]struct{}), + dynacastQuality: make(map[mime.MimeType]dynacastQuality), + notifyOpsQueue: utils.NewOpsQueue(utils.OpsQueueParams{ + Name: "dynacast-notify", + MinSize: params.OpsQueueDepth, + FlushOnStop: true, + Logger: params.Logger, + }), + } + d.notifyOpsQueue.Start() + return d +} + +func (d *dynacastManagerBase) AddCodec(mime mime.MimeType) { + d.getOrCreateDynacastQuality(mime) +} + +func (d *dynacastManagerBase) HandleCodecRegression(fromMime, toMime mime.MimeType) { + fromDq := d.getOrCreateDynacastQuality(fromMime) + + d.lock.Lock() + if d.isClosed { + d.lock.Unlock() + return + } + + if fromDq == nil { + // should not happen as we have added the codec on setup receiver + d.params.Logger.Warnw("regression from codec not found", nil, "mime", fromMime, "toMime", toMime) + d.lock.Unlock() + return + } + d.regressedCodec[fromMime] = struct{}{} + d.params.OnRegressCodec(fromMime, toMime) + d.lock.Unlock() + + d.params.OnUpdateNeeded(false) + + fromDq.Stop() + fromDq.RegressTo(d.getOrCreateDynacastQuality(toMime)) +} + +func (d *dynacastManagerBase) Restart() { + d.lock.Lock() + d.params.OnRestart() + + dqs := d.getDynacastQualitiesLocked() + d.lock.Unlock() + + for _, dq := range dqs { + dq.Restart() + } +} + +func (d *dynacastManagerBase) Close() { + d.notifyOpsQueue.Stop() + + d.lock.Lock() + dqs := d.getDynacastQualitiesLocked() + d.dynacastQuality = make(map[mime.MimeType]dynacastQuality) + + d.isClosed = true + d.lock.Unlock() + + for _, dq := range dqs { + dq.Stop() + } +} + +// There are situations like track unmute or streaming from a different node +// where subscription changes needs to sent to the provider immediately. +// This bypasses any debouncing and forces a subscription change update +// with immediate effect. +func (d *dynacastManagerBase) ForceUpdate() { + d.params.OnUpdateNeeded(true) +} + +func (d *dynacastManagerBase) ClearSubscriberNodes() { + d.lock.Lock() + dqs := d.getDynacastQualitiesLocked() + d.lock.Unlock() + for _, dq := range dqs { + dq.ClearSubscriberNodes() + } +} + +func (d *dynacastManagerBase) getOrCreateDynacastQuality(mimeType mime.MimeType) dynacastQuality { + d.lock.Lock() + defer d.lock.Unlock() + + if d.isClosed || mimeType == mime.MimeTypeUnknown { + return nil + } + + if dq := d.dynacastQuality[mimeType]; dq != nil { + return dq + } + + dq := d.params.OnDynacastQualityCreate(mimeType) + dq.Start() + + d.dynacastQuality[mimeType] = dq + return dq +} + +func (d *dynacastManagerBase) getDynacastQualitiesLocked() []dynacastQuality { + return maps.Values(d.dynacastQuality) +} diff --git a/livekit/pkg/rtc/dynacast/dynacastmanagervideo.go b/livekit/pkg/rtc/dynacast/dynacastmanagervideo.go new file mode 100644 index 0000000..93120c8 --- /dev/null +++ b/livekit/pkg/rtc/dynacast/dynacastmanagervideo.go @@ -0,0 +1,272 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "time" + + "github.com/bep/debounce" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" +) + +var _ DynacastManager = (*dynacastManagerVideo)(nil) +var _ dynacastQualityListener = (*dynacastManagerVideo)(nil) + +type DynacastManagerVideoParams struct { + DynacastPauseDelay time.Duration + Listener DynacastManagerListener + Logger logger.Logger +} + +type dynacastManagerVideo struct { + params DynacastManagerVideoParams + + maxSubscribedQuality map[mime.MimeType]livekit.VideoQuality + committedMaxSubscribedQuality map[mime.MimeType]livekit.VideoQuality + + maxSubscribedQualityDebounce func(func()) + maxSubscribedQualityDebouncePending bool + + isClosed bool + + *dynacastManagerBase +} + +func NewDynacastManagerVideo(params DynacastManagerVideoParams) DynacastManager { + if params.Logger == nil { + params.Logger = logger.GetLogger() + } + d := &dynacastManagerVideo{ + params: params, + maxSubscribedQuality: make(map[mime.MimeType]livekit.VideoQuality), + committedMaxSubscribedQuality: make(map[mime.MimeType]livekit.VideoQuality), + } + if params.DynacastPauseDelay > 0 { + d.maxSubscribedQualityDebounce = debounce.New(params.DynacastPauseDelay) + } + d.dynacastManagerBase = newDynacastManagerBase(dynacastManagerBaseParams{ + Logger: params.Logger, + OpsQueueDepth: 64, + OnRestart: func() { + d.committedMaxSubscribedQuality = make(map[mime.MimeType]livekit.VideoQuality) + }, + OnDynacastQualityCreate: func(mimeType mime.MimeType) dynacastQuality { + dq := newDynacastQualityVideo(dynacastQualityVideoParams{ + MimeType: mimeType, + Listener: d, + Logger: d.params.Logger, + }) + return dq + }, + OnRegressCodec: func(fromMime, toMime mime.MimeType) { + d.maxSubscribedQuality[fromMime] = livekit.VideoQuality_OFF + + // if the new codec is not added, notify the publisher to start publishing + if _, ok := d.maxSubscribedQuality[toMime]; !ok { + d.maxSubscribedQuality[toMime] = livekit.VideoQuality_HIGH + } + }, + OnUpdateNeeded: d.update, + }) + return d +} + +// It is possible for tracks to be in pending close state. When track +// is waiting to be closed, a node is not streaming a track. This can +// be used to force an update announcing that subscribed quality is OFF, +// i.e. indicating not pulling track any more. +func (d *dynacastManagerVideo) ForceQuality(quality livekit.VideoQuality) { + d.lock.Lock() + defer d.lock.Unlock() + + for mime := range d.committedMaxSubscribedQuality { + d.committedMaxSubscribedQuality[mime] = quality + } + + d.enqueueSubscribedQualityChange() +} + +func (d *dynacastManagerVideo) NotifySubscriberMaxQuality( + subscriberID livekit.ParticipantID, + mime mime.MimeType, + quality livekit.VideoQuality, +) { + dq := d.getOrCreateDynacastQuality(mime) + if dq != nil { + dq.NotifySubscriberMaxQuality(subscriberID, quality) + } +} + +func (d *dynacastManagerVideo) NotifySubscriberNodeMaxQuality( + nodeID livekit.NodeID, + qualities []types.SubscribedCodecQuality, +) { + for _, quality := range qualities { + dq := d.getOrCreateDynacastQuality(quality.CodecMime) + if dq != nil { + dq.NotifySubscriberNodeMaxQuality(nodeID, quality.Quality) + } + } +} + +func (d *dynacastManagerVideo) OnUpdateMaxQualityForMime( + mime mime.MimeType, + maxQuality livekit.VideoQuality, +) { + d.lock.Lock() + if _, ok := d.regressedCodec[mime]; !ok { + d.maxSubscribedQuality[mime] = maxQuality + } + d.lock.Unlock() + + d.update(false) +} + +func (d *dynacastManagerVideo) update(force bool) { + d.lock.Lock() + + d.params.Logger.Debugw( + "processing quality change", + "force", force, + "committedMaxSubscribedQuality", d.committedMaxSubscribedQuality, + "maxSubscribedQuality", d.maxSubscribedQuality, + ) + + if len(d.maxSubscribedQuality) == 0 { + // no mime has been added, nothing to update + d.lock.Unlock() + return + } + + // add or remove of a mime triggers an update + changed := len(d.maxSubscribedQuality) != len(d.committedMaxSubscribedQuality) + downgradesOnly := !changed + if !changed { + for mime, quality := range d.maxSubscribedQuality { + if cq, ok := d.committedMaxSubscribedQuality[mime]; ok { + if cq != quality { + changed = true + } + + if (cq == livekit.VideoQuality_OFF && quality != livekit.VideoQuality_OFF) || (cq != livekit.VideoQuality_OFF && quality != livekit.VideoQuality_OFF && cq < quality) { + downgradesOnly = false + } + } + } + } + + if !force { + if !changed { + d.lock.Unlock() + return + } + + if downgradesOnly && d.maxSubscribedQualityDebounce != nil { + if !d.maxSubscribedQualityDebouncePending { + d.params.Logger.Debugw( + "debouncing quality downgrade", + "committedMaxSubscribedQuality", d.committedMaxSubscribedQuality, + "maxSubscribedQuality", d.maxSubscribedQuality, + ) + d.maxSubscribedQualityDebounce(func() { + d.update(true) + }) + d.maxSubscribedQualityDebouncePending = true + } else { + d.params.Logger.Debugw( + "quality downgrade waiting for debounce", + "committedMaxSubscribedQuality", d.committedMaxSubscribedQuality, + "maxSubscribedQuality", d.maxSubscribedQuality, + ) + } + d.lock.Unlock() + return + } + } + + // clear debounce on send + if d.maxSubscribedQualityDebounce != nil { + d.maxSubscribedQualityDebounce(func() {}) + d.maxSubscribedQualityDebouncePending = false + } + + d.params.Logger.Debugw( + "committing quality change", + "force", force, + "committedMaxSubscribedQuality", d.committedMaxSubscribedQuality, + "maxSubscribedQuality", d.maxSubscribedQuality, + ) + + // commit change + d.committedMaxSubscribedQuality = make(map[mime.MimeType]livekit.VideoQuality, len(d.maxSubscribedQuality)) + for mime, quality := range d.maxSubscribedQuality { + d.committedMaxSubscribedQuality[mime] = quality + } + + d.enqueueSubscribedQualityChange() + d.lock.Unlock() +} + +func (d *dynacastManagerVideo) enqueueSubscribedQualityChange() { + if d.isClosed || d.params.Listener == nil { + return + } + + subscribedCodecs := make([]*livekit.SubscribedCodec, 0, len(d.committedMaxSubscribedQuality)) + maxSubscribedQualities := make([]types.SubscribedCodecQuality, 0, len(d.committedMaxSubscribedQuality)) + for mime, quality := range d.committedMaxSubscribedQuality { + maxSubscribedQualities = append(maxSubscribedQualities, types.SubscribedCodecQuality{ + CodecMime: mime, + Quality: quality, + }) + + if quality == livekit.VideoQuality_OFF { + subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedCodec{ + Codec: mime.String(), + Qualities: []*livekit.SubscribedQuality{ + {Quality: livekit.VideoQuality_LOW, Enabled: false}, + {Quality: livekit.VideoQuality_MEDIUM, Enabled: false}, + {Quality: livekit.VideoQuality_HIGH, Enabled: false}, + }, + }) + } else { + var subscribedQualities []*livekit.SubscribedQuality + for q := livekit.VideoQuality_LOW; q <= livekit.VideoQuality_HIGH; q++ { + subscribedQualities = append(subscribedQualities, &livekit.SubscribedQuality{ + Quality: q, + Enabled: q <= quality, + }) + } + subscribedCodecs = append(subscribedCodecs, &livekit.SubscribedCodec{ + Codec: mime.String(), + Qualities: subscribedQualities, + }) + } + } + + d.params.Logger.Debugw( + "subscribedMaxQualityChange", + "subscribedCodecs", subscribedCodecs, + "maxSubscribedQualities", maxSubscribedQualities, + ) + d.notifyOpsQueue.Enqueue(func() { + d.params.Listener.OnDynacastSubscribedMaxQualityChange(subscribedCodecs, maxSubscribedQualities) + }) +} diff --git a/livekit/pkg/rtc/dynacast/dynacastqualityaudio.go b/livekit/pkg/rtc/dynacast/dynacastqualityaudio.go new file mode 100644 index 0000000..2414b5b --- /dev/null +++ b/livekit/pkg/rtc/dynacast/dynacastqualityaudio.go @@ -0,0 +1,168 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +var _ dynacastQuality = (*dynacastQualityAudio)(nil) + +type dynacastQualityAudioParams struct { + MimeType mime.MimeType + Listener dynacastQualityListener + Logger logger.Logger +} + +// dynacastQualityAudio manages enable a single receiver of a media track +type dynacastQualityAudio struct { + params dynacastQualityAudioParams + + // quality level enable/disable + lock sync.RWMutex + initialized bool + subscriberEnables map[livekit.ParticipantID]bool + subscriberNodeEnables map[livekit.NodeID]bool + enabled bool + regressTo dynacastQuality + + dynacastQualityNull +} + +func newDynacastQualityAudio(params dynacastQualityAudioParams) dynacastQuality { + return &dynacastQualityAudio{ + params: params, + subscriberEnables: make(map[livekit.ParticipantID]bool), + subscriberNodeEnables: make(map[livekit.NodeID]bool), + } +} + +func (d *dynacastQualityAudio) Start() { + d.reset() +} + +func (d *dynacastQualityAudio) Restart() { + d.reset() +} + +func (d *dynacastQualityAudio) Stop() { +} + +func (d *dynacastQualityAudio) NotifySubscription(subscriberID livekit.ParticipantID, enabled bool) { + d.params.Logger.Debugw( + "setting subscriber codec enable", + "mime", d.params.MimeType, + "subscriberID", subscriberID, + "enabled", enabled, + ) + + d.lock.Lock() + if r := d.regressTo; r != nil { + d.lock.Unlock() + return + } + + if !enabled { + delete(d.subscriberEnables, subscriberID) + } else { + d.subscriberEnables[subscriberID] = true + } + d.lock.Unlock() + + d.updateQualityChange(false) +} + +func (d *dynacastQualityAudio) NotifySubscriptionNode(nodeID livekit.NodeID, enabled bool) { + d.params.Logger.Debugw( + "setting subscriber node codec enabled", + "mime", d.params.MimeType, + "subscriberNodeID", nodeID, + "enabled", enabled, + ) + + d.lock.Lock() + if r := d.regressTo; r != nil { + // the downstream node will synthesize correct enable (its dynacast manager has codec regression), just ignore it + d.params.Logger.Debugw( + "ignoring node codec change, regressed to another dynacast quality", + "mime", d.params.MimeType, + "regressedMime", d.regressTo.Mime(), + ) + d.lock.Unlock() + return + } + + if !enabled { + delete(d.subscriberNodeEnables, nodeID) + } else { + d.subscriberNodeEnables[nodeID] = true + } + d.lock.Unlock() + + d.updateQualityChange(false) +} + +func (d *dynacastQualityAudio) ClearSubscriberNodes() { + d.lock.Lock() + d.subscriberNodeEnables = make(map[livekit.NodeID]bool) + d.lock.Unlock() + + d.updateQualityChange(false) +} + +func (d *dynacastQualityAudio) Mime() mime.MimeType { + return d.params.MimeType +} + +func (d *dynacastQualityAudio) RegressTo(other dynacastQuality) { + d.lock.Lock() + d.regressTo = other + d.lock.Unlock() + + other.Restart() +} + +func (d *dynacastQualityAudio) reset() { + d.lock.Lock() + d.initialized = false + d.lock.Unlock() +} + +func (d *dynacastQualityAudio) updateQualityChange(force bool) { + d.lock.Lock() + enabled := len(d.subscriberEnables) != 0 || len(d.subscriberNodeEnables) != 0 + if enabled == d.enabled && d.initialized && !force { + d.lock.Unlock() + return + } + + d.initialized = true + d.enabled = enabled + d.params.Logger.Debugw( + "notifying enabled change", + "mime", d.params.MimeType, + "enabled", d.enabled, + "subscriberNodeEnables", d.subscriberNodeEnables, + "subscribedEnables", d.subscriberEnables, + "force", force, + ) + d.lock.Unlock() + + d.params.Listener.OnUpdateAudioCodecForMime(d.params.MimeType, enabled) +} diff --git a/livekit/pkg/rtc/dynacast/dynacastqualityvideo.go b/livekit/pkg/rtc/dynacast/dynacastqualityvideo.go new file mode 100644 index 0000000..47f9f38 --- /dev/null +++ b/livekit/pkg/rtc/dynacast/dynacastqualityvideo.go @@ -0,0 +1,249 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +var _ dynacastQuality = (*dynacastQualityVideo)(nil) + +const ( + initialQualityUpdateWait = 10 * time.Second +) + +type dynacastQualityVideoParams struct { + MimeType mime.MimeType + Listener dynacastQualityListener + Logger logger.Logger +} + +// dynacastQualityVideo manages max subscribed quality of a single receiver of a media track +type dynacastQualityVideo struct { + params dynacastQualityVideoParams + + // quality level enable/disable + lock sync.RWMutex + initialized bool + maxSubscriberQuality map[livekit.ParticipantID]livekit.VideoQuality + maxSubscriberNodeQuality map[livekit.NodeID]livekit.VideoQuality + maxSubscribedQuality livekit.VideoQuality + maxQualityTimer *time.Timer + regressTo dynacastQuality + + dynacastQualityNull +} + +func newDynacastQualityVideo(params dynacastQualityVideoParams) dynacastQuality { + return &dynacastQualityVideo{ + params: params, + maxSubscriberQuality: make(map[livekit.ParticipantID]livekit.VideoQuality), + maxSubscriberNodeQuality: make(map[livekit.NodeID]livekit.VideoQuality), + } +} + +func (d *dynacastQualityVideo) Start() { + d.reset() +} + +func (d *dynacastQualityVideo) Restart() { + d.reset() +} + +func (d *dynacastQualityVideo) Stop() { + d.stopMaxQualityTimer() +} + +func (d *dynacastQualityVideo) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, quality livekit.VideoQuality) { + d.params.Logger.Debugw( + "setting subscriber max quality", + "mime", d.params.MimeType, + "subscriberID", subscriberID, + "quality", quality.String(), + ) + + d.lock.Lock() + if r := d.regressTo; r != nil { + d.lock.Unlock() + r.NotifySubscriberMaxQuality(subscriberID, quality) + return + } + + if quality == livekit.VideoQuality_OFF { + delete(d.maxSubscriberQuality, subscriberID) + } else { + d.maxSubscriberQuality[subscriberID] = quality + } + d.lock.Unlock() + + d.updateQualityChange(false) +} + +func (d *dynacastQualityVideo) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, quality livekit.VideoQuality) { + d.params.Logger.Debugw( + "setting subscriber node max quality", + "mime", d.params.MimeType, + "subscriberNodeID", nodeID, + "quality", quality.String(), + ) + + d.lock.Lock() + if r := d.regressTo; r != nil { + // the downstream node will synthesize correct quality notify (its dynacast manager has codec regression), just ignore it + d.params.Logger.Debugw( + "ignoring node quality change, regressed to another dynacast quality", + "mime", d.params.MimeType, + "regressedMime", d.regressTo.Mime(), + ) + d.lock.Unlock() + return + } + + if quality == livekit.VideoQuality_OFF { + delete(d.maxSubscriberNodeQuality, nodeID) + } else { + d.maxSubscriberNodeQuality[nodeID] = quality + } + d.lock.Unlock() + + d.updateQualityChange(false) +} + +func (d *dynacastQualityVideo) ClearSubscriberNodes() { + d.lock.Lock() + d.maxSubscriberNodeQuality = make(map[livekit.NodeID]livekit.VideoQuality) + d.lock.Unlock() + + d.updateQualityChange(false) +} + +func (d *dynacastQualityVideo) Mime() mime.MimeType { + return d.params.MimeType +} + +func (d *dynacastQualityVideo) RegressTo(other dynacastQuality) { + d.lock.Lock() + d.regressTo = other + maxSubscriberQuality := d.maxSubscriberQuality + maxSubscriberNodeQuality := d.maxSubscriberNodeQuality + d.maxSubscriberQuality = make(map[livekit.ParticipantID]livekit.VideoQuality) + d.maxSubscriberNodeQuality = make(map[livekit.NodeID]livekit.VideoQuality) + d.lock.Unlock() + + other.Replace(maxSubscriberQuality, maxSubscriberNodeQuality) +} + +func (d *dynacastQualityVideo) Replace( + maxSubscriberQuality map[livekit.ParticipantID]livekit.VideoQuality, + maxSubscriberNodeQuality map[livekit.NodeID]livekit.VideoQuality, +) { + d.lock.Lock() + for subID, quality := range maxSubscriberQuality { + if oldQuality, ok := d.maxSubscriberQuality[subID]; ok { + // no QUALITY_OFF in the map + if quality > oldQuality { + d.maxSubscriberQuality[subID] = quality + } + } else { + d.maxSubscriberQuality[subID] = quality + } + } + + for nodeID, quality := range maxSubscriberNodeQuality { + if oldQuality, ok := d.maxSubscriberNodeQuality[nodeID]; ok { + // no QUALITY_OFF in the map + if quality > oldQuality { + d.maxSubscriberNodeQuality[nodeID] = quality + } + } else { + d.maxSubscriberNodeQuality[nodeID] = quality + } + } + d.lock.Unlock() + + d.Restart() +} + +func (d *dynacastQualityVideo) reset() { + d.lock.Lock() + d.initialized = false + d.lock.Unlock() + + d.startMaxQualityTimer() +} + +func (d *dynacastQualityVideo) updateQualityChange(force bool) { + d.lock.Lock() + maxSubscribedQuality := livekit.VideoQuality_OFF + for _, subQuality := range d.maxSubscriberQuality { + if maxSubscribedQuality == livekit.VideoQuality_OFF || (subQuality != livekit.VideoQuality_OFF && subQuality > maxSubscribedQuality) { + maxSubscribedQuality = subQuality + } + } + for _, nodeQuality := range d.maxSubscriberNodeQuality { + if maxSubscribedQuality == livekit.VideoQuality_OFF || (nodeQuality != livekit.VideoQuality_OFF && nodeQuality > maxSubscribedQuality) { + maxSubscribedQuality = nodeQuality + } + } + + if maxSubscribedQuality == d.maxSubscribedQuality && d.initialized && !force { + d.lock.Unlock() + return + } + + d.initialized = true + d.maxSubscribedQuality = maxSubscribedQuality + d.params.Logger.Debugw( + "notifying quality change", + "mime", d.params.MimeType, + "maxSubscriberQuality", d.maxSubscriberQuality, + "maxSubscriberNodeQuality", d.maxSubscriberNodeQuality, + "maxSubscribedQuality", d.maxSubscribedQuality, + "force", force, + ) + d.lock.Unlock() + + d.params.Listener.OnUpdateMaxQualityForMime(d.params.MimeType, maxSubscribedQuality) +} + +func (d *dynacastQualityVideo) startMaxQualityTimer() { + d.lock.Lock() + defer d.lock.Unlock() + + if d.maxQualityTimer != nil { + d.maxQualityTimer.Stop() + d.maxQualityTimer = nil + } + + d.maxQualityTimer = time.AfterFunc(initialQualityUpdateWait, func() { + d.stopMaxQualityTimer() + d.updateQualityChange(true) + }) +} + +func (d *dynacastQualityVideo) stopMaxQualityTimer() { + d.lock.Lock() + defer d.lock.Unlock() + + if d.maxQualityTimer != nil { + d.maxQualityTimer.Stop() + d.maxQualityTimer = nil + } +} diff --git a/livekit/pkg/rtc/dynacast/interfaces.go b/livekit/pkg/rtc/dynacast/interfaces.go new file mode 100644 index 0000000..42d1147 --- /dev/null +++ b/livekit/pkg/rtc/dynacast/interfaces.go @@ -0,0 +1,185 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynacast + +import ( + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" +) + +type DynacastManagerListener interface { + OnDynacastSubscribedMaxQualityChange( + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, + ) + + OnDynacastSubscribedAudioCodecChange(codecs []*livekit.SubscribedAudioCodec) +} + +var _ DynacastManagerListener = (*DynacastManagerListenerNull)(nil) + +type DynacastManagerListenerNull struct { +} + +func (d *DynacastManagerListenerNull) OnDynacastSubscribedMaxQualityChange( + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, +) { +} +func (d *DynacastManagerListenerNull) OnDynacastSubscribedAudioCodecChange( + codecs []*livekit.SubscribedAudioCodec, +) { +} + +// ----------------------------------------- + +type DynacastManager interface { + AddCodec(mime mime.MimeType) + HandleCodecRegression(fromMime, toMime mime.MimeType) + Restart() + Close() + ForceUpdate() + ForceQuality(quality livekit.VideoQuality) + ForceEnable(enabled bool) + + NotifySubscriberMaxQuality( + subscriberID livekit.ParticipantID, + mime mime.MimeType, + quality livekit.VideoQuality, + ) + NotifySubscription( + subscriberID livekit.ParticipantID, + mime mime.MimeType, + enabled bool, + ) + + NotifySubscriberNodeMaxQuality( + nodeID livekit.NodeID, + qualities []types.SubscribedCodecQuality, + ) + NotifySubscriptionNode( + nodeID livekit.NodeID, + codecs []*livekit.SubscribedAudioCodec, + ) + ClearSubscriberNodes() +} + +var _ DynacastManager = (*dynacastManagerNull)(nil) + +type dynacastManagerNull struct { +} + +func (d *dynacastManagerNull) AddCodec(mime mime.MimeType) {} +func (d *dynacastManagerNull) HandleCodecRegression(fromMime, toMime mime.MimeType) {} +func (d *dynacastManagerNull) Restart() {} +func (d *dynacastManagerNull) Close() {} +func (d *dynacastManagerNull) ForceUpdate() {} +func (d *dynacastManagerNull) ForceQuality(quality livekit.VideoQuality) {} +func (d *dynacastManagerNull) ForceEnable(enabled bool) {} +func (d *dynacastManagerNull) NotifySubscriberMaxQuality( + subscriberID livekit.ParticipantID, + mime mime.MimeType, + quality livekit.VideoQuality, +) { +} +func (d *dynacastManagerNull) NotifySubscription( + subscriberID livekit.ParticipantID, + mime mime.MimeType, + enabled bool, +) { +} +func (d *dynacastManagerNull) NotifySubscriberNodeMaxQuality( + nodeID livekit.NodeID, + qualities []types.SubscribedCodecQuality, +) { +} +func (d *dynacastManagerNull) NotifySubscriptionNode( + nodeID livekit.NodeID, + codecs []*livekit.SubscribedAudioCodec, +) { +} +func (d *dynacastManagerNull) ClearSubscriberNodes() {} + +// ------------------------------------------------ + +type dynacastQualityListener interface { + OnUpdateMaxQualityForMime(mimeType mime.MimeType, maxQuality livekit.VideoQuality) + OnUpdateAudioCodecForMime(mimeType mime.MimeType, enabled bool) +} + +var _ dynacastQualityListener = (*dynacastQualityListenerNull)(nil) + +type dynacastQualityListenerNull struct { +} + +func (d *dynacastQualityListenerNull) OnUpdateMaxQualityForMime( + mimeType mime.MimeType, + maxQuality livekit.VideoQuality, +) { +} + +func (d *dynacastQualityListenerNull) OnUpdateAudioCodecForMime( + mimeType mime.MimeType, + enabled bool, +) { +} + +// ------------------------------------------------ + +type dynacastQuality interface { + Start() + Restart() + Stop() + + NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, quality livekit.VideoQuality) + NotifySubscription(subscriberID livekit.ParticipantID, enabled bool) + + NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, quality livekit.VideoQuality) + NotifySubscriptionNode(nodeID livekit.NodeID, enabled bool) + ClearSubscriberNodes() + + Replace( + maxSubscriberQuality map[livekit.ParticipantID]livekit.VideoQuality, + maxSubscriberNodeQuality map[livekit.NodeID]livekit.VideoQuality, + ) + + Mime() mime.MimeType + RegressTo(other dynacastQuality) +} + +var _ dynacastQuality = (*dynacastQualityNull)(nil) + +type dynacastQualityNull struct { +} + +func (d *dynacastQualityNull) Start() {} +func (d *dynacastQualityNull) Restart() {} +func (d *dynacastQualityNull) Stop() {} +func (d *dynacastQualityNull) NotifySubscriberMaxQuality(subscriberID livekit.ParticipantID, quality livekit.VideoQuality) { +} +func (d *dynacastQualityNull) NotifySubscription(subscriberID livekit.ParticipantID, enabled bool) {} +func (d *dynacastQualityNull) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, quality livekit.VideoQuality) { +} +func (d *dynacastQualityNull) NotifySubscriptionNode(nodeID livekit.NodeID, enabled bool) {} +func (d *dynacastQualityNull) ClearSubscriberNodes() {} +func (d *dynacastQualityNull) Replace( + maxSubscriberQuality map[livekit.ParticipantID]livekit.VideoQuality, + maxSubscriberNodeQuality map[livekit.NodeID]livekit.VideoQuality, +) { +} +func (d *dynacastQualityNull) Mime() mime.MimeType { return mime.MimeTypeUnknown } +func (d *dynacastQualityNull) RegressTo(other dynacastQuality) {} diff --git a/livekit/pkg/rtc/egress.go b/livekit/pkg/rtc/egress.go new file mode 100644 index 0000000..06ff4ea --- /dev/null +++ b/livekit/pkg/rtc/egress.go @@ -0,0 +1,175 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/webhook" +) + +type EgressLauncher interface { + StartEgress(context.Context, *rpc.StartEgressRequest) (*livekit.EgressInfo, error) +} + +func StartParticipantEgress( + ctx context.Context, + launcher EgressLauncher, + ts telemetry.TelemetryService, + opts *livekit.AutoParticipantEgress, + identity livekit.ParticipantIdentity, + roomName livekit.RoomName, + roomID livekit.RoomID, +) error { + if req, err := startParticipantEgress(ctx, launcher, opts, identity, roomName, roomID); err != nil { + // send egress failed webhook + + info := &livekit.EgressInfo{ + RoomId: string(roomID), + RoomName: string(roomName), + Status: livekit.EgressStatus_EGRESS_FAILED, + Error: err.Error(), + Request: &livekit.EgressInfo_Participant{Participant: req}, + } + + ts.NotifyEgressEvent(ctx, webhook.EventEgressEnded, info) + + return err + } + return nil +} + +func startParticipantEgress( + ctx context.Context, + launcher EgressLauncher, + opts *livekit.AutoParticipantEgress, + identity livekit.ParticipantIdentity, + roomName livekit.RoomName, + roomID livekit.RoomID, +) (*livekit.ParticipantEgressRequest, error) { + req := &livekit.ParticipantEgressRequest{ + RoomName: string(roomName), + Identity: string(identity), + FileOutputs: opts.FileOutputs, + SegmentOutputs: opts.SegmentOutputs, + } + + switch o := opts.Options.(type) { + case *livekit.AutoParticipantEgress_Preset: + req.Options = &livekit.ParticipantEgressRequest_Preset{Preset: o.Preset} + case *livekit.AutoParticipantEgress_Advanced: + req.Options = &livekit.ParticipantEgressRequest_Advanced{Advanced: o.Advanced} + } + + if launcher == nil { + return req, errors.New("egress launcher not found") + } + + _, err := launcher.StartEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_Participant{ + Participant: req, + }, + RoomId: string(roomID), + }) + return req, err +} + +func StartTrackEgress( + ctx context.Context, + launcher EgressLauncher, + ts telemetry.TelemetryService, + opts *livekit.AutoTrackEgress, + track types.MediaTrack, + roomName livekit.RoomName, + roomID livekit.RoomID, +) error { + if req, err := startTrackEgress(ctx, launcher, opts, track, roomName, roomID); err != nil { + // send egress failed webhook + + info := &livekit.EgressInfo{ + RoomId: string(roomID), + RoomName: string(roomName), + Status: livekit.EgressStatus_EGRESS_FAILED, + Error: err.Error(), + Request: &livekit.EgressInfo_Track{Track: req}, + } + ts.NotifyEgressEvent(ctx, webhook.EventEgressEnded, info) + + return err + } + return nil +} + +func startTrackEgress( + ctx context.Context, + launcher EgressLauncher, + opts *livekit.AutoTrackEgress, + track types.MediaTrack, + roomName livekit.RoomName, + roomID livekit.RoomID, +) (*livekit.TrackEgressRequest, error) { + output := &livekit.DirectFileOutput{ + Filepath: getFilePath(opts.Filepath), + } + + switch out := opts.Output.(type) { + case *livekit.AutoTrackEgress_Azure: + output.Output = &livekit.DirectFileOutput_Azure{Azure: out.Azure} + case *livekit.AutoTrackEgress_Gcp: + output.Output = &livekit.DirectFileOutput_Gcp{Gcp: out.Gcp} + case *livekit.AutoTrackEgress_S3: + output.Output = &livekit.DirectFileOutput_S3{S3: out.S3} + } + + req := &livekit.TrackEgressRequest{ + RoomName: string(roomName), + TrackId: string(track.ID()), + Output: &livekit.TrackEgressRequest_File{ + File: output, + }, + } + + if launcher == nil { + return req, errors.New("egress launcher not found") + } + + _, err := launcher.StartEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_Track{ + Track: req, + }, + RoomId: string(roomID), + }) + return req, err +} + +func getFilePath(filepath string) string { + if filepath == "" || strings.HasSuffix(filepath, "/") || strings.Contains(filepath, "{track_id}") { + return filepath + } + + idx := strings.Index(filepath, ".") + if idx == -1 { + return fmt.Sprintf("%s-{track_id}", filepath) + } else { + return fmt.Sprintf("%s-%s%s", filepath[:idx], "{track_id}", filepath[idx:]) + } +} diff --git a/livekit/pkg/rtc/errors.go b/livekit/pkg/rtc/errors.go new file mode 100644 index 0000000..bbe3906 --- /dev/null +++ b/livekit/pkg/rtc/errors.go @@ -0,0 +1,44 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "errors" +) + +var ( + ErrRoomClosed = errors.New("room has already closed") + ErrParticipantSessionClosed = errors.New("participant session is already closed") + ErrPermissionDenied = errors.New("no permissions to access the room") + ErrMaxParticipantsExceeded = errors.New("room has exceeded its max participants") + ErrLimitExceeded = errors.New("node has exceeded its configured limit") + ErrAlreadyJoined = errors.New("a participant with the same identity is already in the room") + ErrDataChannelUnavailable = errors.New("data channel is not available") + ErrDataChannelBufferFull = errors.New("data channel buffer is full") + ErrTransportFailure = errors.New("transport failure") + ErrEmptyIdentity = errors.New("participant identity cannot be empty") + ErrEmptyParticipantID = errors.New("participant ID cannot be empty") + ErrMissingGrants = errors.New("VideoGrant is missing") + ErrInternalError = errors.New("internal error") + + // Track subscription related + ErrNoTrackPermission = errors.New("participant is not allowed to subscribe to this track") + ErrNoSubscribePermission = errors.New("participant is not given permission to subscribe to tracks") + ErrTrackNotFound = errors.New("track cannot be found") + ErrTrackNotBound = errors.New("track not bound") + ErrSubscriptionLimitExceeded = errors.New("participant has exceeded its subscription limit") + + ErrNoSubscribeMetricsPermission = errors.New("participant is not given permission to subscribe to metrics") +) diff --git a/livekit/pkg/rtc/mediaengine.go b/livekit/pkg/rtc/mediaengine.go new file mode 100644 index 0000000..e1e7180 --- /dev/null +++ b/livekit/pkg/rtc/mediaengine.go @@ -0,0 +1,318 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "strings" + + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" +) + +var ( + OpusCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeOpus.String(), + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", + }, + PayloadType: 111, + } + + RedCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeRED.String(), + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "111/111", + }, + PayloadType: 63, + } + + PCMUCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypePCMU.String(), + ClockRate: 8000, + }, + PayloadType: 0, + } + + PCMACodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypePCMA.String(), + ClockRate: 8000, + }, + PayloadType: 8, + } + + videoRTXCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeRTX.String(), + ClockRate: 90000, + }, + } + + vp8CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeVP8.String(), + ClockRate: 90000, + }, + PayloadType: 96, + } + + vp9ProfileId0CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeVP9.String(), + ClockRate: 90000, + SDPFmtpLine: "profile-id=0", + }, + PayloadType: 98, + } + + vp9ProfileId1CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeVP9.String(), + ClockRate: 90000, + SDPFmtpLine: "profile-id=1", + }, + PayloadType: 100, + } + + h264ProfileLevelId42e01fPacketizationMode0CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH264.String(), + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + }, + PayloadType: 125, + } + + h264ProfileLevelId42e01fPacketizationMode1CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH264.String(), + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f", + }, + PayloadType: 108, + } + + h264HighProfileFmtp = "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640032" + h264HighProfileCodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH264.String(), + ClockRate: 90000, + SDPFmtpLine: h264HighProfileFmtp, + }, + PayloadType: 123, + } + + av1CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeAV1.String(), + ClockRate: 90000, + }, + PayloadType: 35, + } + + h265CodecParameters = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: mime.MimeTypeH265.String(), + ClockRate: 90000, + }, + PayloadType: 116, + } + + videoCodecsParameters = []webrtc.RTPCodecParameters{ + vp8CodecParameters, + vp9ProfileId0CodecParameters, + vp9ProfileId1CodecParameters, + h264ProfileLevelId42e01fPacketizationMode0CodecParameters, + h264ProfileLevelId42e01fPacketizationMode1CodecParameters, + h264HighProfileCodecParameters, + av1CodecParameters, + h265CodecParameters, + } +) + +func registerCodecs(me *webrtc.MediaEngine, codecs []*livekit.Codec, rtcpFeedback RTCPFeedbackConfig, filterOutH264HighProfile bool) error { + // audio codecs + if IsCodecEnabled(codecs, OpusCodecParameters.RTPCodecCapability) { + cp := OpusCodecParameters + cp.RTPCodecCapability.RTCPFeedback = rtcpFeedback.Audio + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeAudio); err != nil { + return err + } + + if IsCodecEnabled(codecs, RedCodecParameters.RTPCodecCapability) { + if err := me.RegisterCodec(RedCodecParameters, webrtc.RTPCodecTypeAudio); err != nil { + return err + } + } + } + + for _, codec := range []webrtc.RTPCodecParameters{PCMUCodecParameters, PCMACodecParameters} { + if !IsCodecEnabled(codecs, codec.RTPCodecCapability) { + continue + } + + cp := codec + cp.RTPCodecCapability.RTCPFeedback = rtcpFeedback.Audio + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeAudio); err != nil { + return err + } + } + + // video codecs + rtxEnabled := IsCodecEnabled(codecs, videoRTXCodecParameters.RTPCodecCapability) + for _, codec := range videoCodecsParameters { + if filterOutH264HighProfile && codec.RTPCodecCapability.SDPFmtpLine == h264HighProfileFmtp { + continue + } + if mime.IsMimeTypeStringRTX(codec.MimeType) { + continue + } + if !IsCodecEnabled(codecs, codec.RTPCodecCapability) { + continue + } + + cp := codec + cp.RTPCodecCapability.RTCPFeedback = rtcpFeedback.Video + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeVideo); err != nil { + return err + } + + if !rtxEnabled { + continue + } + + cp = videoRTXCodecParameters + cp.RTPCodecCapability.SDPFmtpLine = fmt.Sprintf("apt=%d", codec.PayloadType) + cp.PayloadType = codec.PayloadType + 1 + if err := me.RegisterCodec(cp, webrtc.RTPCodecTypeVideo); err != nil { + return err + } + } + return nil +} + +func registerHeaderExtensions(me *webrtc.MediaEngine, rtpHeaderExtension RTPHeaderExtensionConfig) error { + for _, extension := range rtpHeaderExtension.Video { + if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: extension}, webrtc.RTPCodecTypeVideo); err != nil { + return err + } + } + + for _, extension := range rtpHeaderExtension.Audio { + if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: extension}, webrtc.RTPCodecTypeAudio); err != nil { + return err + } + } + + return nil +} + +func createMediaEngine(codecs []*livekit.Codec, config DirectionConfig, filterOutH264HighProfile bool) (*webrtc.MediaEngine, error) { + me := &webrtc.MediaEngine{} + if err := registerCodecs(me, codecs, config.RTCPFeedback, filterOutH264HighProfile); err != nil { + return nil, err + } + + if err := registerHeaderExtensions(me, config.RTPHeaderExtension); err != nil { + return nil, err + } + + return me, nil +} + +func IsCodecEnabled(codecs []*livekit.Codec, cap webrtc.RTPCodecCapability) bool { + for _, codec := range codecs { + if !mime.IsMimeTypeStringEqual(codec.Mime, cap.MimeType) { + continue + } + if codec.FmtpLine == "" || strings.EqualFold(codec.FmtpLine, cap.SDPFmtpLine) { + return true + } + } + return false +} + +func selectAlternativeVideoCodec(enabledCodecs []*livekit.Codec) string { + for _, c := range enabledCodecs { + if mime.IsMimeTypeStringVideo(c.Mime) { + return c.Mime + } + } + // no viable codec in the list of enabled codecs, fall back to the most widely supported codec + return mime.MimeTypeVP8.String() +} + +func selectAlternativeAudioCodec(enabledCodecs []*livekit.Codec) string { + for _, c := range enabledCodecs { + if mime.IsMimeTypeStringAudio(c.Mime) { + return c.Mime + } + } + // no viable codec in the list of enabled codecs, fall back to the most widely supported codec + return mime.MimeTypeOpus.String() +} + +func filterCodecs( + codecs []webrtc.RTPCodecParameters, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, + filterOutH264HighProfile bool, +) []webrtc.RTPCodecParameters { + filteredCodecs := make([]webrtc.RTPCodecParameters, 0, len(codecs)) + for _, c := range codecs { + if filterOutH264HighProfile && isH264HighProfile(c.RTPCodecCapability.SDPFmtpLine) { + continue + } + + for _, enabledCodec := range enabledCodecs { + if mime.NormalizeMimeType(enabledCodec.Mime) == mime.NormalizeMimeType(c.RTPCodecCapability.MimeType) { + if !mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, mime.MimeTypeRTX.String()) { + if mime.IsMimeTypeStringVideo(c.RTPCodecCapability.MimeType) { + c.RTPCodecCapability.RTCPFeedback = rtcpFeedbackConfig.Video + } else { + c.RTPCodecCapability.RTCPFeedback = rtcpFeedbackConfig.Audio + } + } + filteredCodecs = append(filteredCodecs, c) + break + } + } + } + return filteredCodecs +} + +func isH264HighProfile(fmtp string) bool { + params := strings.Split(fmtp, ";") + for _, param := range params { + parts := strings.Split(param, "=") + if len(parts) == 2 { + if parts[0] == "profile-level-id" { + // https://datatracker.ietf.org/doc/html/rfc6184#section-8.1 + // hex value 0x64 for profile_idc is high profile + return strings.HasPrefix(parts[1], "64") + } + } + } + + return false +} diff --git a/livekit/pkg/rtc/mediaengine_test.go b/livekit/pkg/rtc/mediaengine_test.go new file mode 100644 index 0000000..122861a --- /dev/null +++ b/livekit/pkg/rtc/mediaengine_test.go @@ -0,0 +1,41 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "testing" + + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" +) + +func TestIsCodecEnabled(t *testing.T) { + t.Run("empty fmtp requirement should match all", func(t *testing.T) { + enabledCodecs := []*livekit.Codec{{Mime: "video/h264"}} + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String(), SDPFmtpLine: "special"})) + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String()})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeVP8.String()})) + }) + + t.Run("when fmtp is provided, require match", func(t *testing.T) { + enabledCodecs := []*livekit.Codec{{Mime: "video/h264", FmtpLine: "special"}} + require.True(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String(), SDPFmtpLine: "special"})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeH264.String()})) + require.False(t, IsCodecEnabled(enabledCodecs, webrtc.RTPCodecCapability{MimeType: mime.MimeTypeVP8.String()})) + }) +} diff --git a/livekit/pkg/rtc/medialossproxy.go b/livekit/pkg/rtc/medialossproxy.go new file mode 100644 index 0000000..0ac54b7 --- /dev/null +++ b/livekit/pkg/rtc/medialossproxy.go @@ -0,0 +1,105 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "sync" + "time" + + "github.com/pion/rtcp" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu" +) + +const ( + downLostUpdateDelta = time.Second +) + +type MediaLossProxyParams struct { + Logger logger.Logger +} + +type MediaLossProxy struct { + params MediaLossProxyParams + + lock sync.Mutex + maxDownFracLost uint8 + maxDownFracLostTs time.Time + maxDownFracLostValid bool + + onMediaLossUpdate func(fractionalLoss uint8) +} + +func NewMediaLossProxy(params MediaLossProxyParams) *MediaLossProxy { + return &MediaLossProxy{params: params} +} + +func (m *MediaLossProxy) OnMediaLossUpdate(f func(fractionalLoss uint8)) { + m.lock.Lock() + m.onMediaLossUpdate = f + m.lock.Unlock() +} + +func (m *MediaLossProxy) HandleMaxLossFeedback(_ *sfu.DownTrack, report *rtcp.ReceiverReport) { + m.lock.Lock() + for _, rr := range report.Reports { + m.maxDownFracLostValid = true + if m.maxDownFracLost < rr.FractionLost { + m.maxDownFracLost = rr.FractionLost + } + } + m.lock.Unlock() + + m.maybeUpdateLoss() +} + +func (m *MediaLossProxy) NotifySubscriberNodeMediaLoss(_nodeID livekit.NodeID, fractionalLoss uint8) { + m.lock.Lock() + m.maxDownFracLostValid = true + if m.maxDownFracLost < fractionalLoss { + m.maxDownFracLost = fractionalLoss + } + m.lock.Unlock() + + m.maybeUpdateLoss() +} + +func (m *MediaLossProxy) maybeUpdateLoss() { + var ( + shouldUpdate bool + maxLost uint8 + ) + + m.lock.Lock() + now := time.Now() + if now.Sub(m.maxDownFracLostTs) > downLostUpdateDelta && m.maxDownFracLostValid { + shouldUpdate = true + maxLost = m.maxDownFracLost + m.maxDownFracLost = 0 + m.maxDownFracLostTs = now + m.maxDownFracLostValid = false + } + onMediaLossUpdate := m.onMediaLossUpdate + m.lock.Unlock() + + if shouldUpdate { + if onMediaLossUpdate != nil { + onMediaLossUpdate(maxLost) + } + } +} diff --git a/livekit/pkg/rtc/mediatrack.go b/livekit/pkg/rtc/mediatrack.go new file mode 100644 index 0000000..cc6ad17 --- /dev/null +++ b/livekit/pkg/rtc/mediatrack.go @@ -0,0 +1,707 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "math" + "sync" + "time" + + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/utils/mono" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/dynacast" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/interceptor" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/telemetry" + util "github.com/livekit/mediatransportutil" +) + +var _ types.LocalMediaTrack = (*MediaTrack)(nil) + +// MediaTrack represents a WebRTC track that needs to be forwarded +// Implements MediaTrack and PublishedTrack interface +type MediaTrack struct { + params MediaTrackParams + buffer *buffer.Buffer + everSubscribed atomic.Bool + + *MediaTrackReceiver + *MediaLossProxy + + dynacastManager dynacast.DynacastManager + + lock sync.RWMutex + + rttFromXR atomic.Bool + + backupCodecPolicy livekit.BackupCodecPolicy + regressionTargetCodec mime.MimeType + regressionTargetCodecReceived bool + + onSubscribedMaxQualityChange func( + trackID livekit.TrackID, + trackInfo *livekit.TrackInfo, + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, + ) error + onSubscribedAudioCodecChange func( + trackID livekit.TrackID, + codecs []*livekit.SubscribedAudioCodec, + ) error +} + +type MediaTrackParams struct { + ParticipantID func() livekit.ParticipantID + ParticipantIdentity livekit.ParticipantIdentity + ParticipantVersion uint32 + ParticipantCountry string + BufferFactory *buffer.Factory + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + PLIThrottleConfig sfu.PLIThrottleConfig + AudioConfig sfu.AudioConfig + VideoConfig config.VideoConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger + Reporter roomobs.TrackReporter + SimTracks map[uint32]interceptor.SimulcastTrackInfo + OnRTCP func([]rtcp.Packet) + ForwardStats *sfu.ForwardStats + OnTrackEverSubscribed func(livekit.TrackID) + ShouldRegressCodec func() bool + PreferVideoSizeFromMedia bool + EnableRTPStreamRestartDetection bool + UpdateTrackInfoByVideoSizeChange bool + ForceBackupCodecPolicySimulcast bool +} + +func NewMediaTrack(params MediaTrackParams, ti *livekit.TrackInfo) *MediaTrack { + t := &MediaTrack{ + params: params, + backupCodecPolicy: ti.BackupCodecPolicy, + } + + if t.params.ForceBackupCodecPolicySimulcast { + t.backupCodecPolicy = livekit.BackupCodecPolicy_SIMULCAST + } + + if t.backupCodecPolicy != livekit.BackupCodecPolicy_SIMULCAST && len(ti.Codecs) > 1 { + t.regressionTargetCodec = mime.NormalizeMimeType(ti.Codecs[1].MimeType) + t.params.Logger.Debugw("track enabled codec regression", "regressionCodec", t.regressionTargetCodec) + } + + t.MediaTrackReceiver = NewMediaTrackReceiver(MediaTrackReceiverParams{ + MediaTrack: t, + IsRelayed: false, + ParticipantID: params.ParticipantID, + ParticipantIdentity: params.ParticipantIdentity, + ParticipantVersion: params.ParticipantVersion, + ReceiverConfig: params.ReceiverConfig, + SubscriberConfig: params.SubscriberConfig, + AudioConfig: params.AudioConfig, + Telemetry: params.Telemetry, + Logger: params.Logger, + RegressionTargetCodec: t.regressionTargetCodec, + PreferVideoSizeFromMedia: params.PreferVideoSizeFromMedia, + }, ti) + + if ti.Type == livekit.TrackType_AUDIO { + t.MediaLossProxy = NewMediaLossProxy(MediaLossProxyParams{ + Logger: params.Logger, + }) + t.MediaLossProxy.OnMediaLossUpdate(func(fractionalLoss uint8) { + if t.buffer != nil { + t.buffer.SetLastFractionLostReport(fractionalLoss) + } + }) + t.MediaTrackReceiver.OnMediaLossFeedback(t.MediaLossProxy.HandleMaxLossFeedback) + } + + switch ti.Type { + case livekit.TrackType_VIDEO: + t.dynacastManager = dynacast.NewDynacastManagerVideo(dynacast.DynacastManagerVideoParams{ + DynacastPauseDelay: params.VideoConfig.DynacastPauseDelay, + Listener: t, + Logger: params.Logger, + }) + + case livekit.TrackType_AUDIO: + if len(ti.Codecs) > 1 { + t.dynacastManager = dynacast.NewDynacastManagerAudio(dynacast.DynacastManagerAudioParams{ + Listener: t, + Logger: params.Logger, + }) + } + } + t.MediaTrackReceiver.OnSetupReceiver(func(mime mime.MimeType) { + if t.dynacastManager != nil { + t.dynacastManager.AddCodec(mime) + } + }) + t.MediaTrackReceiver.OnSubscriberMaxQualityChange( + func(subscriberID livekit.ParticipantID, mimeType mime.MimeType, layer int32) { + if t.dynacastManager != nil { + t.dynacastManager.NotifySubscriberMaxQuality( + subscriberID, + mimeType, + buffer.GetVideoQualityForSpatialLayer( + mimeType, + layer, + t.MediaTrackReceiver.TrackInfo(), + ), + ) + } + }, + ) + t.MediaTrackReceiver.OnSubscriberAudioCodecChange( + func(subscriberID livekit.ParticipantID, mimeType mime.MimeType, enabled bool) { + if t.dynacastManager != nil { + t.dynacastManager.NotifySubscription(subscriberID, mimeType, enabled) + } + }, + ) + t.MediaTrackReceiver.OnCodecRegression(func(old, new webrtc.RTPCodecParameters) { + if t.dynacastManager != nil { + t.dynacastManager.HandleCodecRegression( + mime.NormalizeMimeType(old.MimeType), + mime.NormalizeMimeType(new.MimeType), + ) + } + }) + + t.SetMuted(ti.Muted) + return t +} + +func (t *MediaTrack) OnSubscribedMaxQualityChange( + f func( + trackID livekit.TrackID, + trackInfo *livekit.TrackInfo, + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, + ) error, +) { + t.lock.Lock() + t.onSubscribedMaxQualityChange = f + t.lock.Unlock() +} + +func (t *MediaTrack) OnSubscribedAudioCodecChange( + f func( + trackID livekit.TrackID, + codecs []*livekit.SubscribedAudioCodec, + ) error, +) { + t.lock.Lock() + t.onSubscribedAudioCodecChange = f + t.lock.Unlock() +} + +func (t *MediaTrack) NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, qualities []types.SubscribedCodecQuality) { + if t.dynacastManager != nil { + t.dynacastManager.NotifySubscriberNodeMaxQuality(nodeID, qualities) + } +} + +func (t *MediaTrack) NotifySubscriptionNode(nodeID livekit.NodeID, codecs []*livekit.SubscribedAudioCodec) { + if t.dynacastManager != nil { + t.dynacastManager.NotifySubscriptionNode(nodeID, codecs) + } +} + +func (t *MediaTrack) ClearSubscriberNodes() { + if t.dynacastManager != nil { + t.dynacastManager.ClearSubscriberNodes() + } +} + +func (t *MediaTrack) HasSignalCid(cid string) bool { + if cid != "" { + ti := t.MediaTrackReceiver.TrackInfoClone() + for _, c := range ti.Codecs { + if c.Cid == cid { + return true + } + } + } + return false +} + +func (t *MediaTrack) HasSdpCid(cid string) bool { + if cid != "" { + ti := t.MediaTrackReceiver.TrackInfoClone() + for _, c := range ti.Codecs { + if c.Cid == cid || c.SdpCid == cid { + return true + } + } + } + return false +} + +func (t *MediaTrack) GetMimeTypeForSdpCid(cid string) mime.MimeType { + if cid != "" { + ti := t.MediaTrackReceiver.TrackInfoClone() + for _, c := range ti.Codecs { + if c.Cid == cid || c.SdpCid == cid { + return mime.NormalizeMimeType(c.MimeType) + } + } + } + return mime.MimeTypeUnknown +} + +func (t *MediaTrack) GetCidsForMimeType(mimeType mime.MimeType) (string, string) { + ti := t.MediaTrackReceiver.TrackInfoClone() + for _, c := range ti.Codecs { + if mime.NormalizeMimeType(c.MimeType) == mimeType { + return c.Cid, c.SdpCid + } + } + return "", "" +} + +func (t *MediaTrack) ToProto() *livekit.TrackInfo { + return t.MediaTrackReceiver.TrackInfoClone() +} + +// AddReceiver adds a new RTP receiver to the track, returns true when receiver represents a new codec +// and if a receiver was added successfully +func (t *MediaTrack) AddReceiver(receiver *webrtc.RTPReceiver, track sfu.TrackRemote, mid string) (bool, bool) { + var newCodec bool + ssrc := uint32(track.SSRC()) + buff, rtcpReader := t.params.BufferFactory.GetBufferPair(ssrc) + if buff == nil || rtcpReader == nil { + t.params.Logger.Errorw("could not retrieve buffer pair", nil) + return newCodec, false + } + + var lastRR uint32 + rtcpReader.OnPacket(func(bytes []byte) { + pkts, err := rtcp.Unmarshal(bytes) + if err != nil { + t.params.Logger.Errorw("could not unmarshal RTCP", err) + return + } + + for _, pkt := range pkts { + switch pkt := pkt.(type) { + case *rtcp.SourceDescription: + case *rtcp.SenderReport: + if pkt.SSRC == uint32(track.SSRC()) { + buff.SetSenderReportData(&livekit.RTCPSenderReportState{ + RtpTimestamp: pkt.RTPTime, + NtpTimestamp: pkt.NTPTime, + Packets: pkt.PacketCount, + Octets: uint64(pkt.OctetCount), + At: mono.UnixNano(), + }) + } + case *rtcp.ExtendedReport: + rttFromXR: + for _, report := range pkt.Reports { + if rr, ok := report.(*rtcp.DLRRReportBlock); ok { + for _, dlrrReport := range rr.Reports { + if dlrrReport.LastRR <= lastRR { + continue + } + nowNTP := util.ToNtpTime(time.Now()) + nowNTP32 := uint32(nowNTP >> 16) + ntpDiff := nowNTP32 - dlrrReport.LastRR - dlrrReport.DLRR + rtt := uint32(math.Ceil(float64(ntpDiff) * 1000.0 / 65536.0)) + buff.SetRTT(rtt) + t.rttFromXR.Store(true) + lastRR = dlrrReport.LastRR + break rttFromXR + } + } + } + } + } + }) + + ti := t.MediaTrackReceiver.TrackInfoClone() + t.lock.Lock() + var regressCodec bool + mimeType := mime.NormalizeMimeType(track.Codec().MimeType) + layer := buffer.GetSpatialLayerForRid(mimeType, track.RID(), ti) + if layer < 0 { + t.params.Logger.Warnw( + "AddReceiver failed due to negative layer", nil, + "rid", track.RID(), + "layer", layer, + "ssrc", track.SSRC(), + "codec", track.Codec(), + "trackInfo", logger.Proto(ti), + ) + t.lock.Unlock() + return newCodec, false + } + + t.params.Logger.Debugw( + "AddReceiver", + "rid", track.RID(), + "layer", layer, + "ssrc", track.SSRC(), + "codec", track.Codec(), + "trackInfo", logger.Proto(ti), + ) + wr := t.MediaTrackReceiver.Receiver(mimeType) + if wr == nil { + priority := -1 + for idx, c := range ti.Codecs { + if mime.IsMimeTypeStringEqual(track.Codec().MimeType, c.MimeType) { + priority = idx + break + } + } + if priority < 0 { + switch len(ti.Codecs) { + case 0: + // audio track + t.params.Logger.Warnw( + "unexpected 0 codecs in track info", nil, + "mime", mimeType, + "track", logger.Proto(ti), + ) + priority = 0 + case 1: + // older clients or non simulcast-codec, mime type only set later + if ti.Codecs[0].MimeType == "" { + priority = 0 + } + } + } + if priority < 0 { + t.params.Logger.Warnw( + "could not find codec for webrtc receiver", nil, + "mime", mimeType, + "track", logger.Proto(ti), + ) + t.lock.Unlock() + return newCodec, false + } + + newWR := sfu.NewWebRTCReceiver( + receiver, + track, + ti, + LoggerWithCodecMime(t.params.Logger, mimeType), + t.params.OnRTCP, + t.params.VideoConfig.StreamTrackerManager, + sfu.WithPliThrottleConfig(t.params.PLIThrottleConfig), + sfu.WithAudioConfig(t.params.AudioConfig), + sfu.WithLoadBalanceThreshold(20), + sfu.WithForwardStats(t.params.ForwardStats), + sfu.WithEnableRTPStreamRestartDetection(t.params.EnableRTPStreamRestartDetection), + ) + newWR.OnCloseHandler(func() { + t.MediaTrackReceiver.SetClosing(false) + t.MediaTrackReceiver.ClearReceiver(mimeType, false) + if t.MediaTrackReceiver.TryClose() { + if t.dynacastManager != nil { + t.dynacastManager.Close() + } + } + }) + + // SIMULCAST-CODEC-TODO: these need to be receiver/mime aware, setting it up only for primary now + statsKey := telemetry.StatsKeyForTrack( + t.params.ParticipantCountry, + livekit.StreamType_UPSTREAM, + t.PublisherID(), + t.ID(), + ti.Source, + ti.Type, + ) + newWR.OnStatsUpdate(func(_ *sfu.WebRTCReceiver, stat *livekit.AnalyticsStat) { + // send for only one codec, either primary (priority == 0) OR regressed codec + t.lock.RLock() + regressionTargetCodecReceived := t.regressionTargetCodecReceived + t.lock.RUnlock() + if priority == 0 || regressionTargetCodecReceived { + t.params.Telemetry.TrackStats(statsKey, stat) + + if cs, ok := telemetry.CondenseStat(stat); ok { + t.params.Reporter.Tx(func(tx roomobs.TrackTx) { + tx.ReportName(ti.Name) + tx.ReportKind(roomobs.TrackKindPub) + tx.ReportType(roomobs.TrackTypeFromProto(ti.Type)) + tx.ReportSource(roomobs.TrackSourceFromProto(ti.Source)) + tx.ReportMime(mime.NormalizeMimeType(ti.MimeType).ReporterType()) + tx.ReportLayer(roomobs.PackTrackLayer(ti.Height, ti.Width)) + tx.ReportDuration(uint16(cs.EndTime.Sub(cs.StartTime).Milliseconds())) + tx.ReportFrames(uint16(cs.Frames)) + tx.ReportRecvBytes(uint32(cs.Bytes)) + tx.ReportRecvPackets(cs.Packets) + tx.ReportPacketsLost(cs.PacketsLost) + tx.ReportScore(stat.Score) + }) + } + } + }) + + newWR.OnMaxLayerChange(func(mimeType mime.MimeType, maxLayer int32) { + // send for only one codec, either primary (priority == 0) OR regressed codec + t.lock.RLock() + regressionTargetCodecReceived := t.regressionTargetCodecReceived + t.lock.RUnlock() + if priority == 0 || regressionTargetCodecReceived { + t.MediaTrackReceiver.NotifyMaxLayerChange(mimeType, maxLayer) + } + }) + // SIMULCAST-CODEC-TODO END: these need to be receiver/mime aware, setting it up only for primary now + + if t.PrimaryReceiver() == nil { + // primary codec published, set potential codecs + potentialCodecs := make([]webrtc.RTPCodecParameters, 0, len(ti.Codecs)) + parameters := receiver.GetParameters() + for _, c := range ti.Codecs { + for _, nc := range parameters.Codecs { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { + potentialCodecs = append(potentialCodecs, nc) + break + } + } + } + + if len(potentialCodecs) > 0 { + t.params.Logger.Debugw("primary codec published, set potential codecs", "potential", potentialCodecs) + t.MediaTrackReceiver.SetPotentialCodecs(potentialCodecs, parameters.HeaderExtensions) + } + } + + t.buffer = buff + + t.MediaTrackReceiver.SetupReceiver(newWR, priority, mid) + + for ssrc, info := range t.params.SimTracks { + if info.Mid == mid && !info.IsRepairStream { + t.MediaTrackReceiver.SetLayerSsrcsForRid(mimeType, info.StreamID, ssrc, info.RepairSSRC) + } + } + wr = newWR + newCodec = true + + newWR.AddOnCodecStateChange(func(codec webrtc.RTPCodecParameters, state sfu.ReceiverCodecState) { + t.MediaTrackReceiver.HandleReceiverCodecChange(newWR, codec, state) + }) + + // update subscriber video layers when video size changes + newWR.OnVideoSizeChanged(func() { + if t.params.UpdateTrackInfoByVideoSizeChange { + t.MediaTrackReceiver.UpdateVideoSize(mimeType, newWR.VideoSizes()) + } + + t.MediaTrackSubscriptions.UpdateVideoLayers() + }) + } + + if newCodec && t.enableRegression() { + if mimeType == t.regressionTargetCodec { + t.params.Logger.Infow("regression target codec received", "codec", mimeType) + t.regressionTargetCodecReceived = true + regressCodec = true + } else if t.regressionTargetCodecReceived { + regressCodec = true + } + } + t.lock.Unlock() + + if err := wr.(*sfu.WebRTCReceiver).AddUpTrack(track, buff); err != nil { + t.params.Logger.Warnw( + "adding up track failed", err, + "rid", track.RID(), + "layer", layer, + "ssrc", track.SSRC(), + "newCodec", newCodec, + ) + buff.Close() + return newCodec, false + } + + var expectedBitrate int + layers := buffer.GetVideoLayersForMimeType(mimeType, ti) + if layer >= 0 && len(layers) > int(layer) { + expectedBitrate = int(layers[layer].GetBitrate()) + } + if err := buff.Bind(receiver.GetParameters(), track.Codec().RTPCodecCapability, expectedBitrate); err != nil { + t.params.Logger.Warnw( + "binding buffer failed", err, + "rid", track.RID(), + "layer", layer, + "ssrc", track.SSRC(), + "newCodec", newCodec, + ) + buff.Close() + return newCodec, false + } + + t.MediaTrackReceiver.SetLayerSsrcsForRid(mimeType, track.RID(), uint32(track.SSRC()), 0) + + if regressCodec { + for _, c := range ti.Codecs { + if mime.NormalizeMimeType(c.MimeType) == t.regressionTargetCodec { + continue + } + + t.params.Logger.Debugw("suspending codec for codec regression", "codec", c.MimeType) + if r := t.MediaTrackReceiver.Receiver(mime.NormalizeMimeType(c.MimeType)); r != nil { + if rtcreceiver, ok := r.(*sfu.WebRTCReceiver); ok { + rtcreceiver.SetCodecState(sfu.ReceiverCodecStateSuspended) + } + } + } + } + + buff.OnNotifyRTX(t.MediaTrackReceiver.setLayerRtxInfo) + + // if subscriber request fps before fps calculated, update them after fps updated. + buff.OnFpsChanged(func() { + t.MediaTrackSubscriptions.UpdateVideoLayers() + }) + + buff.OnFinalRtpStats(func(stats *livekit.RTPStats) { + t.params.Telemetry.TrackPublishRTPStats( + context.Background(), + t.params.ParticipantID(), + t.ID(), + mimeType, + int(layer), + stats, + ) + }) + return newCodec, true +} + +func (t *MediaTrack) GetConnectionScoreAndQuality() (float32, livekit.ConnectionQuality) { + receiver := t.ActiveReceiver() + if rtcReceiver, ok := receiver.(*sfu.WebRTCReceiver); ok { + return rtcReceiver.GetConnectionScoreAndQuality() + } + + return connectionquality.MaxMOS, livekit.ConnectionQuality_EXCELLENT +} + +func (t *MediaTrack) SetRTT(rtt uint32) { + if !t.rttFromXR.Load() { + t.MediaTrackReceiver.SetRTT(rtt) + } +} + +func (t *MediaTrack) HasPendingCodec() bool { + return t.MediaTrackReceiver.PrimaryReceiver() == nil +} + +func (t *MediaTrack) Restart() { + t.MediaTrackReceiver.Restart() + + if t.dynacastManager != nil { + t.dynacastManager.Restart() + } +} + +func (t *MediaTrack) Close(isExpectedToResume bool) { + t.MediaTrackReceiver.SetClosing(isExpectedToResume) + if t.dynacastManager != nil { + t.dynacastManager.Close() + } + t.MediaTrackReceiver.Close(isExpectedToResume) +} + +func (t *MediaTrack) SetMuted(muted bool) { + // update quality based on subscription if unmuting. + // This will queue up the current state, but subscriber + // driven changes could update it. + if !muted && t.dynacastManager != nil { + t.dynacastManager.ForceUpdate() + } + + t.MediaTrackReceiver.SetMuted(muted) +} + +// OnTrackSubscribed is called when the track is subscribed by a non-hidden subscriber +// this allows the publisher to know when they should start sending data +func (t *MediaTrack) OnTrackSubscribed() { + if !t.everSubscribed.Swap(true) && t.params.OnTrackEverSubscribed != nil { + go t.params.OnTrackEverSubscribed(t.ID()) + } +} + +func (t *MediaTrack) enableRegression() bool { + return t.backupCodecPolicy == livekit.BackupCodecPolicy_REGRESSION || + (t.backupCodecPolicy == livekit.BackupCodecPolicy_PREFER_REGRESSION && t.params.ShouldRegressCodec()) +} + +func (t *MediaTrack) Logger() logger.Logger { + return t.params.Logger +} + +// dynacast.DynacastManagerListtener implementation +var _ dynacast.DynacastManagerListener = (*MediaTrack)(nil) + +func (t *MediaTrack) OnDynacastSubscribedMaxQualityChange( + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, +) { + t.lock.RLock() + onSubscribedMaxQualityChange := t.onSubscribedMaxQualityChange + t.lock.RUnlock() + + if onSubscribedMaxQualityChange != nil && !t.IsMuted() { + _ = onSubscribedMaxQualityChange( + t.ID(), + t.ToProto(), + subscribedQualities, + maxSubscribedQualities, + ) + } + + for _, q := range maxSubscribedQualities { + receiver := t.Receiver(q.CodecMime) + if receiver != nil { + receiver.SetMaxExpectedSpatialLayer( + buffer.GetSpatialLayerForVideoQuality( + q.CodecMime, + q.Quality, + t.MediaTrackReceiver.TrackInfo(), + ), + ) + } + } +} + +func (t *MediaTrack) OnDynacastSubscribedAudioCodecChange(codecs []*livekit.SubscribedAudioCodec) { + t.lock.RLock() + onSubscribedAudioCodecChange := t.onSubscribedAudioCodecChange + t.lock.RUnlock() + + if onSubscribedAudioCodecChange != nil { + _ = onSubscribedAudioCodecChange(t.ID(), codecs) + } +} diff --git a/livekit/pkg/rtc/mediatrack_test.go b/livekit/pkg/rtc/mediatrack_test.go new file mode 100644 index 0000000..7839a94 --- /dev/null +++ b/livekit/pkg/rtc/mediatrack_test.go @@ -0,0 +1,197 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +func TestTrackInfo(t *testing.T) { + // ensures that persisted trackinfo is being returned + ti := livekit.TrackInfo{ + Sid: "testsid", + Name: "testtrack", + Source: livekit.TrackSource_SCREEN_SHARE, + Type: livekit.TrackType_VIDEO, + Simulcast: false, + Width: 100, + Height: 80, + Muted: true, + } + + mt := NewMediaTrack(MediaTrackParams{}, &ti) + outInfo := mt.ToProto() + require.Equal(t, ti.Muted, outInfo.Muted) + require.Equal(t, ti.Name, outInfo.Name) + require.Equal(t, ti.Name, mt.Name()) + require.Equal(t, livekit.TrackID(ti.Sid), mt.ID()) + require.Equal(t, ti.Type, outInfo.Type) + require.Equal(t, ti.Type, mt.Kind()) + require.Equal(t, ti.Source, outInfo.Source) + require.Equal(t, ti.Width, outInfo.Width) + require.Equal(t, ti.Height, outInfo.Height) + require.Equal(t, ti.Simulcast, outInfo.Simulcast) +} + +func TestGetQualityForDimension(t *testing.T) { + t.Run("landscape source", func(t *testing.T) { + mt := NewMediaTrack(MediaTrackParams{ + Logger: logger.GetLogger(), + }, &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 1080, + Height: 720, + }) + + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeVP8, 120, 120)) + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeVP8, 300, 200)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(mime.MimeTypeVP8, 200, 250)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeVP8, 700, 480)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeVP8, 500, 1000)) + }) + + t.Run("portrait source", func(t *testing.T) { + mt := NewMediaTrack(MediaTrackParams{ + Logger: logger.GetLogger(), + }, &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 540, + Height: 960, + }) + + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeVP8, 200, 400)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(mime.MimeTypeVP8, 400, 400)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(mime.MimeTypeVP8, 400, 700)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeVP8, 600, 900)) + }) + + t.Run("layers provided", func(t *testing.T) { + mt := NewMediaTrack(MediaTrackParams{ + Logger: logger.GetLogger(), + }, &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 1080, + Height: 720, + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeH264.String(), + Layers: []*livekit.VideoLayer{ + { + Quality: livekit.VideoQuality_LOW, + Width: 480, + Height: 270, + }, + { + Quality: livekit.VideoQuality_MEDIUM, + Width: 960, + Height: 540, + }, + { + Quality: livekit.VideoQuality_HIGH, + Width: 1080, + Height: 720, + }, + }, + }, + }, + }) + + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeH264, 120, 120)) + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeH264, 300, 300)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(mime.MimeTypeH264, 800, 500)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 1000, 700)) + }) + + t.Run("highest layer with smallest dimensions", func(t *testing.T) { + mt := NewMediaTrack(MediaTrackParams{ + Logger: logger.GetLogger(), + }, &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 1080, + Height: 720, + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeH264.String(), + Layers: []*livekit.VideoLayer{ + { + Quality: livekit.VideoQuality_LOW, + Width: 480, + Height: 270, + }, + { + Quality: livekit.VideoQuality_MEDIUM, + Width: 1080, + Height: 720, + }, + { + Quality: livekit.VideoQuality_HIGH, + Width: 1080, + Height: 720, + }, + }, + }, + }, + }) + + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeH264, 120, 120)) + require.Equal(t, livekit.VideoQuality_LOW, mt.GetQualityForDimension(mime.MimeTypeH264, 300, 300)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 800, 500)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 1000, 700)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 1200, 800)) + + mt = NewMediaTrack(MediaTrackParams{ + Logger: logger.GetLogger(), + }, &livekit.TrackInfo{ + Type: livekit.TrackType_VIDEO, + Width: 1080, + Height: 720, + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeH264.String(), + Layers: []*livekit.VideoLayer{ + { + Quality: livekit.VideoQuality_LOW, + Width: 480, + Height: 270, + }, + { + Quality: livekit.VideoQuality_MEDIUM, + Width: 480, + Height: 270, + }, + { + Quality: livekit.VideoQuality_HIGH, + Width: 1080, + Height: 720, + }, + }, + }, + }, + }) + + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(mime.MimeTypeH264, 120, 120)) + require.Equal(t, livekit.VideoQuality_MEDIUM, mt.GetQualityForDimension(mime.MimeTypeH264, 300, 300)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 800, 500)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 1000, 700)) + require.Equal(t, livekit.VideoQuality_HIGH, mt.GetQualityForDimension(mime.MimeTypeH264, 1200, 800)) + }) + +} diff --git a/livekit/pkg/rtc/mediatrackreceiver.go b/livekit/pkg/rtc/mediatrackreceiver.go new file mode 100644 index 0000000..b109dfd --- /dev/null +++ b/livekit/pkg/rtc/mediatrackreceiver.go @@ -0,0 +1,1241 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "errors" + "fmt" + "slices" + "sort" + "strings" + "sync" + + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/livekit-server/pkg/telemetry" + sutils "github.com/livekit/livekit-server/pkg/utils" +) + +const ( + layerSelectionTolerance = 0.9 +) + +var ( + ErrNotOpen = errors.New("track is not open") + ErrNoReceiver = errors.New("cannot subscribe without a receiver in place") +) + +// ------------------------------------------------------ + +type mediaTrackReceiverState int + +const ( + mediaTrackReceiverStateOpen mediaTrackReceiverState = iota + mediaTrackReceiverStateClosing + mediaTrackReceiverStateClosed +) + +func (m mediaTrackReceiverState) String() string { + switch m { + case mediaTrackReceiverStateOpen: + return "OPEN" + case mediaTrackReceiverStateClosing: + return "CLOSING" + case mediaTrackReceiverStateClosed: + return "CLOSED" + default: + return fmt.Sprintf("%d", int(m)) + } +} + +// ----------------------------------------------------- + +type simulcastReceiver struct { + sfu.TrackReceiver + priority int + lock sync.Mutex + regressTo sfu.TrackReceiver +} + +func (r *simulcastReceiver) Priority() int { + return r.priority +} + +func (r *simulcastReceiver) AddDownTrack(track sfu.TrackSender) error { + r.lock.Lock() + if rt := r.regressTo; rt != nil { + r.lock.Unlock() + // AddDownTrack could be called in downtrack.OnBinding callback, use a goroutine to avoid deadlock + go track.SetReceiver(rt) + return nil + } + err := r.TrackReceiver.AddDownTrack(track) + r.lock.Unlock() + return err +} + +func (r *simulcastReceiver) RegressTo(receiver sfu.TrackReceiver) { + r.lock.Lock() + r.regressTo = receiver + dts := r.GetDownTracks() + r.lock.Unlock() + + for _, dt := range dts { + dt.SetReceiver(receiver) + } +} + +func (r *simulcastReceiver) IsRegressed() bool { + r.lock.Lock() + defer r.lock.Unlock() + return r.regressTo != nil +} + +// ----------------------------------------------------- + +type MediaTrackReceiverParams struct { + MediaTrack types.MediaTrack + IsRelayed bool + ParticipantID func() livekit.ParticipantID + ParticipantIdentity livekit.ParticipantIdentity + ParticipantVersion uint32 + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + AudioConfig sfu.AudioConfig + Telemetry telemetry.TelemetryService + Logger logger.Logger + RegressionTargetCodec mime.MimeType + PreferVideoSizeFromMedia bool +} + +type MediaTrackReceiver struct { + params MediaTrackReceiverParams + + lock sync.RWMutex + receivers []*simulcastReceiver + trackInfo atomic.Pointer[livekit.TrackInfo] + potentialCodecs []webrtc.RTPCodecParameters + state mediaTrackReceiverState + isExpectedToResume bool + + onSetupReceiver func(mime mime.MimeType) + onMediaLossFeedback func(dt *sfu.DownTrack, report *rtcp.ReceiverReport) + onClose []func(isExpectedToResume bool) + onCodecRegression func(old, new webrtc.RTPCodecParameters) + + *MediaTrackSubscriptions +} + +func NewMediaTrackReceiver(params MediaTrackReceiverParams, ti *livekit.TrackInfo) *MediaTrackReceiver { + t := &MediaTrackReceiver{ + params: params, + state: mediaTrackReceiverStateOpen, + } + t.trackInfo.Store(utils.CloneProto(ti)) + + t.MediaTrackSubscriptions = NewMediaTrackSubscriptions(MediaTrackSubscriptionsParams{ + MediaTrack: params.MediaTrack, + IsRelayed: params.IsRelayed, + ReceiverConfig: params.ReceiverConfig, + SubscriberConfig: params.SubscriberConfig, + Telemetry: params.Telemetry, + Logger: params.Logger, + }) + t.MediaTrackSubscriptions.OnDownTrackCreated(t.onDownTrackCreated) + return t +} + +func (t *MediaTrackReceiver) Restart() { + for _, receiver := range t.loadReceivers() { + hq := buffer.GetSpatialLayerForVideoQuality(receiver.Mime(), livekit.VideoQuality_HIGH, t.TrackInfo()) + receiver.SetMaxExpectedSpatialLayer(hq) + } +} + +func (t *MediaTrackReceiver) OnSetupReceiver(f func(mime mime.MimeType)) { + t.lock.Lock() + t.onSetupReceiver = f + t.lock.Unlock() +} + +func (t *MediaTrackReceiver) SetupReceiver(receiver sfu.TrackReceiver, priority int, mid string) { + t.lock.Lock() + if t.state != mediaTrackReceiverStateOpen { + t.params.Logger.Warnw("cannot set up receiver on a track not open", nil) + t.lock.Unlock() + return + } + + receivers := slices.Clone(t.receivers) + + // codec position maybe taken by DummyReceiver, check and upgrade to WebRTCReceiver + var existingReceiver bool + for _, r := range receivers { + if r.Mime() == receiver.Mime() { + existingReceiver = true + if d, ok := r.TrackReceiver.(*DummyReceiver); ok { + d.Upgrade(receiver) + } else { + t.params.Logger.Errorw("receiver already exists, setup failed", nil, "mime", receiver.Mime()) + } + break + } + } + if !existingReceiver { + receivers = append(receivers, &simulcastReceiver{TrackReceiver: receiver, priority: priority}) + } + + sort.Slice(receivers, func(i, j int) bool { + return receivers[i].Priority() < receivers[j].Priority() + }) + + if mid != "" { + trackInfo := t.TrackInfoClone() + if priority == 0 { + trackInfo.MimeType = receiver.Mime().String() + trackInfo.Mid = mid + } + + for i, ci := range trackInfo.Codecs { + if i == priority { + ci.MimeType = receiver.Mime().String() + ci.Mid = mid + break + } + } + t.trackInfo.Store(trackInfo) + } + + t.receivers = receivers + onSetupReceiver := t.onSetupReceiver + t.lock.Unlock() + + var receiverCodecs []string + for _, r := range receivers { + receiverCodecs = append(receiverCodecs, r.Mime().String()) + } + t.params.Logger.Debugw( + "setup receiver", + "mime", receiver.Mime(), + "priority", priority, + "receivers", receiverCodecs, + "mid", mid, + ) + + if onSetupReceiver != nil { + onSetupReceiver(receiver.Mime()) + } +} + +func (t *MediaTrackReceiver) HandleReceiverCodecChange(r sfu.TrackReceiver, codec webrtc.RTPCodecParameters, state sfu.ReceiverCodecState) { + // TODO: we only support codec regress to backup codec now, so the receiver will not be available + // once fallback / regression happens. + // We will support codec upgrade in the future then the primary receiver will be available again if + // all subscribers of the track negotiate it. + if state == sfu.ReceiverCodecStateNormal { + return + } + + t.lock.Lock() + // codec regression, find backup codec and switch all downtracks to it + var ( + oldReceiver *simulcastReceiver + backupCodecReceiver sfu.TrackReceiver + ) + for _, receiver := range t.receivers { + if receiver.TrackReceiver == r { + oldReceiver = receiver + continue + } + if d, ok := receiver.TrackReceiver.(*DummyReceiver); ok && d.Receiver() == r { + oldReceiver = receiver + continue + } + + if receiver.Mime() == t.params.RegressionTargetCodec { + backupCodecReceiver = receiver.TrackReceiver + } + + if oldReceiver != nil && backupCodecReceiver != nil { + break + } + } + + if oldReceiver == nil { + // should not happen + t.params.Logger.Errorw("could not find primary receiver for codec", nil, "codec", codec.MimeType) + t.lock.Unlock() + return + } + + if oldReceiver.IsRegressed() { + t.params.Logger.Infow("codec already regressed", "codec", codec.MimeType) + t.lock.Unlock() + return + } + + if backupCodecReceiver == nil { + t.params.Logger.Infow("no backup codec found, can't regress codec") + t.lock.Unlock() + return + } + + t.params.Logger.Infow( + "regressing codec", + "from", codec.MimeType, + "to", backupCodecReceiver.Codec().MimeType, + ) + + // remove old codec from potential codecs + for i, c := range t.potentialCodecs { + if strings.EqualFold(c.MimeType, codec.MimeType) { + t.potentialCodecs = slices.Delete(t.potentialCodecs, i, i+1) + break + } + } + onCodecRegression := t.onCodecRegression + t.lock.Unlock() + oldReceiver.RegressTo(backupCodecReceiver) + + if onCodecRegression != nil { + onCodecRegression(codec, backupCodecReceiver.Codec()) + } +} + +func (t *MediaTrackReceiver) OnCodecRegression(f func(old, new webrtc.RTPCodecParameters)) { + t.onCodecRegression = f +} + +func (t *MediaTrackReceiver) SetPotentialCodecs(codecs []webrtc.RTPCodecParameters, headers []webrtc.RTPHeaderExtensionParameter) { + // The potential codecs have not published yet, so we can't get the actual Extensions, the client/browser uses same extensions + // for all video codecs so we assume they will have same extensions as the primary codec. + t.lock.Lock() + receivers := slices.Clone(t.receivers) + t.potentialCodecs = codecs + for i, c := range codecs { + var exist bool + for _, r := range receivers { + if mime.NormalizeMimeType(c.MimeType) == r.Mime() { + exist = true + break + } + } + if !exist { + receivers = append(receivers, &simulcastReceiver{ + TrackReceiver: NewDummyReceiver(t.ID(), string(t.PublisherID()), c, headers), + priority: i, + }) + } + } + sort.Slice(receivers, func(i, j int) bool { + return receivers[i].Priority() < receivers[j].Priority() + }) + t.receivers = receivers + t.lock.Unlock() +} + +func (t *MediaTrackReceiver) ClearReceiver(mime mime.MimeType, isExpectedToResume bool) { + t.lock.Lock() + receivers := slices.Clone(t.receivers) + for idx, receiver := range receivers { + if receiver.Mime() == mime { + receivers[idx] = receivers[len(receivers)-1] + receivers[len(receivers)-1] = nil + receivers = receivers[:len(receivers)-1] + break + } + } + t.receivers = receivers + t.lock.Unlock() + + t.removeAllSubscribersForMime(mime, isExpectedToResume) +} + +func (t *MediaTrackReceiver) ClearAllReceivers(isExpectedToResume bool) { + t.params.Logger.Debugw("clearing all receivers", "isExpectedToResume", isExpectedToResume) + t.lock.Lock() + receivers := t.receivers + t.receivers = nil + + t.isExpectedToResume = isExpectedToResume + t.lock.Unlock() + + for _, r := range receivers { + t.removeAllSubscribersForMime(r.Mime(), isExpectedToResume) + } +} + +func (t *MediaTrackReceiver) OnMediaLossFeedback(f func(dt *sfu.DownTrack, rr *rtcp.ReceiverReport)) { + t.onMediaLossFeedback = f +} + +func (t *MediaTrackReceiver) IsOpen() bool { + t.lock.RLock() + defer t.lock.RUnlock() + if t.state != mediaTrackReceiverStateOpen { + return false + } + // If any one of the receivers has entered closed state, we would not consider the track open + for _, receiver := range t.receivers { + if receiver.IsClosed() { + return false + } + } + return true +} + +func (t *MediaTrackReceiver) SetClosing(isExpectedToResume bool) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.state == mediaTrackReceiverStateOpen { + t.state = mediaTrackReceiverStateClosing + } + + t.isExpectedToResume = isExpectedToResume +} + +func (t *MediaTrackReceiver) TryClose() bool { + t.lock.RLock() + if t.state == mediaTrackReceiverStateClosed { + t.lock.RUnlock() + return true + } + + numActiveReceivers := 0 + for _, receiver := range t.receivers { + dr, ok := receiver.TrackReceiver.(*DummyReceiver) + if !ok || dr.Receiver() != nil { + // !ok means real receiver OR + // dummy receiver with a regular receiver attached + numActiveReceivers++ + } + } + + isExpectedToResume := t.isExpectedToResume + t.lock.RUnlock() + if numActiveReceivers != 0 { + return false + } + + t.Close(isExpectedToResume) + return true +} + +func (t *MediaTrackReceiver) Close(isExpectedToResume bool) { + t.ClearAllReceivers(isExpectedToResume) + + t.lock.Lock() + if t.state == mediaTrackReceiverStateClosed { + t.lock.Unlock() + return + } + + t.state = mediaTrackReceiverStateClosed + onclose := t.onClose + t.lock.Unlock() + + for _, f := range onclose { + f(isExpectedToResume) + } +} + +func (t *MediaTrackReceiver) ID() livekit.TrackID { + return livekit.TrackID(t.TrackInfo().Sid) +} + +func (t *MediaTrackReceiver) Kind() livekit.TrackType { + return t.TrackInfo().Type +} + +func (t *MediaTrackReceiver) Source() livekit.TrackSource { + return t.TrackInfo().Source +} + +func (t *MediaTrackReceiver) Stream() string { + return t.TrackInfo().Stream +} + +func (t *MediaTrackReceiver) PublisherID() livekit.ParticipantID { + return t.params.ParticipantID() +} + +func (t *MediaTrackReceiver) PublisherIdentity() livekit.ParticipantIdentity { + return t.params.ParticipantIdentity +} + +func (t *MediaTrackReceiver) PublisherVersion() uint32 { + return t.params.ParticipantVersion +} + +func (t *MediaTrackReceiver) Name() string { + return t.TrackInfo().Name +} + +func (t *MediaTrackReceiver) IsMuted() bool { + return t.TrackInfo().Muted +} + +func (t *MediaTrackReceiver) SetMuted(muted bool) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + trackInfo.Muted = muted + t.trackInfo.Store(trackInfo) + + receivers := t.receivers + t.lock.Unlock() + + for _, receiver := range receivers { + receiver.SetUpTrackPaused(muted) + } + + t.MediaTrackSubscriptions.SetMuted(muted) +} + +func (t *MediaTrackReceiver) IsEncrypted() bool { + return t.TrackInfo().Encryption != livekit.Encryption_NONE +} + +func (t *MediaTrackReceiver) AddOnClose(f func(isExpectedToResume bool)) { + if f == nil { + return + } + + t.lock.Lock() + t.onClose = append(t.onClose, f) + t.lock.Unlock() +} + +// AddSubscriber subscribes sub to current mediaTrack +func (t *MediaTrackReceiver) AddSubscriber(sub types.LocalParticipant) (types.SubscribedTrack, error) { + t.lock.RLock() + if t.state != mediaTrackReceiverStateOpen { + t.lock.RUnlock() + return nil, ErrNotOpen + } + + receivers := t.receivers + potentialCodecs := make([]webrtc.RTPCodecParameters, len(t.potentialCodecs)) + copy(potentialCodecs, t.potentialCodecs) + t.lock.RUnlock() + + if len(receivers) == 0 { + // cannot add, no receiver + return nil, ErrNoReceiver + } + + for _, receiver := range receivers { + if receiver.IsRegressed() { + continue + } + + codec := receiver.Codec() + var found bool + for _, pc := range potentialCodecs { + if mime.IsMimeTypeStringEqual(codec.MimeType, pc.MimeType) { + found = true + break + } + } + if !found { + potentialCodecs = append(potentialCodecs, codec) + } + } + + streamId := string(t.PublisherID()) + if sub.ProtocolVersion().SupportsPackedStreamId() { + // when possible, pack both IDs in streamID to allow new streams to be generated + // react-native-webrtc still uses stream based APIs and require this + streamId = PackStreamID(t.PublisherID(), t.ID()) + } + + tLogger := LoggerWithTrack(sub.GetLogger(), t.ID(), t.params.IsRelayed) + wr := NewWrappedReceiver(WrappedReceiverParams{ + Receivers: receivers, + TrackID: t.ID(), + StreamId: streamId, + UpstreamCodecs: potentialCodecs, + Logger: tLogger, + DisableRed: !IsRedEnabled(t.TrackInfo()) || !t.params.AudioConfig.ActiveREDEncoding, + IsEncrypted: t.IsEncrypted(), + }) + subID := sub.ID() + subTrack, err := t.MediaTrackSubscriptions.AddSubscriber(sub, wr) + + // media track could have been closed while adding subscription + remove := false + isExpectedToResume := false + t.lock.RLock() + if t.state != mediaTrackReceiverStateOpen { + isExpectedToResume = t.isExpectedToResume + remove = true + } + t.lock.RUnlock() + + if remove { + t.params.Logger.Debugw( + "removing subscriber on a not-open track", + "subscriberID", subID, + "isExpectedToResume", isExpectedToResume, + ) + _ = t.MediaTrackSubscriptions.RemoveSubscriber(subID, isExpectedToResume) + return nil, ErrNotOpen + } + + return subTrack, err +} + +// RemoveSubscriber removes participant from subscription +// stop all forwarders to the client +func (t *MediaTrackReceiver) RemoveSubscriber(subscriberID livekit.ParticipantID, isExpectedToResume bool) { + _ = t.MediaTrackSubscriptions.RemoveSubscriber(subscriberID, isExpectedToResume) +} + +func (t *MediaTrackReceiver) removeAllSubscribersForMime(mime mime.MimeType, isExpectedToResume bool) { + t.params.Logger.Debugw("removing all subscribers for mime", "mime", mime, "isExpectedToResume", isExpectedToResume) + for _, subscriberID := range t.MediaTrackSubscriptions.GetAllSubscribersForMime(mime) { + t.RemoveSubscriber(subscriberID, isExpectedToResume) + } +} + +func (t *MediaTrackReceiver) RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { + var revokedSubscriberIdentities []livekit.ParticipantIdentity + + // LK-TODO: large number of subscribers needs to be solved for this loop + for _, subTrack := range t.MediaTrackSubscriptions.getAllSubscribedTracks() { + if IsParticipantExemptFromTrackPermissionsRestrictions(subTrack.Subscriber()) { + continue + } + + found := slices.Contains(allowedSubscriberIdentities, subTrack.SubscriberIdentity()) + if !found { + t.params.Logger.Infow("revoking subscription", + "subscriber", subTrack.SubscriberIdentity(), + "subscriberID", subTrack.SubscriberID(), + ) + t.RemoveSubscriber(subTrack.SubscriberID(), false) + revokedSubscriberIdentities = append(revokedSubscriberIdentities, subTrack.SubscriberIdentity()) + } + } + + return revokedSubscriberIdentities +} + +func (t *MediaTrackReceiver) updateTrackInfoOfReceivers() { + ti := t.TrackInfo() + for _, r := range t.loadReceivers() { + r.UpdateTrackInfo(ti) + } +} + +func (t *MediaTrackReceiver) SetLayerSsrcsForRid(mimeType mime.MimeType, rid string, ssrc uint32, repairSSRC uint32) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + layer := buffer.GetSpatialLayerForRid(mimeType, rid, trackInfo) + if layer == buffer.InvalidLayerSpatial { + // non-simulcast case will not have `rid` + layer = 0 + } + quality := buffer.GetVideoQualityForSpatialLayer(mimeType, layer, trackInfo) + // set video layer ssrc info + for i, ci := range trackInfo.Codecs { + if mime.NormalizeMimeType(ci.MimeType) != mimeType { + continue + } + + // if origin layer has ssrc, don't override it + var matchingLayer *livekit.VideoLayer + ssrcFound := false + for _, l := range ci.Layers { + if l.Quality == quality { + matchingLayer = l + if l.Ssrc != 0 { + ssrcFound = true + } + break + } + } + if !ssrcFound && matchingLayer != nil { + matchingLayer.Ssrc = ssrc + if repairSSRC != 0 { + matchingLayer.RepairSsrc = repairSSRC + } + } + if ssrcFound { + t.params.Logger.Warnw( + "not overriding ssrc", nil, + "rid", rid, + "ssrc", ssrc, + "existingSSRC", matchingLayer.Ssrc, + "repairSSRC", repairSSRC, + "existingRepairSSRC", matchingLayer.RepairSsrc, + "trackInfo", trackInfo, + ) + } + + // for client don't use simulcast codecs (old client version or single codec) + if i == 0 { + trackInfo.Layers = ci.Layers + } + break + } + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() +} + +func (t *MediaTrackReceiver) setLayerRtxInfo(ssrc uint32, repairSSRC uint32, rsid string) { + t.params.Logger.Debugw("rtx notification", "ssrc", ssrc, "repairSSRC", repairSSRC, "rsid", rsid) + if ssrc == 0 || repairSSRC == 0 || rsid == "" { + return + } + + t.lock.Lock() + trackInfo := t.TrackInfoClone() + +done: + for _, ci := range trackInfo.Codecs { + for _, l := range ci.Layers { + if l.Ssrc == ssrc { + if (l.RepairSsrc != 0 && l.RepairSsrc != repairSSRC) || (l.Rid != "" && l.Rid != rsid) { + t.params.Logger.Warnw( + "not overriding rtx info", nil, + "ssrc", ssrc, + "repairSSRC", repairSSRC, + "existingRepairSSRC", l.RepairSsrc, + "rsid", rsid, + "existingRid", l.Rid, + "trackInfo", logger.Proto(trackInfo), + ) + } else { + l.RepairSsrc = repairSSRC + t.params.Logger.Debugw( + "set rtx info", + "ssrc", ssrc, + "repairSSRC", repairSSRC, + "rsid", rsid, + "trackInfo", logger.Proto(trackInfo), + ) + } + break done + } + } + } + + // backwards compatibility + for _, l := range trackInfo.Layers { + if l.Ssrc == ssrc { + if (l.RepairSsrc != 0 && l.RepairSsrc != repairSSRC) || (l.Rid != "" && l.Rid != rsid) { + t.params.Logger.Warnw( + "not overriding rtx info", nil, + "ssrc", ssrc, + "repairSSRC", repairSSRC, + "existingRepairSSRC", l.RepairSsrc, + "rsid", rsid, + "existingRid", l.Rid, + "trackInfo", logger.Proto(trackInfo), + ) + } else { + l.RepairSsrc = repairSSRC + t.params.Logger.Debugw( + "set rtx info", + "ssrc", ssrc, + "repairSSRC", repairSSRC, + "rsid", rsid, + "trackInfo", logger.Proto(trackInfo), + ) + } + break + } + } + + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + // change not propagated as it is internal +} + +func (t *MediaTrackReceiver) UpdateCodecInfo(codecs []*livekit.SimulcastCodec) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + for _, c := range codecs { + for _, origin := range trackInfo.Codecs { + if mime.GetMimeTypeCodec(origin.MimeType) == mime.NormalizeMimeTypeCodec(c.Codec) { + origin.Cid = c.Cid + + if len(c.Layers) != 0 { + clonedLayers := make([]*livekit.VideoLayer, 0, len(c.Layers)) + for _, l := range c.Layers { + clonedLayers = append(clonedLayers, utils.CloneProto(l)) + } + origin.Layers = clonedLayers + + mimeType := mime.NormalizeMimeType(origin.MimeType) + for _, layer := range origin.Layers { + layer.SpatialLayer = buffer.VideoQualityToSpatialLayer(mimeType, layer.Quality, trackInfo) + layer.Rid = buffer.VideoQualityToRid(mimeType, layer.Quality, trackInfo, buffer.DefaultVideoLayersRid) + } + } + + break + } + } + } + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() +} + +func (t *MediaTrackReceiver) UpdateCodecSdpCid(mimeType mime.MimeType, sdpCid string) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + for _, origin := range trackInfo.Codecs { + if mime.NormalizeMimeType(origin.MimeType) == mimeType { + if sdpCid != origin.Cid { + origin.SdpCid = sdpCid + } + } + } + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() +} + +func (t *MediaTrackReceiver) UpdateCodecRids(mimeType mime.MimeType, rids buffer.VideoLayersRid) { + t.lock.Lock() + trackInfo := t.TrackInfoClone() + for _, origin := range trackInfo.Codecs { + originMimeType := mime.NormalizeMimeType(origin.MimeType) + if originMimeType != mimeType { + continue + } + + for _, layer := range origin.Layers { + layer.SpatialLayer = buffer.VideoQualityToSpatialLayer(mimeType, layer.Quality, trackInfo) + layer.Rid = buffer.VideoQualityToRid(mimeType, layer.Quality, trackInfo, rids) + } + break + } + t.trackInfo.Store(trackInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() +} + +func (t *MediaTrackReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { + updateMute := false + clonedInfo := utils.CloneProto(ti) + + t.lock.Lock() + trackInfo := t.TrackInfo() + // patch Mid/Rid and Ssrc/RtxSsrc of codecs/layers by keeping original if available + for i, ci := range clonedInfo.Codecs { + for _, originCi := range trackInfo.Codecs { + if !mime.IsMimeTypeStringEqual(ci.MimeType, originCi.MimeType) { + continue + } + + if originCi.Mid != "" { + ci.Mid = originCi.Mid + } + + for _, layer := range ci.Layers { + for _, originLayer := range originCi.Layers { + if layer.Quality == originLayer.Quality { + if originLayer.Ssrc != 0 { + layer.Ssrc = originLayer.Ssrc + } + if originLayer.Rid != "" { + layer.Rid = originLayer.Rid + } + + if originLayer.RepairSsrc != 0 { + layer.RepairSsrc = originLayer.RepairSsrc + } + break + } + } + } + break + } + + // for clients that don't use simulcast codecs (old client version or single codec) + if i == 0 { + clonedInfo.Layers = ci.Layers + } + } + if trackInfo.Muted != clonedInfo.Muted { + updateMute = true + } + t.trackInfo.Store(clonedInfo) + t.lock.Unlock() + + if updateMute { + t.SetMuted(clonedInfo.Muted) + } + + t.updateTrackInfoOfReceivers() +} + +func (t *MediaTrackReceiver) UpdateAudioTrack(update *livekit.UpdateLocalAudioTrack) { + if t.Kind() != livekit.TrackType_AUDIO { + return + } + + t.lock.Lock() + trackInfo := t.TrackInfo() + clonedInfo := utils.CloneProto(trackInfo) + + clonedInfo.AudioFeatures = sutils.DedupeSlice(update.Features) + + clonedInfo.Stereo = false + clonedInfo.DisableDtx = false + for _, feature := range update.Features { + switch feature { + case livekit.AudioTrackFeature_TF_STEREO: + clonedInfo.Stereo = true + case livekit.AudioTrackFeature_TF_NO_DTX: + clonedInfo.DisableDtx = true + } + } + + if proto.Equal(trackInfo, clonedInfo) { + t.lock.Unlock() + return + } + + t.trackInfo.Store(clonedInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() + + t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), clonedInfo) + t.params.Logger.Debugw("updated audio track", "before", logger.Proto(trackInfo), "after", logger.Proto(clonedInfo)) +} + +func (t *MediaTrackReceiver) UpdateVideoTrack(update *livekit.UpdateLocalVideoTrack) { + if t.Kind() != livekit.TrackType_VIDEO { + return + } + + t.lock.Lock() + trackInfo := t.TrackInfo() + clonedInfo := utils.CloneProto(trackInfo) + clonedInfo.Width = update.Width + clonedInfo.Height = update.Height + if proto.Equal(trackInfo, clonedInfo) { + t.lock.Unlock() + return + } + + t.trackInfo.Store(clonedInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() + + t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), clonedInfo) + t.params.Logger.Debugw("updated video track", "before", logger.Proto(trackInfo), "after", logger.Proto(clonedInfo)) +} + +func (t *MediaTrackReceiver) UpdateVideoSize(mimeType mime.MimeType, sizes []buffer.VideoSize) { + var changed bool + t.lock.Lock() + trackInfo := t.TrackInfo() + clonedInfo := utils.CloneProto(trackInfo) + var maxWidth, maxHeight uint32 + for _, size := range sizes { + if size.Width > maxWidth { + maxWidth = size.Width + maxHeight = size.Height + } + } + + if clonedInfo.Width != maxWidth || clonedInfo.Height != maxHeight { + clonedInfo.Width = maxWidth + clonedInfo.Height = maxHeight + changed = true + } + + for _, c := range clonedInfo.Codecs { + if mime.NormalizeMimeType(c.MimeType) == mimeType { + for i, l := range c.Layers { + if i < len(sizes) && (sizes[i].Width != 0 || sizes[i].Height != 0) && + (l.Width != sizes[i].Width || l.Height != sizes[i].Height) { + l.Width = sizes[i].Width + l.Height = sizes[i].Height + changed = true + } + } + } + } + + if !changed { + t.lock.Unlock() + return + } + + t.trackInfo.Store(clonedInfo) + t.lock.Unlock() + + t.updateTrackInfoOfReceivers() + + t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), clonedInfo) + t.params.Logger.Debugw("updated video sizes", "before", logger.Proto(trackInfo), "after", logger.Proto(clonedInfo)) +} + +func (t *MediaTrackReceiver) TrackInfo() *livekit.TrackInfo { + return t.trackInfo.Load() +} + +func (t *MediaTrackReceiver) TrackInfoClone() *livekit.TrackInfo { + return utils.CloneProto(t.TrackInfo()) +} + +func (t *MediaTrackReceiver) NotifyMaxLayerChange(mimeType mime.MimeType, maxLayer int32) { + trackInfo := t.TrackInfo() + quality := buffer.GetVideoQualityForSpatialLayer(mimeType, maxLayer, trackInfo) + ti := &livekit.TrackInfo{ + Sid: trackInfo.Sid, + Type: trackInfo.Type, + Layers: []*livekit.VideoLayer{{Quality: quality}}, + } + if quality != livekit.VideoQuality_OFF { + layers := buffer.GetVideoLayersForMimeType(mimeType, trackInfo) + for _, layer := range layers { + if layer.Quality == quality { + ti.Layers[0].Width = layer.Width + ti.Layers[0].Height = layer.Height + break + } + } + } + + t.params.Telemetry.TrackPublishedUpdate(context.Background(), t.PublisherID(), ti) +} + +// GetQualityForDimension finds the closest quality to use for desired dimensions +// affords a 20% tolerance on dimension +func (t *MediaTrackReceiver) GetQualityForDimension(mimeType mime.MimeType, width, height uint32) livekit.VideoQuality { + quality := livekit.VideoQuality_HIGH + if t.Kind() == livekit.TrackType_AUDIO { + return quality + } + + trackInfo := t.TrackInfo() + + var mediaSizes []buffer.VideoSize + if receiver := t.Receiver(mimeType); receiver != nil { + mediaSizes = receiver.VideoSizes() + } + + if trackInfo.Height == 0 && len(mediaSizes) == 0 { + return quality + } + origSize := trackInfo.Height + requestedSize := height + if trackInfo.Width < trackInfo.Height { + // for portrait videos + origSize = trackInfo.Width + requestedSize = width + } + + if origSize == 0 { + for i := len(mediaSizes) - 1; i >= 0; i-- { + if mediaSizes[i].Height > 0 { + origSize = mediaSizes[i].Height + if mediaSizes[i].Width < mediaSizes[i].Height { + origSize = mediaSizes[i].Width + } + break + } + } + } + + // default sizes representing qualities low - high + layerSizes := []uint32{180, 360, origSize} + var providedSizes []uint32 + for _, layer := range buffer.GetVideoLayersForMimeType(mimeType, trackInfo) { + providedSizes = append(providedSizes, layer.Height) + } + + if len(providedSizes) == 0 || providedSizes[0] == 0 || t.params.PreferVideoSizeFromMedia { + if len(mediaSizes) > 0 { + providedSizes = providedSizes[:0] + for _, size := range mediaSizes { + providedSizes = append(providedSizes, size.Height) + } + } else { + t.params.Logger.Debugw("no video sizes provided by receiver, using track info sizes") + } + } + + if len(providedSizes) > 0 { + layerSizes = providedSizes + // comparing height always + requestedSize = height + sort.Slice(layerSizes, func(i, j int) bool { + return layerSizes[i] < layerSizes[j] + }) + } + + // finds the highest layer with smallest dimensions that still satisfy client demands + requestedSize = uint32(float32(requestedSize) * layerSelectionTolerance) + for i, s := range layerSizes { + quality = livekit.VideoQuality(i) + if i == len(layerSizes)-1 { + break + } else if s >= requestedSize && s != layerSizes[i+1] { + break + } + } + + return quality +} + +func (t *MediaTrackReceiver) GetAudioLevel() (float64, bool) { + receiver := t.ActiveReceiver() + if receiver == nil { + return 0, false + } + + return receiver.GetAudioLevel() +} + +func (t *MediaTrackReceiver) onDownTrackCreated(downTrack *sfu.DownTrack) { + if t.Kind() == livekit.TrackType_AUDIO { + downTrack.AddReceiverReportListener(func(dt *sfu.DownTrack, rr *rtcp.ReceiverReport) { + if t.onMediaLossFeedback != nil { + t.onMediaLossFeedback(dt, rr) + } + }) + } +} + +func (t *MediaTrackReceiver) DebugInfo() map[string]any { + info := map[string]any{ + "ID": t.ID(), + "Kind": t.Kind().String(), + "PubMuted": t.IsMuted(), + } + + info["DownTracks"] = t.MediaTrackSubscriptions.DebugInfo() + + for _, receiver := range t.loadReceivers() { + info[receiver.Codec().MimeType] = receiver.DebugInfo() + } + + return info +} + +func (t *MediaTrackReceiver) PrimaryReceiver() sfu.TrackReceiver { + receivers := t.loadReceivers() + if len(receivers) == 0 { + return nil + } + if dr, ok := receivers[0].TrackReceiver.(*DummyReceiver); ok { + return dr.Receiver() + } + return receivers[0].TrackReceiver +} + +func (t *MediaTrackReceiver) ActiveReceiver() sfu.TrackReceiver { + for _, r := range t.loadReceivers() { + if r.IsRegressed() { + return r.TrackReceiver + } + } + + return t.PrimaryReceiver() +} + +func (t *MediaTrackReceiver) Receiver(mime mime.MimeType) sfu.TrackReceiver { + for _, r := range t.loadReceivers() { + if r.Mime() == mime { + if dr, ok := r.TrackReceiver.(*DummyReceiver); ok { + return dr.Receiver() + } + return r.TrackReceiver + } + } + return nil +} + +func (t *MediaTrackReceiver) Receivers() []sfu.TrackReceiver { + receivers := t.loadReceivers() + trackReceivers := make([]sfu.TrackReceiver, len(receivers)) + for i, r := range receivers { + trackReceivers[i] = r.TrackReceiver + } + return trackReceivers +} + +func (t *MediaTrackReceiver) loadReceivers() []*simulcastReceiver { + t.lock.RLock() + defer t.lock.RUnlock() + return t.receivers +} + +func (t *MediaTrackReceiver) SetRTT(rtt uint32) { + for _, r := range t.loadReceivers() { + if wr, ok := r.TrackReceiver.(*sfu.WebRTCReceiver); ok { + wr.SetRTT(rtt) + } + } +} + +func (t *MediaTrackReceiver) GetTemporalLayerForSpatialFps(mimeType mime.MimeType, spatial int32, fps uint32) int32 { + receiver := t.Receiver(mimeType) + if receiver == nil { + return buffer.DefaultMaxLayerTemporal + } + + layerFps := receiver.GetTemporalLayerFpsForSpatial(spatial) + requestFps := float32(fps) * layerSelectionTolerance + for i, f := range layerFps { + if requestFps <= f { + return int32(i) + } + } + return buffer.DefaultMaxLayerTemporal +} + +func (t *MediaTrackReceiver) GetTrackStats() *livekit.RTPStats { + receivers := t.loadReceivers() + stats := make([]*livekit.RTPStats, 0, len(receivers)) + for _, receiver := range receivers { + receiverStats := receiver.GetTrackStats() + if receiverStats != nil { + stats = append(stats, receiverStats) + } + } + + return rtpstats.AggregateRTPStats(stats) +} diff --git a/livekit/pkg/rtc/mediatracksubscriptions.go b/livekit/pkg/rtc/mediatracksubscriptions.go new file mode 100644 index 0000000..fc529b0 --- /dev/null +++ b/livekit/pkg/rtc/mediatracksubscriptions.go @@ -0,0 +1,367 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "errors" + "slices" + "sync" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +var ( + errAlreadySubscribed = errors.New("already subscribed") + errNotFound = errors.New("not found") +) + +// MediaTrackSubscriptions manages subscriptions of a media track +type MediaTrackSubscriptions struct { + params MediaTrackSubscriptionsParams + + subscribedTracksMu sync.RWMutex + subscribedTracks map[livekit.ParticipantID]types.SubscribedTrack + + onDownTrackCreated func(downTrack *sfu.DownTrack) + onSubscriberMaxQualityChange func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32) + onSubscriberAudioCodecChange func(subscriberID livekit.ParticipantID, mime mime.MimeType, enabled bool) +} + +type MediaTrackSubscriptionsParams struct { + MediaTrack types.MediaTrack + IsRelayed bool + + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + + Telemetry telemetry.TelemetryService + + Logger logger.Logger +} + +func NewMediaTrackSubscriptions(params MediaTrackSubscriptionsParams) *MediaTrackSubscriptions { + return &MediaTrackSubscriptions{ + params: params, + subscribedTracks: make(map[livekit.ParticipantID]types.SubscribedTrack), + } +} + +func (t *MediaTrackSubscriptions) OnDownTrackCreated(f func(downTrack *sfu.DownTrack)) { + t.onDownTrackCreated = f +} + +func (t *MediaTrackSubscriptions) OnSubscriberMaxQualityChange(f func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32)) { + t.onSubscriberMaxQualityChange = f +} + +func (t *MediaTrackSubscriptions) OnSubscriberAudioCodecChange(f func(subscriberID livekit.ParticipantID, mime mime.MimeType, enabled bool)) { + t.onSubscriberAudioCodecChange = f +} + +func (t *MediaTrackSubscriptions) SetMuted(muted bool) { + // update mute of all subscribed tracks + for _, st := range t.getAllSubscribedTracks() { + st.SetPublisherMuted(muted) + } +} + +func (t *MediaTrackSubscriptions) IsSubscriber(subID livekit.ParticipantID) bool { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + _, ok := t.subscribedTracks[subID] + return ok +} + +// AddSubscriber subscribes sub to current mediaTrack +func (t *MediaTrackSubscriptions) AddSubscriber(sub types.LocalParticipant, wr *WrappedReceiver) (types.SubscribedTrack, error) { + trackID := t.params.MediaTrack.ID() + subscriberID := sub.ID() + + // don't subscribe to the same track multiple times + t.subscribedTracksMu.Lock() + if _, ok := t.subscribedTracks[subscriberID]; ok { + t.subscribedTracksMu.Unlock() + return nil, errAlreadySubscribed + } + t.subscribedTracksMu.Unlock() + + subTrack, err := NewSubscribedTrack(SubscribedTrackParams{ + ReceiverConfig: t.params.ReceiverConfig, + SubscriberConfig: t.params.SubscriberConfig, + Subscriber: sub, + MediaTrack: t.params.MediaTrack, + AdaptiveStream: sub.GetAdaptiveStream(), + Telemetry: t.params.Telemetry, + WrappedReceiver: wr, + IsRelayed: t.params.IsRelayed, + OnDownTrackCreated: t.onDownTrackCreated, + OnDownTrackClosed: func(subscriberID livekit.ParticipantID) { + t.subscribedTracksMu.Lock() + delete(t.subscribedTracks, subscriberID) + t.subscribedTracksMu.Unlock() + }, + OnSubscriberMaxQualityChange: t.onSubscriberMaxQualityChange, + OnSubscriberAudioCodecChange: t.onSubscriberAudioCodecChange, + }) + if err != nil { + return nil, err + } + + // Bind callback can happen from replaceTrack, so set it up early + var reusingTransceiver atomic.Bool + var dtState sfu.DownTrackState + downTrack := subTrack.DownTrack() + downTrack.OnBinding(func(err error) { + if err != nil { + go subTrack.Bound(err) + return + } + if reusingTransceiver.Load() { + sub.GetLogger().Debugw("seeding downtrack state", "trackID", trackID) + downTrack.SeedState(dtState) + } + if err = wr.AddDownTrack(downTrack); err != nil && err != sfu.ErrReceiverClosed { + sub.GetLogger().Errorw( + "could not add down track", err, + "publisher", subTrack.PublisherIdentity(), + "publisherID", subTrack.PublisherID(), + "trackID", trackID, + ) + } + + go subTrack.Bound(nil) + + subTrack.SetPublisherMuted(t.params.MediaTrack.IsMuted()) + }) + + var transceiver *webrtc.RTPTransceiver + var sender *webrtc.RTPSender + + // try cached RTP senders for a chance to replace track + var existingTransceiver *webrtc.RTPTransceiver + replacedTrack := false + existingTransceiver, dtState = sub.GetCachedDownTrack(trackID) + if existingTransceiver != nil { + sub.GetLogger().Debugw( + "trying to use existing transceiver", + "publisher", subTrack.PublisherIdentity(), + "publisherID", subTrack.PublisherID(), + "trackID", trackID, + ) + reusingTransceiver.Store(true) + rtpSender := existingTransceiver.Sender() + if rtpSender != nil { + // replaced track will bind immediately without negotiation, SetTransceiver first before bind + downTrack.SetTransceiver(existingTransceiver) + err := rtpSender.ReplaceTrack(downTrack) + if err == nil { + sender = rtpSender + transceiver = existingTransceiver + replacedTrack = true + sub.GetLogger().Debugw( + "track replaced", + "publisher", subTrack.PublisherIdentity(), + "publisherID", subTrack.PublisherID(), + "trackID", trackID, + ) + } + } + + if !replacedTrack { + // Could not re-use cached transceiver for this track. + // Stop the transceiver so that it is at least not active. + // It is not usable once stopped, + // + // Adding down track will create a new transceiver (or re-use + // an inactive existing one). In either case, a renegotiation + // will happen and that will notify remote of this stopped + // transceiver + existingTransceiver.Stop() + reusingTransceiver.Store(false) + } + } + + // if cannot replace, find an unused transceiver or add new one + if transceiver == nil { + info := t.params.MediaTrack.ToProto() + addTrackParams := types.AddTrackParams{ + Stereo: slices.Contains(info.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO), + Red: IsRedEnabled(info), + } + codecs := wr.Codecs() + if addTrackParams.Red && (len(codecs) == 1 && mime.IsMimeTypeStringOpus(codecs[0].MimeType)) { + addTrackParams.Red = false + } + + sub.VerifySubscribeParticipantInfo(subTrack.PublisherID(), subTrack.PublisherVersion()) + if sub.SupportsTransceiverReuse() { + // + // AddTrack will create a new transceiver or re-use an unused one + // if the attributes match. This prevents SDP from bloating + // because of dormant transceivers building up. + // + sender, transceiver, err = sub.AddTrackLocal(downTrack, addTrackParams) + if err != nil { + return nil, err + } + } else { + sender, transceiver, err = sub.AddTransceiverFromTrackLocal(downTrack, addTrackParams) + if err != nil { + return nil, err + } + } + } + + // whether re-using or stopping remove transceiver from cache + // NOTE: safety net, if somehow a cached transceiver is re-used by a different track + sub.UncacheDownTrack(transceiver) + + // negotiation isn't required if we've replaced track + // ONE-SHOT-SIGNALLING-MODE: this should not be needed, but that mode information is not available here, + // but it is not detrimental to set this, needs clean up when participants modes are separated out better. + subTrack.SetNeedsNegotiation(!replacedTrack) + subTrack.SetRTPSender(sender) + + // it is possible that subscribed track is closed before subscription manager sets + // the `OnClose` callback. That handler in subscription manager removes the track + // from the peer connection. + // + // But, the subscription could be removed early if the published track is closed + // while adding subscription. In those cases, subscription manager would not have set + // the `OnClose` callback. So, set it here to handle cases of early close. + subTrack.OnClose(func(isExpectedToResume bool) { + if !isExpectedToResume { + if err := sub.RemoveTrackLocal(sender); err != nil { + t.params.Logger.Warnw("could not remove track from peer connection", err) + } + } + }) + + downTrack.SetTransceiver(transceiver) + + t.subscribedTracksMu.Lock() + t.subscribedTracks[subscriberID] = subTrack + t.subscribedTracksMu.Unlock() + + return subTrack, nil +} + +// RemoveSubscriber removes participant from subscription +// stop all forwarders to the client +func (t *MediaTrackSubscriptions) RemoveSubscriber(subscriberID livekit.ParticipantID, isExpectedToResume bool) error { + subTrack := t.getSubscribedTrack(subscriberID) + if subTrack == nil { + return errNotFound + } + + t.params.Logger.Debugw("removing subscriber", "subscriberID", subscriberID, "isExpectedToResume", isExpectedToResume) + t.closeSubscribedTrack(subTrack, isExpectedToResume) + return nil +} + +func (t *MediaTrackSubscriptions) closeSubscribedTrack(subTrack types.SubscribedTrack, isExpectedToResume bool) { + dt := subTrack.DownTrack() + if dt == nil { + return + } + + if isExpectedToResume { + dt.CloseWithFlush(false, false) + } else { + // flushing blocks, avoid blocking when publisher removes all its subscribers + go dt.CloseWithFlush(true, true) + } +} + +func (t *MediaTrackSubscriptions) GetAllSubscribers() []livekit.ParticipantID { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks)) + for id := range t.subscribedTracks { + subs = append(subs, id) + } + return subs +} + +func (t *MediaTrackSubscriptions) GetAllSubscribersForMime(mime mime.MimeType) []livekit.ParticipantID { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + subs := make([]livekit.ParticipantID, 0, len(t.subscribedTracks)) + for id, subTrack := range t.subscribedTracks { + if subTrack.DownTrack().Mime() != mime { + continue + } + + subs = append(subs, id) + } + return subs +} + +func (t *MediaTrackSubscriptions) GetNumSubscribers() int { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + return len(t.subscribedTracks) +} + +func (t *MediaTrackSubscriptions) UpdateVideoLayers() { + for _, st := range t.getAllSubscribedTracks() { + st.UpdateVideoLayer() + } +} + +func (t *MediaTrackSubscriptions) getSubscribedTrack(subscriberID livekit.ParticipantID) types.SubscribedTrack { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + return t.subscribedTracks[subscriberID] +} + +func (t *MediaTrackSubscriptions) getAllSubscribedTracks() []types.SubscribedTrack { + t.subscribedTracksMu.RLock() + defer t.subscribedTracksMu.RUnlock() + + return t.getAllSubscribedTracksLocked() +} + +func (t *MediaTrackSubscriptions) getAllSubscribedTracksLocked() []types.SubscribedTrack { + subTracks := make([]types.SubscribedTrack, 0, len(t.subscribedTracks)) + for _, subTrack := range t.subscribedTracks { + subTracks = append(subTracks, subTrack) + } + return subTracks +} + +func (t *MediaTrackSubscriptions) DebugInfo() []map[string]any { + subscribedTrackInfo := make([]map[string]any, 0) + for _, val := range t.getAllSubscribedTracks() { + if st, ok := val.(*SubscribedTrack); ok { + subscribedTrackInfo = append(subscribedTrackInfo, st.DownTrack().DebugInfo()) + } + } + + return subscribedTrackInfo +} diff --git a/livekit/pkg/rtc/migrationdatacache.go b/livekit/pkg/rtc/migrationdatacache.go new file mode 100644 index 0000000..8c44eac --- /dev/null +++ b/livekit/pkg/rtc/migrationdatacache.go @@ -0,0 +1,59 @@ +package rtc + +import ( + "time" + + "github.com/livekit/protocol/livekit" +) + +type MigrationDataCacheState int + +const ( + MigrationDataCacheStateWaiting MigrationDataCacheState = iota + MigrationDataCacheStateTimeout + MigrationDataCacheStateDone +) + +type MigrationDataCache struct { + lastSeq uint32 + pkts []*livekit.DataPacket + state MigrationDataCacheState + expiredAt time.Time +} + +func NewMigrationDataCache(lastSeq uint32, expiredAt time.Time) *MigrationDataCache { + return &MigrationDataCache{ + lastSeq: lastSeq, + expiredAt: expiredAt, + } +} + +// Add adds a message to the cache if there is a gap between the last sequence number and cached messages then return the cache State: +// - MigrationDataCacheStateWaiting: waiting for the next packet (lastSeq + 1) of last sequence from old node +// - MigrationDataCacheStateTimeout: the next packet is not received before the expiredAt, participant will +// continue to process the reliable messages, subscribers will see the gap after the publisher migration +// - MigrationDataCacheStateDone: the next packet is received, participant can continue to process the reliable messages +func (c *MigrationDataCache) Add(pkt *livekit.DataPacket) MigrationDataCacheState { + if c.state == MigrationDataCacheStateDone || c.state == MigrationDataCacheStateTimeout { + return c.state + } + + if pkt.Sequence <= c.lastSeq { + return c.state + } + + if pkt.Sequence == c.lastSeq+1 { + c.state = MigrationDataCacheStateDone + return c.state + } + + c.pkts = append(c.pkts, pkt) + if time.Now().After(c.expiredAt) { + c.state = MigrationDataCacheStateTimeout + } + return c.state +} + +func (c *MigrationDataCache) Get() []*livekit.DataPacket { + return c.pkts +} diff --git a/livekit/pkg/rtc/migrationdatacache_test.go b/livekit/pkg/rtc/migrationdatacache_test.go new file mode 100644 index 0000000..046b2dd --- /dev/null +++ b/livekit/pkg/rtc/migrationdatacache_test.go @@ -0,0 +1,38 @@ +package rtc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" +) + +func TestMigrationDataCache_Add(t *testing.T) { + expiredAt := time.Now().Add(100 * time.Millisecond) + cache := NewMigrationDataCache(10, expiredAt) + + pkt1 := &livekit.DataPacket{Sequence: 9} + state := cache.Add(pkt1) + require.Equal(t, MigrationDataCacheStateWaiting, state) + require.Empty(t, cache.Get()) + + pkt2 := &livekit.DataPacket{Sequence: 11} + state = cache.Add(pkt2) + require.Equal(t, MigrationDataCacheStateDone, state) + require.Empty(t, cache.Get()) + + pkt3 := &livekit.DataPacket{Sequence: 12} + state = cache.Add(pkt3) + require.Equal(t, MigrationDataCacheStateDone, state) + require.Empty(t, cache.Get()) + + cache2 := NewMigrationDataCache(20, time.Now().Add(10*time.Millisecond)) + pkt4 := &livekit.DataPacket{Sequence: 22} + time.Sleep(20 * time.Millisecond) + state = cache2.Add(pkt4) + require.Equal(t, MigrationDataCacheStateTimeout, state) + require.Len(t, cache2.Get(), 1) + require.Equal(t, uint32(22), cache2.Get()[0].Sequence) +} diff --git a/livekit/pkg/rtc/participant.go b/livekit/pkg/rtc/participant.go new file mode 100644 index 0000000..a436098 --- /dev/null +++ b/livekit/pkg/rtc/participant.go @@ -0,0 +1,4221 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "fmt" + "io" + "math/rand" + "os" + "slices" + "strings" + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/google/uuid" + lru "github.com/hashicorp/golang-lru/v2" + "github.com/pion/rtcp" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "github.com/pkg/errors" + "go.uber.org/atomic" + "go.uber.org/zap/zapcore" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" + + "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability" + "github.com/livekit/protocol/observability/roomobs" + protosdp "github.com/livekit/protocol/sdp" + protosignalling "github.com/livekit/protocol/signalling" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/utils/pointer" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/metric" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/signalling" + "github.com/livekit/livekit-server/pkg/rtc/supervisor" + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/interceptor" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + sutils "github.com/livekit/livekit-server/pkg/utils" +) + +var _ types.LocalParticipant = (*ParticipantImpl)(nil) + +const ( + sdBatchSize = 30 + rttUpdateInterval = 5 * time.Second + + disconnectCleanupDuration = 5 * time.Second + migrationWaitDuration = 3 * time.Second + migrationWaitContinuousMsgDuration = 2 * time.Second + + PingIntervalSeconds = 5 + PingTimeoutSeconds = 15 +) + +var ( + ErrMoveOldClientVersion = errors.New("participant client version does not support moving") +) + +// ------------------------------------------------- + +type pendingTrackInfo struct { + trackInfos []*livekit.TrackInfo + sdpRids buffer.VideoLayersRid + migrated bool + createdAt time.Time + + // indicates if this track is queued for publishing to avoid a track has been published + // before the previous track is unpublished(closed) because client is allowed to negotiate + // webrtc track before AddTrackRequest return to speed up the publishing process + queued bool +} + +func (p *pendingTrackInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + if p == nil { + return nil + } + + e.AddArray("trackInfos", logger.ProtoSlice(p.trackInfos)) + e.AddArray("sdpRids", logger.StringSlice(p.sdpRids[:])) + e.AddBool("migrated", p.migrated) + e.AddTime("createdAt", p.createdAt) + e.AddBool("queued", p.queued) + return nil +} + +// -------------------------------------------------- + +type pendingRemoteTrack struct { + track *webrtc.TrackRemote + receiver *webrtc.RTPReceiver +} + +type downTrackState struct { + transceiver *webrtc.RTPTransceiver + downTrack sfu.DownTrackState +} + +type postRtcpOp struct { + *ParticipantImpl + pkts []rtcp.Packet +} + +// --------------------------------------------------------------- + +type participantUpdateInfo struct { + identity livekit.ParticipantIdentity + version uint32 + state livekit.ParticipantInfo_State + updatedAt time.Time +} + +func (p participantUpdateInfo) String() string { + return fmt.Sprintf("identity: %s, version: %d, state: %s, updatedAt: %s", p.identity, p.version, p.state.String(), p.updatedAt.String()) +} + +type reliableDataInfo struct { + joiningMessageLock sync.Mutex + joiningMessageFirstSeqs map[livekit.ParticipantID]uint32 + joiningMessageLastWrittenSeqs map[livekit.ParticipantID]uint32 + lastPubReliableSeq atomic.Uint32 + stopReliableByMigrateOut atomic.Bool + canWriteReliable bool + migrateInPubDataCache atomic.Pointer[MigrationDataCache] +} + +// --------------------------------------------------------------- + +var _ types.LocalParticipant = (*ParticipantImpl)(nil) + +type ParticipantParams struct { + Identity livekit.ParticipantIdentity + Name livekit.ParticipantName + SID livekit.ParticipantID + Config *WebRTCConfig + Sink routing.MessageSink + AudioConfig sfu.AudioConfig + VideoConfig config.VideoConfig + LimitConfig config.LimitConfig + ProtocolVersion types.ProtocolVersion + SessionStartTime time.Time + Telemetry telemetry.TelemetryService + Trailer []byte + PLIThrottleConfig sfu.PLIThrottleConfig + CongestionControlConfig config.CongestionControlConfig + // codecs that are enabled for this room + PublishEnabledCodecs []*livekit.Codec + SubscribeEnabledCodecs []*livekit.Codec + Logger logger.Logger + LoggerResolver logger.DeferredFieldResolver + Reporter roomobs.ParticipantSessionReporter + ReporterResolver roomobs.ParticipantReporterResolver + SimTracks map[uint32]interceptor.SimulcastTrackInfo + Grants *auth.ClaimGrants + InitialVersion uint32 + ClientConf *livekit.ClientConfiguration + ClientInfo ClientInfo + Region string + Migration bool + Reconnect bool + AdaptiveStream bool + AllowTCPFallback bool + TCPFallbackRTTThreshold int + AllowUDPUnstableFallback bool + TURNSEnabled bool + ParticipantListener types.LocalParticipantListener + ParticipantHelper types.LocalParticipantHelper + DisableSupervisor bool + ReconnectOnPublicationError bool + ReconnectOnSubscriptionError bool + ReconnectOnDataChannelError bool + VersionGenerator utils.TimedVersionGenerator + DisableDynacast bool + SubscriberAllowPause bool + SubscriptionLimitAudio int32 + SubscriptionLimitVideo int32 + PlayoutDelay *livekit.PlayoutDelay + SyncStreams bool + ForwardStats *sfu.ForwardStats + DisableSenderReportPassThrough bool + MetricConfig metric.MetricConfig + UseOneShotSignallingMode bool + EnableMetrics bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int + DatachannelLossyTargetLatency time.Duration + FireOnTrackBySdp bool + DisableCodecRegression bool + LastPubReliableSeq uint32 + Country string + PreferVideoSizeFromMedia bool + UseSinglePeerConnection bool + EnableDataTracks bool + EnableRTPStreamRestartDetection bool + ForceBackupCodecPolicySimulcast bool +} + +type ParticipantImpl struct { + // utils.TimedVersion is a atomic. To be correctly aligned also on 32bit archs + // 64it atomics need to be at the front of a struct + timedVersion utils.TimedVersion + + params ParticipantParams + + participantListener atomic.Pointer[types.LocalParticipantListener] + participantHelper atomic.Value // types.LocalParticipantHelper + id atomic.Value // types.ParticipantID + + isClosed atomic.Bool + closeReason atomic.Value // types.ParticipantCloseReason + + state atomic.Value // livekit.ParticipantInfo_State + disconnected chan struct{} + + grants atomic.Pointer[auth.ClaimGrants] + isPublisher atomic.Bool + + sessionStartRecorded atomic.Bool + lastActiveAt atomic.Pointer[time.Time] + // when first connected + connectedAt time.Time + disconnectedAt atomic.Pointer[time.Time] + // timer that's set when disconnect is detected on primary PC + disconnectTimer *time.Timer + migrationTimer *time.Timer + + pubRTCPQueue *sutils.TypedOpsQueue[postRtcpOp] + + // hold reference for MediaTrack + twcc *twcc.Responder + + // client intended to publish, yet to be reconciled + pendingTracksLock utils.RWMutex + pendingTracks map[string]*pendingTrackInfo + pendingPublishingTracks map[livekit.TrackID]*pendingTrackInfo + pendingRemoteTracks []*pendingRemoteTrack + + // supported codecs + enabledPublishCodecs []*livekit.Codec + enabledSubscribeCodecs []*livekit.Codec + + *TransportManager + *UpTrackManager + *UpDataTrackManager + *SubscriptionManager + + nextSubscribedDataTrackHandle uint16 + + icQueue [2]atomic.Pointer[webrtc.ICECandidate] + + requireBroadcast bool + // queued participant updates before join response is sent + // guarded by updateLock + queuedUpdates []*livekit.ParticipantInfo + // cache of recently sent updates, to ensure ordering by version + // guarded by updateLock + updateCache *lru.Cache[livekit.ParticipantID, participantUpdateInfo] + updateLock utils.Mutex + + dataChannelStats *telemetry.BytesTrackStats + + reliableDataInfo reliableDataInfo + + rttUpdatedAt time.Time + lastRTT uint32 + + // idempotent reference guard for telemetry stats worker + telemetryGuard *telemetry.ReferenceGuard + + lock utils.RWMutex + + dirty atomic.Bool + version atomic.Uint32 + + migrateState atomic.Value // types.MigrateState + migratedTracksPublishedFuse core.Fuse + + onClose map[string]func(types.LocalParticipant) + onClaimsChanged func(participant types.LocalParticipant) + onICEConfigChanged func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig) + + cachedDownTracks map[livekit.TrackID]*downTrackState + forwarderState map[livekit.TrackID]*livekit.RTPForwarderState + + supervisor *supervisor.ParticipantSupervisor + + connectionQuality livekit.ConnectionQuality + + metricTimestamper *metric.MetricTimestamper + metricsCollector *metric.MetricsCollector + metricsReporter *metric.MetricsReporter + + signalling signalling.ParticipantSignalling + signalHandler signalling.ParticipantSignalHandler + signaller signalling.ParticipantSignaller + + // loggers for publisher and subscriber + pubLogger logger.Logger + subLogger logger.Logger + + rpcLock sync.Mutex + rpcPendingAcks map[string]*utils.DataChannelRpcPendingAckHandler + rpcPendingResponses map[string]*utils.DataChannelRpcPendingResponseHandler +} + +func NewParticipant(params ParticipantParams) (*ParticipantImpl, error) { + if params.Identity == "" { + return nil, ErrEmptyIdentity + } + if params.SID == "" { + return nil, ErrEmptyParticipantID + } + if params.Grants == nil || params.Grants.Video == nil { + return nil, ErrMissingGrants + } + p := &ParticipantImpl{ + params: params, + disconnected: make(chan struct{}), + pubRTCPQueue: sutils.NewTypedOpsQueue[postRtcpOp](sutils.OpsQueueParams{ + Name: "pub-rtcp", + MinSize: 64, + Logger: params.Logger, + }), + pendingTracks: make(map[string]*pendingTrackInfo), + pendingPublishingTracks: make(map[livekit.TrackID]*pendingTrackInfo), + connectedAt: time.Now().Truncate(time.Millisecond), + rttUpdatedAt: time.Now(), + cachedDownTracks: make(map[livekit.TrackID]*downTrackState), + connectionQuality: livekit.ConnectionQuality_EXCELLENT, + pubLogger: params.Logger.WithComponent(sutils.ComponentPub), + subLogger: params.Logger.WithComponent(sutils.ComponentSub), + reliableDataInfo: reliableDataInfo{ + joiningMessageFirstSeqs: make(map[livekit.ParticipantID]uint32), + joiningMessageLastWrittenSeqs: make(map[livekit.ParticipantID]uint32), + }, + rpcPendingAcks: make(map[string]*utils.DataChannelRpcPendingAckHandler), + rpcPendingResponses: make(map[string]*utils.DataChannelRpcPendingResponseHandler), + onClose: make(map[string]func(types.LocalParticipant)), + telemetryGuard: &telemetry.ReferenceGuard{}, + nextSubscribedDataTrackHandle: uint16(rand.Intn(256)), + } + p.setupSignalling() + + p.id.Store(params.SID) + p.dataChannelStats = telemetry.NewBytesTrackStats( + p.params.Country, + telemetry.BytesTrackIDForParticipantID(telemetry.BytesTrackTypeData, p.ID()), + p.ID(), + params.Telemetry, + params.Reporter, + ) + p.reliableDataInfo.lastPubReliableSeq.Store(params.LastPubReliableSeq) + p.setListener(params.ParticipantListener) + p.participantHelper.Store(params.ParticipantHelper) + if !params.DisableSupervisor { + p.supervisor = supervisor.NewParticipantSupervisor(supervisor.ParticipantSupervisorParams{Logger: params.Logger}) + } + p.closeReason.Store(types.ParticipantCloseReasonNone) + p.version.Store(params.InitialVersion) + p.timedVersion.Update(params.VersionGenerator.Next()) + + p.migrateState.Store(types.MigrateStateInit) + + p.state.Store(livekit.ParticipantInfo_JOINING) + p.grants.Store(params.Grants.Clone()) + p.SwapResponseSink(params.Sink, types.SignallingCloseReasonUnknown) + p.setupEnabledCodecs(params.PublishEnabledCodecs, params.SubscribeEnabledCodecs, params.ClientConf.GetDisabledCodecs()) + + if p.supervisor != nil { + p.supervisor.OnPublicationError(p.onPublicationError) + } + + sessionTimer := observability.NewSessionTimer(p.params.SessionStartTime) + params.Reporter.RegisterFunc(func(ts time.Time, tx roomobs.ParticipantSessionTx) bool { + if dts := p.disconnectedAt.Load(); dts != nil { + ts = *dts + tx.ReportEndTime(ts) + } + + millis, mins := sessionTimer.Advance(ts) + tx.ReportDuration(uint16(millis)) + tx.ReportDurationMinutes(uint8(mins)) + + return !p.IsClosed() + }) + + var err error + // keep last participants and when updates were sent + if p.updateCache, err = lru.New[livekit.ParticipantID, participantUpdateInfo](128); err != nil { + return nil, err + } + + err = p.setupTransportManager() + if err != nil { + return nil, err + } + + p.setupUpTrackManager() + p.setupUpDataTrackManager() + p.setupSubscriptionManager() + p.setupMetrics() + + return p, nil +} + +func (p *ParticipantImpl) setListener(listener types.LocalParticipantListener) { + if listener == nil { + p.participantListener.Store(nil) + return + } + p.participantListener.Store(&listener) +} + +func (p *ParticipantImpl) listener() types.LocalParticipantListener { + if l := p.participantListener.Load(); l != nil { + return *l + } + return &types.NullLocalParticipantListener{} +} + +func (p *ParticipantImpl) GetParticipantListener() types.ParticipantListener { + return p.listener() +} + +func (p *ParticipantImpl) ClearParticipantListener() { + p.setListener(nil) +} + +func (p *ParticipantImpl) GetCountry() string { + return p.params.Country +} + +func (p *ParticipantImpl) GetTrailer() []byte { + trailer := make([]byte, len(p.params.Trailer)) + copy(trailer, p.params.Trailer) + return trailer +} + +func (p *ParticipantImpl) GetLogger() logger.Logger { + return p.params.Logger +} + +func (p *ParticipantImpl) GetLoggerResolver() logger.DeferredFieldResolver { + return p.params.LoggerResolver +} + +func (p *ParticipantImpl) GetReporter() roomobs.ParticipantSessionReporter { + return p.params.Reporter +} + +func (p *ParticipantImpl) GetReporterResolver() roomobs.ParticipantReporterResolver { + return p.params.ReporterResolver +} + +func (p *ParticipantImpl) GetAdaptiveStream() bool { + return p.params.AdaptiveStream +} + +func (p *ParticipantImpl) GetPacer() pacer.Pacer { + return p.TransportManager.GetSubscriberPacer() +} + +func (p *ParticipantImpl) GetDisableSenderReportPassThrough() bool { + return p.params.DisableSenderReportPassThrough +} + +func (p *ParticipantImpl) ID() livekit.ParticipantID { + return p.id.Load().(livekit.ParticipantID) +} + +func (p *ParticipantImpl) Identity() livekit.ParticipantIdentity { + return p.params.Identity +} + +func (p *ParticipantImpl) State() livekit.ParticipantInfo_State { + return p.state.Load().(livekit.ParticipantInfo_State) +} + +func (p *ParticipantImpl) Kind() livekit.ParticipantInfo_Kind { + return p.grants.Load().GetParticipantKind() +} + +func (p *ParticipantImpl) IsRecorder() bool { + grants := p.grants.Load() + return grants.GetParticipantKind() == livekit.ParticipantInfo_EGRESS || grants.Video.Recorder +} + +func (p *ParticipantImpl) IsAgent() bool { + grants := p.grants.Load() + return grants.GetParticipantKind() == livekit.ParticipantInfo_AGENT || grants.Video.Agent +} + +func (p *ParticipantImpl) IsDependent() bool { + grants := p.grants.Load() + switch grants.GetParticipantKind() { + case livekit.ParticipantInfo_AGENT, livekit.ParticipantInfo_EGRESS: + return true + default: + return grants.Video.Agent || grants.Video.Recorder + } +} + +func (p *ParticipantImpl) ProtocolVersion() types.ProtocolVersion { + return p.params.ProtocolVersion +} + +func (p *ParticipantImpl) IsReady() bool { + state := p.State() + + // when migrating, there is no JoinResponse, state transitions from JOINING -> ACTIVE -> DISCONNECTED + // so JOINING is considered ready. + if p.params.Migration { + return state != livekit.ParticipantInfo_DISCONNECTED + } + + // when not migrating, there is a JoinResponse, state transitions from JOINING -> JOINED -> ACTIVE -> DISCONNECTED + return state == livekit.ParticipantInfo_JOINED || state == livekit.ParticipantInfo_ACTIVE +} + +func (p *ParticipantImpl) IsDisconnected() bool { + return p.State() == livekit.ParticipantInfo_DISCONNECTED +} + +func (p *ParticipantImpl) Disconnected() <-chan struct{} { + return p.disconnected +} + +func (p *ParticipantImpl) IsIdle() bool { + // check if there are any published tracks that are subscribed + for _, t := range p.GetPublishedTracks() { + if t.GetNumSubscribers() > 0 { + return false + } + } + + return !p.SubscriptionManager.HasSubscriptions() +} + +func (p *ParticipantImpl) ConnectedAt() time.Time { + return p.connectedAt +} + +func (p *ParticipantImpl) ActiveAt() time.Time { + if activeAt := p.lastActiveAt.Load(); activeAt != nil { + return *activeAt + } + + return time.Time{} +} + +func (p *ParticipantImpl) GetClientInfo() *livekit.ClientInfo { + p.lock.RLock() + defer p.lock.RUnlock() + return p.params.ClientInfo.ClientInfo +} + +func (p *ParticipantImpl) GetClientConfiguration() *livekit.ClientConfiguration { + p.lock.RLock() + defer p.lock.RUnlock() + return utils.CloneProto(p.params.ClientConf) +} + +func (p *ParticipantImpl) GetBufferFactory() *buffer.Factory { + return p.params.Config.BufferFactory +} + +// checkMetadataLimits check if name/metadata/attributes of a participant is within configured limits +func (p *ParticipantImpl) checkMetadataLimits( + name string, + metadata string, + attributes map[string]string, +) error { + if !p.params.LimitConfig.CheckParticipantNameLength(name) { + return signalling.ErrNameExceedsLimits + } + + if !p.params.LimitConfig.CheckMetadataSize(metadata) { + return signalling.ErrMetadataExceedsLimits + } + + if !p.params.LimitConfig.CheckAttributesSize(attributes) { + return signalling.ErrAttributesExceedsLimits + } + + return nil +} + +func (p *ParticipantImpl) UpdateMetadata(update *livekit.UpdateParticipantMetadata, fromAdmin bool) error { + lgr := p.params.Logger.WithUnlikelyValues( + "update", logger.Proto(update), + "fromAdmin", fromAdmin, + ) + lgr.Debugw("updating participant metadata") + + var err error + requestResponse := &livekit.RequestResponse{ + RequestId: update.RequestId, + } + sendRequestResponse := func() error { + if !fromAdmin || (update.RequestId != 0 || err != nil) { + requestResponse.Request = &livekit.RequestResponse_UpdateMetadata{ + UpdateMetadata: utils.CloneProto(update), + } + p.sendRequestResponse(requestResponse) + } + if err != nil { + lgr.Warnw("could not update metadata", err) + } + return err + } + + if !fromAdmin && !p.ClaimGrants().Video.GetCanUpdateOwnMetadata() { + requestResponse.Reason = livekit.RequestResponse_NOT_ALLOWED + requestResponse.Message = "does not have permission to update own metadata" + err = signalling.ErrUpdateOwnMetadataNotAllowed + return sendRequestResponse() + } + + if err = p.checkMetadataLimits(update.Name, update.Metadata, update.Attributes); err != nil { + switch err { + case signalling.ErrNameExceedsLimits: + requestResponse.Reason = livekit.RequestResponse_LIMIT_EXCEEDED + requestResponse.Message = "exceeds name length limit" + + case signalling.ErrMetadataExceedsLimits: + requestResponse.Reason = livekit.RequestResponse_LIMIT_EXCEEDED + requestResponse.Message = "exceeds metadata size limit" + + case signalling.ErrAttributesExceedsLimits: + requestResponse.Reason = livekit.RequestResponse_LIMIT_EXCEEDED + requestResponse.Message = "exceeds attributes size limit" + } + return sendRequestResponse() + } + + if update.Name != "" { + p.SetName(update.Name) + } + if update.Metadata != "" { + p.SetMetadata(update.Metadata) + } + if update.Attributes != nil { + p.SetAttributes(update.Attributes) + } + return sendRequestResponse() +} + +// SetName attaches name to the participant +func (p *ParticipantImpl) SetName(name string) { + p.lock.Lock() + grants := p.grants.Load() + if grants.Name == name { + p.lock.Unlock() + return + } + + grants = grants.Clone() + grants.Name = name + p.grants.Store(grants) + p.dirty.Store(true) + + onClaimsChanged := p.onClaimsChanged + p.lock.Unlock() + + p.listener().OnParticipantUpdate(p) + + if onClaimsChanged != nil { + onClaimsChanged(p) + } +} + +// SetMetadata attaches metadata to the participant +func (p *ParticipantImpl) SetMetadata(metadata string) { + p.lock.Lock() + grants := p.grants.Load() + if grants.Metadata == metadata { + p.lock.Unlock() + return + } + + grants = grants.Clone() + grants.Metadata = metadata + p.grants.Store(grants) + p.requireBroadcast = p.requireBroadcast || metadata != "" + p.dirty.Store(true) + + onClaimsChanged := p.onClaimsChanged + p.lock.Unlock() + + p.listener().OnParticipantUpdate(p) + + if onClaimsChanged != nil { + onClaimsChanged(p) + } +} + +func (p *ParticipantImpl) SetAttributes(attrs map[string]string) { + if len(attrs) == 0 { + return + } + p.lock.Lock() + grants := p.grants.Load().Clone() + if grants.Attributes == nil { + grants.Attributes = make(map[string]string) + } + var keysToDelete []string + for k, v := range attrs { + if v == "" { + keysToDelete = append(keysToDelete, k) + } else { + grants.Attributes[k] = v + } + } + for _, k := range keysToDelete { + delete(grants.Attributes, k) + } + + p.grants.Store(grants) + p.requireBroadcast = true // already checked above + p.dirty.Store(true) + + onClaimsChanged := p.onClaimsChanged + p.lock.Unlock() + + p.listener().OnParticipantUpdate(p) + + if onClaimsChanged != nil { + onClaimsChanged(p) + } +} + +func (p *ParticipantImpl) ClaimGrants() *auth.ClaimGrants { + return p.grants.Load() +} + +func (p *ParticipantImpl) SetPermission(permission *livekit.ParticipantPermission) bool { + if permission == nil { + return false + } + p.lock.Lock() + grants := p.grants.Load() + + if grants.Video.MatchesPermission(permission) { + p.lock.Unlock() + return false + } + + p.params.Logger.Infow("updating participant permission", "permission", permission) + + grants = grants.Clone() + grants.Video.UpdateFromPermission(permission) + p.grants.Store(grants) + p.dirty.Store(true) + + canPublish := grants.Video.GetCanPublish() + canSubscribe := grants.Video.GetCanSubscribe() + + onClaimsChanged := p.onClaimsChanged + + isPublisher := canPublish && p.TransportManager.IsPublisherEstablished() + p.requireBroadcast = p.requireBroadcast || isPublisher + p.lock.Unlock() + + // publish permission has been revoked then remove offending tracks + for _, track := range p.GetPublishedTracks() { + if !grants.Video.GetCanPublishSource(track.Source()) { + p.removePublishedTrack(track) + } + } + + if canSubscribe { + // reconcile everything + p.SubscriptionManager.ReconcileAll() + } else { + // revoke all subscriptions + for _, st := range p.SubscriptionManager.GetSubscribedTracks() { + st.MediaTrack().RemoveSubscriber(p.ID(), false) + } + } + + if !grants.Video.GetCanPublishData() { + for _, dt := range p.UpDataTrackManager.GetPublishedDataTracks() { + p.UpDataTrackManager.RemovePublishedDataTrack(dt) + } + } + + // update isPublisher attribute + p.isPublisher.Store(isPublisher) + + p.listener().OnParticipantUpdate(p) + + if onClaimsChanged != nil { + onClaimsChanged(p) + } + return true +} + +func (p *ParticipantImpl) CanSkipBroadcast() bool { + p.lock.RLock() + defer p.lock.RUnlock() + return !p.requireBroadcast +} + +func (p *ParticipantImpl) maybeIncVersion() { + if p.dirty.Load() { + p.lock.Lock() + if p.dirty.Swap(false) { + p.version.Inc() + p.timedVersion.Update(p.params.VersionGenerator.Next()) + } + p.lock.Unlock() + } +} + +func (p *ParticipantImpl) Version() utils.TimedVersion { + p.maybeIncVersion() + + p.lock.RLock() + defer p.lock.RUnlock() + return p.timedVersion +} + +func (p *ParticipantImpl) ToProtoWithVersion() (*livekit.ParticipantInfo, utils.TimedVersion) { + p.maybeIncVersion() + + p.lock.RLock() + grants := p.grants.Load() + v := p.version.Load() + piv := p.timedVersion + + pi := &livekit.ParticipantInfo{ + Sid: string(p.ID()), + Identity: string(p.params.Identity), + Name: grants.Name, + State: p.State(), + JoinedAt: p.ConnectedAt().Unix(), + JoinedAtMs: p.ConnectedAt().UnixMilli(), + Version: v, + Permission: grants.Video.ToPermission(), + Metadata: grants.Metadata, + Attributes: grants.Attributes, + Region: p.params.Region, + IsPublisher: p.IsPublisher(), + Kind: grants.GetParticipantKind(), + KindDetails: grants.GetKindDetails(), + DisconnectReason: p.CloseReason().ToDisconnectReason(), + } + p.lock.RUnlock() + + p.pendingTracksLock.RLock() + pi.Tracks = p.UpTrackManager.ToProto() + + // add any pending migrating tracks, else an update could delete/unsubscribe from yet to be published, migrating tracks + maybeAdd := func(pti *pendingTrackInfo) { + if !pti.migrated { + return + } + + found := false + for _, ti := range pi.Tracks { + if ti.Sid == pti.trackInfos[0].Sid { + found = true + break + } + } + + if !found { + pi.Tracks = append(pi.Tracks, utils.CloneProto(pti.trackInfos[0])) + } + } + + for _, pt := range p.pendingTracks { + maybeAdd(pt) + } + for _, ppt := range p.pendingPublishingTracks { + maybeAdd(ppt) + } + p.pendingTracksLock.RUnlock() + + pi.DataTracks = p.UpDataTrackManager.ToProto() + + return pi, piv +} + +func (p *ParticipantImpl) ToProto() *livekit.ParticipantInfo { + pi, _ := p.ToProtoWithVersion() + return pi +} + +func (p *ParticipantImpl) TelemetryGuard() *telemetry.ReferenceGuard { + return p.telemetryGuard +} + +func (p *ParticipantImpl) AddOnClose(key string, callback func(types.LocalParticipant)) { + if p.isClosed.Load() { + go callback(p) + return + } + + p.lock.Lock() + if callback == nil { + delete(p.onClose, key) + } else { + p.onClose[key] = callback + } + p.lock.Unlock() +} + +func (p *ParticipantImpl) OnClaimsChanged(callback func(types.LocalParticipant)) { + p.lock.Lock() + p.onClaimsChanged = callback + p.lock.Unlock() +} + +func (p *ParticipantImpl) HandleSignalSourceClose() { + p.TransportManager.SetSignalSourceValid(false) + + if !p.HasConnected() { + _ = p.Close(false, types.ParticipantCloseReasonSignalSourceClose, false) + } +} + +func (p *ParticipantImpl) synthesizeAddTrackRequests(parsedOffer *sdp.SessionDescription) error { + for _, m := range parsedOffer.MediaDescriptions { + if !strings.EqualFold(m.MediaName.Media, "audio") && !strings.EqualFold(m.MediaName.Media, "video") { + continue + } + + cid := protosdp.GetMediaStreamTrack(m) + if cid == "" { + cid = guid.New(utils.TrackPrefix) + } + + rids, ridsOk := protosdp.GetSimulcastRids(m) + + var ( + name string + trackSource livekit.TrackSource + trackType livekit.TrackType + ) + if strings.EqualFold(m.MediaName.Media, "audio") { + name = "synthesized-microphone" + trackSource = livekit.TrackSource_MICROPHONE + trackType = livekit.TrackType_AUDIO + } else { + name = "synthesized-camera" + trackSource = livekit.TrackSource_CAMERA + trackType = livekit.TrackType_VIDEO + } + req := &livekit.AddTrackRequest{ + Cid: cid, + Name: name, + Source: trackSource, + Type: trackType, + DisableDtx: true, + Stereo: false, + Stream: "camera", + } + if strings.EqualFold(m.MediaName.Media, "video") { + if ridsOk { + // add simulcast layers, NOTE: only quality can be set as dimensions/fps is not available + n := min(len(rids), int(buffer.DefaultMaxLayerSpatial)+1) + for i := range n { + // WARN: casting int -> protobuf enum + req.Layers = append(req.Layers, &livekit.VideoLayer{Quality: livekit.VideoQuality(i)}) + } + } else { + // dummy layer to ensure at least one layer is available + req.Layers = []*livekit.VideoLayer{{}} + } + } + p.AddTrack(req) + } + return nil +} + +func (p *ParticipantImpl) updateRidsFromSDP(parsed *sdp.SessionDescription, unmatchVideos []*sdp.MediaDescription) { + for _, m := range parsed.MediaDescriptions { + if m.MediaName.Media != "video" || !slices.Contains(unmatchVideos, m) { + continue + } + + mst := protosdp.GetMediaStreamTrack(m) + if mst == "" { + continue + } + + getRids := func(inRids buffer.VideoLayersRid) buffer.VideoLayersRid { + var outRids buffer.VideoLayersRid + rids, ok := protosdp.GetSimulcastRids(m) + if ok { + n := min(len(rids), len(inRids)) + for i := range n { + // disabled layers will have a `~` prefix, remove it while determining actual rid + if len(rids[i]) != 0 && rids[i][0] == '~' { + outRids[i] = rids[i][1:] + } else { + outRids[i] = rids[i] + } + } + for i := n; i < len(inRids); i++ { + outRids[i] = "" + } + outRids = buffer.NormalizeVideoLayersRid(outRids) + } else { + for i := range len(inRids) { + outRids[i] = "" + } + } + + return outRids + } + + p.pendingTracksLock.Lock() + pti := p.getPendingTrackPrimaryBySdpCid(mst) + if pti != nil { + pti.sdpRids = getRids(pti.sdpRids) + p.pubLogger.Debugw( + "pending track rids updated", + "trackID", pti.trackInfos[0].Sid, + "pendingTrack", pti, + ) + + ti := pti.trackInfos[0] + for _, codec := range ti.Codecs { + if codec.Cid == mst || codec.SdpCid == mst { + mimeType := mime.NormalizeMimeType(codec.MimeType) + for _, layer := range codec.Layers { + layer.SpatialLayer = buffer.VideoQualityToSpatialLayer(mimeType, layer.Quality, ti) + layer.Rid = buffer.VideoQualityToRid(mimeType, layer.Quality, ti, pti.sdpRids) + } + } + } + } + p.pendingTracksLock.Unlock() + + if pti == nil { + // track could already be published, but this could be back up codec offer, + // so check in published tracks also + mt := p.getPublishedTrackBySdpCid(mst) + if mt != nil { + mimeType := mt.(*MediaTrack).GetMimeTypeForSdpCid(mst) + if mimeType != mime.MimeTypeUnknown { + rids := getRids(buffer.DefaultVideoLayersRid) + mt.(*MediaTrack).UpdateCodecRids(mimeType, rids) + p.pubLogger.Debugw( + "published track rids updated", + "trackID", mt.ID(), + "mime", mimeType, + "track", logger.Proto(mt.ToProto()), + ) + } else { + p.pubLogger.Warnw( + "could not get mime type for sdp cid", nil, + "trackID", mt.ID(), + "sdpCid", mst, + "track", logger.Proto(mt.ToProto()), + ) + } + } + } + } +} + +func (p *ParticipantImpl) HandleICETrickle(trickleRequest *livekit.TrickleRequest) { + candidateInit, err := protosignalling.FromProtoTrickle(trickleRequest) + if err != nil { + p.params.Logger.Warnw("could not decode trickle", err) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_UNCLASSIFIED_ERROR, + Message: err.Error(), + Request: &livekit.RequestResponse_Trickle{ + Trickle: utils.CloneProto(trickleRequest), + }, + }) + return + } + + p.TransportManager.AddICECandidate(candidateInit, trickleRequest.Target) +} + +// HandleOffer an offer from remote participant, used when clients make the initial connection +func (p *ParticipantImpl) HandleOffer(sd *livekit.SessionDescription) error { + offer, offerId, _ := protosignalling.FromProtoSessionDescription(sd) + lgr := p.pubLogger.WithUnlikelyValues( + "transport", livekit.SignalTarget_PUBLISHER, + "offer", offer, + "offerId", offerId, + ) + + lgr.Debugw("received offer") + + parsedOffer, err := offer.Unmarshal() + if err != nil { + lgr.Warnw("could not parse offer", err) + return err + } + + if p.params.UseOneShotSignallingMode { + if err := p.synthesizeAddTrackRequests(parsedOffer); err != nil { + lgr.Warnw("could not synthesize add track requests", err) + return err + } + } + + err = p.TransportManager.HandleOffer(offer, offerId, p.MigrateState() == types.MigrateStateInit) + if err != nil { + lgr.Warnw("could not handle offer", err, "mungedOffer", offer) + return err + } + + if p.params.UseOneShotSignallingMode { + go p.listener().OnSubscriberReady(p) + } + + p.handlePendingRemoteTracks() + return nil +} + +func (p *ParticipantImpl) onPublisherSetRemoteDescription() { + offer := p.TransportManager.LastPublisherOfferPending() + parsedOffer, err := offer.Unmarshal() + if err != nil { + p.pubLogger.Warnw("could not parse offer", err) + return + } + + // set publish codec preferences after remote description is set + // and required transceivers are created + unmatchAudios, unmatchVideos := p.populateSdpCid(parsedOffer) + p.setCodecPreferencesForPublisher(parsedOffer, unmatchAudios, unmatchVideos) + p.updateRidsFromSDP(parsedOffer, unmatchVideos) +} + +func (p *ParticipantImpl) onPublisherAnswer(answer webrtc.SessionDescription, answerId uint32, midToTrackID map[string]string) error { + if p.IsClosed() || p.IsDisconnected() { + return nil + } + + answer = p.configurePublisherAnswer(answer) + p.pubLogger.Debugw( + "sending answer", + "transport", livekit.SignalTarget_PUBLISHER, + "answer", answer, + "answerId", answerId, + "midToTrackID", midToTrackID, + ) + + return p.sendSdpAnswer(answer, answerId, midToTrackID) +} + +func (p *ParticipantImpl) GetAnswer() (webrtc.SessionDescription, uint32, error) { + if p.IsClosed() || p.IsDisconnected() { + return webrtc.SessionDescription{}, 0, ErrParticipantSessionClosed + } + + answer, answerId, err := p.TransportManager.GetAnswer() + if err != nil { + return answer, answerId, err + } + + answer = p.configurePublisherAnswer(answer) + p.pubLogger.Debugw( + "returning answer", + "transport", livekit.SignalTarget_PUBLISHER, + "answer", answer, + "answerId", answerId, + ) + return answer, answerId, nil +} + +// HandleAnswer handles a client answer response, with subscriber PC, server initiates the +// offer and client answers +func (p *ParticipantImpl) HandleAnswer(sd *livekit.SessionDescription) { + answer, answerId, _ := protosignalling.FromProtoSessionDescription(sd) + p.subLogger.Debugw( + "received answer", + "transport", livekit.SignalTarget_SUBSCRIBER, + "answer", answer, + "answerId", answerId, + ) + + /* from server received join request to client answer + * 1. server send join response & offer + * ... swap candidates + * 2. client send answer + */ + signalConnCost := time.Since(p.ConnectedAt()).Milliseconds() + p.TransportManager.UpdateSignalingRTT(uint32(signalConnCost)) + + p.TransportManager.HandleAnswer(answer, answerId) +} + +func (p *ParticipantImpl) handleMigrateTracks() []*MediaTrack { + // muted track won't send rtp packet, so it is required to add mediatrack manually. + // But, synthesising track publish for unmuted tracks keeps a consistent path. + // In both cases (muted and unmuted), when publisher sends media packets, OnTrack would register and go from there. + var addedTracks []*MediaTrack + p.pendingTracksLock.Lock() + for cid, pti := range p.pendingTracks { + if !pti.migrated { + continue + } + + if len(pti.trackInfos) > 1 { + p.pubLogger.Warnw( + "too many pending migrated tracks", nil, + "trackID", pti.trackInfos[0].Sid, + "count", len(pti.trackInfos), + "cid", cid, + ) + } + + mt := p.addMigratedTrack(cid, pti.trackInfos[0]) + if mt != nil { + addedTracks = append(addedTracks, mt) + } else { + p.pubLogger.Warnw("could not find migrated track, migration failed", nil, "cid", cid) + p.pendingTracksLock.Unlock() + p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch) + return nil + } + } + + if len(addedTracks) != 0 { + p.dirty.Store(true) + } + p.pendingTracksLock.Unlock() + + return addedTracks +} + +// AddTrack is called when client intends to publish track. +// records track details and lets client know it's ok to proceed +func (p *ParticipantImpl) AddTrack(req *livekit.AddTrackRequest) { + p.params.Logger.Debugw("add track request", "trackID", req.Cid, "request", logger.Proto(req)) + if !p.CanPublishSource(req.Source) { + p.pubLogger.Warnw("no permission to publish track", nil, "trackID", req.Sid, "kind", req.Type) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_ALLOWED, + Request: &livekit.RequestResponse_AddTrack{ + AddTrack: utils.CloneProto(req), + }, + }) + return + } + + if req.Type != livekit.TrackType_AUDIO && req.Type != livekit.TrackType_VIDEO { + p.pubLogger.Warnw("unsupported track type", nil, "trackID", req.Sid, "kind", req.Type) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_UNSUPPORTED_TYPE, + Request: &livekit.RequestResponse_AddTrack{ + AddTrack: utils.CloneProto(req), + }, + }) + return + } + + p.pendingTracksLock.Lock() + ti := p.addPendingTrackLocked(req) + p.pendingTracksLock.Unlock() + if ti == nil { + return + } + + p.sendTrackPublished(req.Cid, ti) + + p.handlePendingRemoteTracks() +} + +func (p *ParticipantImpl) SetMigrateInfo( + previousOffer, previousAnswer *webrtc.SessionDescription, + mediaTracks []*livekit.TrackPublishedResponse, + dataChannels []*livekit.DataChannelInfo, + dataChannelReceiveState []*livekit.DataChannelReceiveState, + dataTracks []*livekit.PublishDataTrackResponse, +) { + p.pendingTracksLock.Lock() + for _, t := range mediaTracks { + ti := t.GetTrack() + + if p.supervisor != nil { + p.supervisor.AddPublication(livekit.TrackID(ti.Sid)) + p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted) + } + + p.pendingTracks[t.GetCid()] = &pendingTrackInfo{ + trackInfos: []*livekit.TrackInfo{ti}, + migrated: true, + createdAt: time.Now(), + } + p.pubLogger.Infow( + "pending track added (migration)", + "trackID", ti.Sid, + "cid", t.GetCid(), + "pendingTrack", p.pendingTracks[t.GetCid()], + ) + } + p.pendingTracksLock.Unlock() + + for _, t := range dataTracks { + dti := t.GetInfo() + dt := NewDataTrack( + DataTrackParams{ + Logger: p.params.Logger.WithValues("trackID", dti.Sid), + ParticipantID: p.ID, + ParticipantIdentity: p.params.Identity, + }, + dti, + ) + p.UpDataTrackManager.AddPublishedDataTrack(dt) + } + + if len(mediaTracks) != 0 || len(dataTracks) != 0 { + p.setIsPublisher(true) + } + + p.reliableDataInfo.joiningMessageLock.Lock() + for _, state := range dataChannelReceiveState { + p.reliableDataInfo.joiningMessageFirstSeqs[livekit.ParticipantID(state.PublisherSid)] = state.LastSeq + 1 + } + p.reliableDataInfo.joiningMessageLock.Unlock() + + p.TransportManager.SetMigrateInfo(previousOffer, previousAnswer, dataChannels) +} + +func (p *ParticipantImpl) IsReconnect() bool { + return p.params.Reconnect +} + +func (p *ParticipantImpl) Close(sendLeave bool, reason types.ParticipantCloseReason, isExpectedToResume bool) error { + if p.isClosed.Swap(true) { + // already closed + return nil + } + + p.params.Logger.Infow( + "participant closing", + "sendLeave", sendLeave, + "reason", reason.String(), + "isExpectedToResume", isExpectedToResume, + ) + p.closeReason.Store(reason) + p.clearDisconnectTimer() + p.clearMigrationTimer() + + if sendLeave { + p.sendLeaveRequest( + reason, + isExpectedToResume, + false, // isExpectedToReconnect + false, // sendOnlyIfSupportingLeaveRequestWithAction + ) + } + + if p.supervisor != nil { + p.supervisor.Stop() + } + + p.pendingTracksLock.Lock() + for _, pti := range p.pendingTracks { + if len(pti.trackInfos) == 0 { + continue + } + prometheus.RecordTrackPublishCancels(pti.trackInfos[0].Type.String(), int32(len(pti.trackInfos))) + } + p.pendingTracks = make(map[string]*pendingTrackInfo) + p.pendingPublishingTracks = make(map[livekit.TrackID]*pendingTrackInfo) + p.pendingTracksLock.Unlock() + + p.UpTrackManager.Close(isExpectedToResume) + + p.rpcLock.Lock() + clear(p.rpcPendingAcks) + for _, handler := range p.rpcPendingResponses { + handler.Resolve("", utils.DataChannelRpcErrorFromBuiltInCodes(utils.DataChannelRpcRecipientDisconnected, "")) + } + p.rpcPendingResponses = make(map[string]*utils.DataChannelRpcPendingResponseHandler) + p.rpcLock.Unlock() + + p.updateState(livekit.ParticipantInfo_DISCONNECTED) + close(p.disconnected) + + // ensure this is synchronized + p.CloseSignalConnection(types.SignallingCloseReasonParticipantClose) + p.lock.RLock() + onClose := maps.Values(p.onClose) + p.lock.RUnlock() + for _, cb := range onClose { + cb(p) + } + + // Close peer connections without blocking participant Close. If peer connections are gathering candidates + // Close will block. + go func() { + p.SubscriptionManager.Close(isExpectedToResume) + p.TransportManager.Close() + + p.metricsCollector.Stop() + p.metricsReporter.Stop() + }() + + p.dataChannelStats.Stop() + return nil +} + +func (p *ParticipantImpl) IsClosed() bool { + return p.isClosed.Load() +} + +func (p *ParticipantImpl) CloseReason() types.ParticipantCloseReason { + return p.closeReason.Load().(types.ParticipantCloseReason) +} + +// Negotiate subscriber SDP with client, if force is true, will cancel pending +// negotiate task and negotiate immediately +func (p *ParticipantImpl) Negotiate(force bool) { + if p.params.UseOneShotSignallingMode { + return + } + + if p.MigrateState() != types.MigrateStateInit { + p.TransportManager.NegotiateSubscriber(force) + } +} + +func (p *ParticipantImpl) clearMigrationTimer() { + p.lock.Lock() + if p.migrationTimer != nil { + p.migrationTimer.Stop() + p.migrationTimer = nil + } + p.lock.Unlock() +} + +func (p *ParticipantImpl) setupMigrationTimerLocked() { + if p.params.UseSinglePeerConnection { + return + } + + // + // On subscriber peer connection, remote side will try ICE on both + // pre- and post-migration ICE candidates as the migrating out + // peer connection leaves itself open to enable transition of + // media with as less disruption as possible. + // + // But, sometimes clients could delay the migration because of + // pinging the incorrect ICE candidates. Give the remote some time + // to try and succeed. If not, close the subscriber peer connection + // and help the remote side to narrow down its ICE candidate pool. + // + p.migrationTimer = time.AfterFunc(migrationWaitDuration, func() { + p.clearMigrationTimer() + + if p.IsClosed() || p.IsDisconnected() { + return + } + p.subLogger.Debugw("closing subscriber peer connection to aid migration") + + // + // Close all down tracks before closing subscriber peer connection. + // Closing subscriber peer connection will call `Unbind` on all down tracks. + // DownTrack close has checks to handle the case of closing before bind. + // So, an `Unbind` before close would bypass that logic. + // + p.SubscriptionManager.Close(true) + + p.TransportManager.SubscriberClose() + }) +} + +func (p *ParticipantImpl) MaybeStartMigration(force bool, onStart func()) bool { + if p.params.UseOneShotSignallingMode { + return false + } + + allTransportConnected := p.TransportManager.HasSubscriberEverConnected() + if p.IsPublisher() { + allTransportConnected = allTransportConnected && p.TransportManager.HasPublisherEverConnected() + } + if !force && !allTransportConnected { + return false + } + + if onStart != nil { + onStart() + } + + p.sendLeaveRequest( + types.ParticipantCloseReasonMigrationRequested, + true, // isExpectedToResume + false, // isExpectedToReconnect + true, // sendOnlyIfSupportingLeaveRequestWithAction + ) + p.CloseSignalConnection(types.SignallingCloseReasonMigration) + + p.clearMigrationTimer() + + p.lock.Lock() + p.setupMigrationTimerLocked() + p.lock.Unlock() + + return true +} + +func (p *ParticipantImpl) NotifyMigration() { + p.lock.Lock() + defer p.lock.Unlock() + + if p.migrationTimer != nil { + // already set up + return + } + + p.setupMigrationTimerLocked() +} + +func (p *ParticipantImpl) SetMigrateState(s types.MigrateState) { + preState := p.MigrateState() + if preState == types.MigrateStateComplete || preState == s { + return + } + + p.params.Logger.Debugw("SetMigrateState", "state", s) + var migratedTracks []*MediaTrack + if s == types.MigrateStateComplete { + migratedTracks = p.handleMigrateTracks() + } + p.migrateState.Store(s) + p.dirty.Store(true) + + switch s { + case types.MigrateStateSync: + p.TransportManager.ProcessPendingPublisherOffer() + + case types.MigrateStateComplete: + if preState == types.MigrateStateSync { + p.params.Logger.Infow("migration complete") + + if p.params.LastPubReliableSeq > 0 { + p.reliableDataInfo.migrateInPubDataCache.Store(NewMigrationDataCache(p.params.LastPubReliableSeq, time.Now().Add(migrationWaitContinuousMsgDuration))) + } + } + p.TransportManager.ProcessPendingPublisherDataChannels() + go p.cacheForwarderState() + } + + go func() { + // launch callbacks in goroutine since they could block. + // callbacks handle webhooks as well as db persistence + for _, t := range migratedTracks { + p.handleTrackPublished(t, true) + } + + if s == types.MigrateStateComplete { + // wait for all migrated track to be published, + // it is possible that synthesized track publish above could + // race with actual publish from client and the above synthesized + // one could actually be a no-op because the actual publish path is active. + // + // if the actual publish path has not finished, the migration state change + // callback could close the remote participant/tracks before the local track + // is fully active. + // + // that could lead subscribers to unsubscribe due to source + // track going away, i. e. in this case, the remote track close would have + // notified the subscription manager, the subscription manager would + // re-resolve to check if the track is still active and unsubscribe if none + // is active, as local track is in the process of completing publish, + // the check would have resolved to an empty track leading to unsubscription. + go func() { + startTime := time.Now() + for { + if !p.hasPendingMigratedTrack() || p.IsDisconnected() || time.Since(startTime) > 15*time.Second { + // a time out just to be safe, but it should not be needed + p.migratedTracksPublishedFuse.Break() + return + } + + time.Sleep(20 * time.Millisecond) + } + }() + + <-p.migratedTracksPublishedFuse.Watch() + } + + p.listener().OnMigrateStateChange(p, s) + }() +} + +func (p *ParticipantImpl) MigrateState() types.MigrateState { + return p.migrateState.Load().(types.MigrateState) +} + +// ICERestart restarts subscriber ICE connections +func (p *ParticipantImpl) ICERestart(iceConfig *livekit.ICEConfig) { + if p.params.UseOneShotSignallingMode { + return + } + + p.clearDisconnectTimer() + p.clearMigrationTimer() + + for _, t := range p.GetPublishedTracks() { + t.(types.LocalMediaTrack).Restart() + } + + if err := p.TransportManager.ICERestart(iceConfig); err != nil { + p.IssueFullReconnect(types.ParticipantCloseReasonNegotiateFailed) + } +} + +func (p *ParticipantImpl) OnICEConfigChanged(f func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig)) { + p.lock.Lock() + p.onICEConfigChanged = f + p.lock.Unlock() +} + +func (p *ParticipantImpl) GetConnectionQuality() *livekit.ConnectionQualityInfo { + minQuality := livekit.ConnectionQuality_EXCELLENT + minScore := connectionquality.MaxMOS + + for _, pt := range p.GetPublishedTracks() { + score, quality := pt.(types.LocalMediaTrack).GetConnectionScoreAndQuality() + if utils.IsConnectionQualityLower(minQuality, quality) { + minQuality = quality + minScore = score + } else if quality == minQuality && score < minScore { + minScore = score + } + } + + subscribedTracks := p.SubscriptionManager.GetSubscribedTracks() + for _, subTrack := range subscribedTracks { + score, quality := subTrack.DownTrack().GetConnectionScoreAndQuality() + if utils.IsConnectionQualityLower(minQuality, quality) { + minQuality = quality + minScore = score + } else if quality == minQuality && score < minScore { + minScore = score + } + } + + prometheus.RecordQuality(minQuality, minScore) + + if minQuality == livekit.ConnectionQuality_LOST && !p.ProtocolVersion().SupportsConnectionQualityLost() { + minQuality = livekit.ConnectionQuality_POOR + } + + p.lock.Lock() + if minQuality != p.connectionQuality { + p.params.Logger.Debugw("connection quality changed", "from", p.connectionQuality, "to", minQuality) + } + p.connectionQuality = minQuality + p.lock.Unlock() + + return &livekit.ConnectionQualityInfo{ + ParticipantSid: string(p.ID()), + Quality: minQuality, + Score: minScore, + } +} + +func (p *ParticipantImpl) IsPublisher() bool { + return p.isPublisher.Load() +} + +func (p *ParticipantImpl) CanPublish() bool { + return p.grants.Load().Video.GetCanPublish() +} + +func (p *ParticipantImpl) CanPublishSource(source livekit.TrackSource) bool { + return p.grants.Load().Video.GetCanPublishSource(source) +} + +func (p *ParticipantImpl) CanSubscribe() bool { + return p.grants.Load().Video.GetCanSubscribe() +} + +func (p *ParticipantImpl) CanPublishData() bool { + return p.grants.Load().Video.GetCanPublishData() +} + +func (p *ParticipantImpl) Hidden() bool { + return p.grants.Load().Video.Hidden +} + +func (p *ParticipantImpl) CanSubscribeMetrics() bool { + return p.grants.Load().Video.GetCanSubscribeMetrics() +} + +func (p *ParticipantImpl) Verify() bool { + state := p.State() + isActive := state != livekit.ParticipantInfo_JOINING && state != livekit.ParticipantInfo_JOINED + if p.params.UseOneShotSignallingMode { + isActive = isActive && p.TransportManager.HasPublisherEverConnected() + } + + return isActive +} + +func (p *ParticipantImpl) VerifySubscribeParticipantInfo(pID livekit.ParticipantID, version uint32) { + if !p.IsReady() { + // we have not sent a JoinResponse yet. metadata would be covered in JoinResponse + return + } + if info, ok := p.updateCache.Get(pID); ok && info.version >= version { + return + } + + if info := p.helper().GetParticipantInfo(pID); info != nil { + _ = p.SendParticipantUpdate([]*livekit.ParticipantInfo{info}) + } +} + +// onTrackSubscribed handles post-processing after a track is subscribed +func (p *ParticipantImpl) onTrackSubscribed(subTrack types.SubscribedTrack) { + if p.params.ClientInfo.FireTrackByRTPPacket() { + subTrack.DownTrack().SetActivePaddingOnMuteUpTrack() + } + + subTrack.AddOnBind(func(err error) { + if err != nil { + return + } + if p.params.UseOneShotSignallingMode { + if p.TransportManager.HasPublisherEverConnected() { + dt := subTrack.DownTrack() + dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(subTrack.ID())}) + dt.SetConnected() + } + // ONE-SHOT-SIGNALLING-MODE-TODO: video support should add to publisher PC for congestion control + } else { + if p.TransportManager.HasSubscriberEverConnected() { + dt := subTrack.DownTrack() + dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(subTrack.ID())}) + dt.SetConnected() + } + p.TransportManager.AddSubscribedTrack(subTrack) + } + }) +} + +// onTrackUnsubscribed handles post-processing after a track is unsubscribed +func (p *ParticipantImpl) onTrackUnsubscribed(subTrack types.SubscribedTrack) { + p.TransportManager.RemoveSubscribedTrack(subTrack) +} + +func (p *ParticipantImpl) UpdateMediaRTT(rtt uint32) { + now := time.Now() + p.lock.Lock() + if now.Sub(p.rttUpdatedAt) < rttUpdateInterval || p.lastRTT == rtt { + p.lock.Unlock() + return + } + p.rttUpdatedAt = now + p.lastRTT = rtt + p.lock.Unlock() + p.TransportManager.UpdateMediaRTT(rtt) + + for _, pt := range p.GetPublishedTracks() { + pt.(types.LocalMediaTrack).SetRTT(rtt) + } +} + +// ---------------------------------------------------------- + +var _ transport.Handler = (*AnyTransportHandler)(nil) + +type AnyTransportHandler struct { + transport.UnimplementedHandler + p *ParticipantImpl +} + +func (h AnyTransportHandler) OnFailed(_isShortLived bool, _ici *types.ICEConnectionInfo) { + h.p.onAnyTransportFailed() +} + +func (h AnyTransportHandler) OnNegotiationFailed() { + h.p.onAnyTransportNegotiationFailed() +} + +func (h AnyTransportHandler) OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + return h.p.onICECandidate(c, target) +} + +// ---------------------------------------------------------- + +type PublisherTransportHandler struct { + AnyTransportHandler +} + +func (h PublisherTransportHandler) OnSetRemoteDescriptionOffer() { + h.p.onPublisherSetRemoteDescription() +} + +func (h PublisherTransportHandler) OnAnswer(sd webrtc.SessionDescription, answerId uint32, midToTrackID map[string]string) error { + return h.p.onPublisherAnswer(sd, answerId, midToTrackID) +} + +func (h PublisherTransportHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + h.p.onMediaTrack(track, rtpReceiver) +} + +func (h PublisherTransportHandler) OnInitialConnected() { + h.p.onPublisherInitialConnected() +} + +func (h PublisherTransportHandler) OnDataMessage(kind livekit.DataPacket_Kind, data []byte) { + h.p.onReceivedDataMessage(kind, data) +} + +func (h PublisherTransportHandler) OnDataMessageUnlabeled(data []byte) { + h.p.onReceivedDataMessageUnlabeled(data) +} + +func (h PublisherTransportHandler) OnDataTrackMessage(data []byte, arrivalTime int64) { + h.p.onReceivedDataTrackMessage(data, arrivalTime) +} + +func (h PublisherTransportHandler) OnDataSendError(err error) { + h.p.onDataSendError(err) +} + +func (h PublisherTransportHandler) OnUnmatchedMedia(numAudios uint32, numVideos uint32) error { + return h.p.sendMediaSectionsRequirement(numAudios, numVideos) +} + +// ---------------------------------------------------------- + +type SubscriberTransportHandler struct { + AnyTransportHandler +} + +func (h SubscriberTransportHandler) OnOffer(sd webrtc.SessionDescription, offerId uint32, midToTrackID map[string]string) error { + return h.p.onSubscriberOffer(sd, offerId, midToTrackID) +} + +func (h SubscriberTransportHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { + return h.p.onStreamStateChange(update) +} + +func (h SubscriberTransportHandler) OnInitialConnected() { + h.p.onSubscriberInitialConnected() +} + +func (h SubscriberTransportHandler) OnDataSendError(err error) { + h.p.onDataSendError(err) +} + +// ---------------------------------------------------------- + +type PrimaryTransportHandler struct { + transport.Handler + p *ParticipantImpl +} + +func (h PrimaryTransportHandler) OnInitialConnected() { + h.Handler.OnInitialConnected() + h.p.onPrimaryTransportInitialConnected() +} + +func (h PrimaryTransportHandler) OnFullyEstablished() { + h.p.onPrimaryTransportFullyEstablished() +} + +// ---------------------------------------------------------- + +func (p *ParticipantImpl) setupSignalling() { + p.signalling = signalling.NewSignalling(signalling.SignallingParams{ + Logger: p.params.Logger, + }) + p.signalHandler = signalling.NewSignalHandler(signalling.SignalHandlerParams{ + Logger: p.params.Logger, + Participant: p, + }) + p.signaller = signalling.NewSignallerAsync(signalling.SignallerAsyncParams{ + Logger: p.params.Logger, + Participant: p, + }) +} + +func (p *ParticipantImpl) setupTransportManager() error { + p.twcc = twcc.NewTransportWideCCResponder() + p.twcc.OnFeedback(func(pkts []rtcp.Packet) { + p.postRtcp(pkts) + }) + ath := AnyTransportHandler{p: p} + var pth transport.Handler = PublisherTransportHandler{ath} + var sth transport.Handler = SubscriberTransportHandler{ath} + + subscriberAsPrimary := !p.params.UseOneShotSignallingMode && (p.ProtocolVersion().SubscriberAsPrimary() && p.CanSubscribe()) && !p.params.UseSinglePeerConnection + if subscriberAsPrimary { + sth = PrimaryTransportHandler{sth, p} + } else { + pth = PrimaryTransportHandler{pth, p} + } + + params := TransportManagerParams{ + // primary connection does not change, canSubscribe can change if permission was updated + // after the participant has joined + SubscriberAsPrimary: subscriberAsPrimary, + UseSinglePeerConnection: p.params.UseSinglePeerConnection, + Config: p.params.Config, + Twcc: p.twcc, + ProtocolVersion: p.params.ProtocolVersion, + CongestionControlConfig: p.params.CongestionControlConfig, + EnabledPublishCodecs: p.enabledPublishCodecs, + EnabledSubscribeCodecs: p.enabledSubscribeCodecs, + SimTracks: p.params.SimTracks, + ClientInfo: p.params.ClientInfo, + Migration: p.params.Migration, + AllowTCPFallback: p.params.AllowTCPFallback, + TCPFallbackRTTThreshold: p.params.TCPFallbackRTTThreshold, + AllowUDPUnstableFallback: p.params.AllowUDPUnstableFallback, + TURNSEnabled: p.params.TURNSEnabled, + AllowPlayoutDelay: p.params.PlayoutDelay.GetEnabled(), + DataChannelMaxBufferedAmount: p.params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: p.params.DatachannelSlowThreshold, + DatachannelLossyTargetLatency: p.params.DatachannelLossyTargetLatency, + Logger: p.params.Logger.WithComponent(sutils.ComponentTransport), + PublisherHandler: pth, + SubscriberHandler: sth, + DataChannelStats: p.dataChannelStats, + UseOneShotSignallingMode: p.params.UseOneShotSignallingMode, + FireOnTrackBySdp: p.params.FireOnTrackBySdp, + EnableDataTracks: p.params.EnableDataTracks, + } + if p.params.SyncStreams && p.params.PlayoutDelay.GetEnabled() && p.params.ClientInfo.isFirefox() { + // we will disable playout delay for Firefox if the user is expecting + // the streams to be synced. Firefox doesn't support SyncStreams + params.AllowPlayoutDelay = false + } + tm, err := NewTransportManager(params) + if err != nil { + return err + } + + tm.OnICEConfigChanged(func(iceConfig *livekit.ICEConfig) { + p.lock.Lock() + onICEConfigChanged := p.onICEConfigChanged + + if p.params.ClientConf == nil { + p.params.ClientConf = &livekit.ClientConfiguration{} + } + if iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TLS { + p.params.ClientConf.ForceRelay = livekit.ClientConfigSetting_ENABLED + } else { + // UNSET indicates that clients could override RTCConfiguration to forceRelay + p.params.ClientConf.ForceRelay = livekit.ClientConfigSetting_UNSET + } + p.lock.Unlock() + + if onICEConfigChanged != nil { + onICEConfigChanged(p, iceConfig) + } + }) + + tm.SetSubscriberAllowPause(p.params.SubscriberAllowPause) + p.TransportManager = tm + return nil +} + +func (p *ParticipantImpl) setupUpTrackManager() { + p.UpTrackManager = NewUpTrackManager(UpTrackManagerParams{ + Logger: p.pubLogger, + VersionGenerator: p.params.VersionGenerator, + }) + + p.UpTrackManager.OnPublishedTrackUpdated(func(track types.MediaTrack) { + p.dirty.Store(true) + p.listener().OnTrackUpdated(p, track) + }) + + p.UpTrackManager.OnUpTrackManagerClose(p.onUpTrackManagerClose) +} + +func (p *ParticipantImpl) setupUpDataTrackManager() { + p.UpDataTrackManager = NewUpDataTrackManager(UpDataTrackManagerParams{ + Logger: p.pubLogger, + Participant: p, + }) +} + +func (p *ParticipantImpl) setupSubscriptionManager() { + p.SubscriptionManager = NewSubscriptionManager(SubscriptionManagerParams{ + Participant: p, + Logger: p.subLogger.WithoutSampler(), + TrackResolver: func(lp types.LocalParticipant, ti livekit.TrackID) types.MediaResolverResult { + return p.helper().ResolveMediaTrack(lp, ti) + }, + DataTrackResolver: func(lp types.LocalParticipant, ti livekit.TrackID) types.DataResolverResult { + return p.helper().ResolveDataTrack(lp, ti) + }, + Telemetry: p.params.Telemetry, + OnTrackSubscribed: p.onTrackSubscribed, + OnTrackUnsubscribed: p.onTrackUnsubscribed, + OnSubscriptionError: p.onSubscriptionError, + SubscriptionLimitVideo: p.params.SubscriptionLimitVideo, + SubscriptionLimitAudio: p.params.SubscriptionLimitAudio, + UseOneShotSignallingMode: p.params.UseOneShotSignallingMode, + }) +} + +func (p *ParticipantImpl) MetricsCollectorTimeToCollectMetrics() { + publisherRTT, ok := p.TransportManager.GetPublisherRTT() + if ok { + p.metricsCollector.AddPublisherRTT(p.Identity(), float32(publisherRTT)) + } + + subscriberRTT, ok := p.TransportManager.GetSubscriberRTT() + if ok { + p.metricsCollector.AddSubscriberRTT(float32(subscriberRTT)) + } +} + +func (p *ParticipantImpl) MetricsCollectorBatchReady(mb *livekit.MetricsBatch) { + p.listener().OnMetrics(p, &livekit.DataPacket{ + ParticipantIdentity: string(p.Identity()), + Value: &livekit.DataPacket_Metrics{ + Metrics: mb, + }, + }) +} + +func (p *ParticipantImpl) MetricsReporterBatchReady(mb *livekit.MetricsBatch) { + dpData, err := proto.Marshal(&livekit.DataPacket{ + ParticipantIdentity: string(p.Identity()), + Value: &livekit.DataPacket_Metrics{ + Metrics: mb, + }, + }) + if err != nil { + p.params.Logger.Errorw("failed to marshal data packet", err) + return + } + + p.TransportManager.SendDataMessage(livekit.DataPacket_RELIABLE, dpData) +} + +func (p *ParticipantImpl) setupMetrics() { + if !p.params.EnableMetrics { + return + } + + p.metricTimestamper = metric.NewMetricTimestamper(metric.MetricTimestamperParams{ + Config: p.params.MetricConfig.Timestamper, + Logger: p.params.Logger, + }) + p.metricsCollector = metric.NewMetricsCollector(metric.MetricsCollectorParams{ + ParticipantIdentity: p.Identity(), + Config: p.params.MetricConfig.Collector, + Provider: p, + Logger: p.params.Logger, + }) + p.metricsReporter = metric.NewMetricsReporter(metric.MetricsReporterParams{ + ParticipantIdentity: p.Identity(), + Config: p.params.MetricConfig.Reporter, + Consumer: p, + Logger: p.params.Logger, + }) +} + +func (p *ParticipantImpl) updateState(state livekit.ParticipantInfo_State) { + var oldState livekit.ParticipantInfo_State + for { + oldState = p.state.Load().(livekit.ParticipantInfo_State) + if state <= oldState { + p.params.Logger.Debugw("ignoring out of order participant state", "state", state.String()) + return + } + if state == livekit.ParticipantInfo_ACTIVE { + p.lastActiveAt.CompareAndSwap(nil, pointer.To(time.Now())) + } + if p.state.CompareAndSwap(oldState, state) { + break + } + } + + p.params.Logger.Debugw("updating participant state", "state", state.String()) + p.dirty.Store(true) + + go p.listener().OnStateChange(p) + + if state == livekit.ParticipantInfo_DISCONNECTED && oldState == livekit.ParticipantInfo_ACTIVE { + p.disconnectedAt.Store(pointer.To(time.Now())) + prometheus.RecordSessionDuration(int(p.ProtocolVersion()), time.Since(*p.lastActiveAt.Load())) + } +} + +func (p *ParticipantImpl) setIsPublisher(isPublisher bool) { + if p.isPublisher.Swap(isPublisher) == isPublisher { + return + } + + p.lock.Lock() + p.requireBroadcast = true + p.lock.Unlock() + + p.dirty.Store(true) + + // trigger update as well if participant is already fully connected + if p.State() == livekit.ParticipantInfo_ACTIVE { + p.listener().OnParticipantUpdate(p) + } +} + +// when the server has an offer for participant +func (p *ParticipantImpl) onSubscriberOffer(offer webrtc.SessionDescription, offerId uint32, midToTrackID map[string]string) error { + p.subLogger.Debugw( + "sending offer", + "transport", livekit.SignalTarget_SUBSCRIBER, + "offer", offer, + "offerId", offerId, + "midToTrackID", midToTrackID, + ) + return p.sendSdpOffer(offer, offerId, midToTrackID) +} + +func (p *ParticipantImpl) removePublishedTrack(track types.MediaTrack) { + p.RemovePublishedTrack(track, false) + if p.ProtocolVersion().SupportsUnpublish() { + p.sendTrackUnpublished(track.ID()) + } else { + // for older clients that don't support unpublish, mute to avoid them sending data + p.sendTrackMuted(track.ID(), true) + } +} + +// when a new remoteTrack is created, creates a Track and adds it to room +func (p *ParticipantImpl) onMediaTrack(rtcTrack *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + if p.IsDisconnected() { + return + } + + var codec webrtc.RTPCodecParameters + var fromSdp bool + if rtcTrack.Kind() == webrtc.RTPCodecTypeVideo && p.params.ClientInfo.FireTrackByRTPPacket() { + if rtcTrack.Codec().PayloadType == 0 { + go func() { + // wait for the first packet to determine the codec + bytes := make([]byte, 1500) + _, _, err := rtcTrack.Read(bytes) + if err != nil { + if !errors.Is(err, io.EOF) { + p.params.Logger.Warnw( + "could not read first packet to determine codec, track will be ignored", err, + "trackID", rtcTrack.ID(), + "StreamID", rtcTrack.StreamID(), + ) + } + return + } + p.onMediaTrack(rtcTrack, rtpReceiver) + }() + return + } + codec = rtcTrack.Codec() + } else { + // track fired by sdp + codecs := rtpReceiver.GetParameters().Codecs + if len(codecs) == 0 { + p.pubLogger.Errorw( + "no negotiated codecs for track, track will be ignored", nil, + "trackID", rtcTrack.ID(), + "StreamID", rtcTrack.StreamID(), + ) + return + } + codec = codecs[0] + fromSdp = true + } + p.params.Logger.Debugw( + "onMediaTrack", + "codec", codec, + "payloadType", codec.PayloadType, + "fromSdp", fromSdp, + "parameters", rtpReceiver.GetParameters(), + ) + + var track sfu.TrackRemote = sfu.NewTrackRemoteFromSdp(rtcTrack, codec) + publishedTrack, isNewTrack, isReceiverAdded, sdpRids := p.mediaTrackReceived(track, rtpReceiver) + if publishedTrack == nil { + p.pubLogger.Debugw( + "webrtc track published but can't find MediaTrack in pendingTracks", + "kind", track.Kind().String(), + "webrtcTrackID", track.ID(), + "rid", track.RID(), + "ssrc", track.SSRC(), + "rtxSsrc", track.RtxSSRC(), + "mime", mime.NormalizeMimeType(codec.MimeType), + "isReceiverAdded", isReceiverAdded, + "sdpRids", logger.StringSlice(sdpRids[:]), + ) + return + } + + if !p.CanPublishSource(publishedTrack.Source()) { + p.pubLogger.Warnw("no permission to publish mediaTrack", nil, + "source", publishedTrack.Source(), + ) + p.removePublishedTrack(publishedTrack) + return + } + + p.TransportManager.RTPStreamPublished( + uint32(track.SSRC()), + p.TransportManager.GetPublisherMid(rtpReceiver), + track.RID(), + ) + + p.setIsPublisher(true) + p.dirty.Store(true) + + p.pubLogger.Infow( + "mediaTrack published", + "kind", track.Kind().String(), + "trackID", publishedTrack.ID(), + "webrtcTrackID", track.ID(), + "rid", track.RID(), + "ssrc", track.SSRC(), + "rtxSsrc", track.RtxSSRC(), + "mime", mime.NormalizeMimeType(codec.MimeType), + "trackInfo", logger.Proto(publishedTrack.ToProto()), + "fromSdp", fromSdp, + "isReceiverAdded", isReceiverAdded, + "sdpRids", logger.StringSlice(sdpRids[:]), + ) + + if !isNewTrack && !publishedTrack.HasPendingCodec() && p.IsReady() { + p.listener().OnTrackUpdated(p, publishedTrack) + } +} + +func (p *ParticipantImpl) handlePendingRemoteTracks() { + p.pendingTracksLock.Lock() + pendingTracks := p.pendingRemoteTracks + p.pendingRemoteTracks = nil + p.pendingTracksLock.Unlock() + for _, rt := range pendingTracks { + p.onMediaTrack(rt.track, rt.receiver) + } +} + +func (p *ParticipantImpl) onReceivedDataMessage(kind livekit.DataPacket_Kind, data []byte) { + if p.IsDisconnected() || !p.CanPublishData() { + return + } + + p.dataChannelStats.AddBytes(uint64(len(data)), false) + + dp := &livekit.DataPacket{} + if err := proto.Unmarshal(data, dp); err != nil { + p.pubLogger.Warnw("could not parse data packet", err) + return + } + + dp.ParticipantSid = string(p.ID()) + if kind == livekit.DataPacket_RELIABLE && dp.Sequence > 0 { + if p.reliableDataInfo.stopReliableByMigrateOut.Load() { + return + } + + if migrationCache := p.reliableDataInfo.migrateInPubDataCache.Load(); migrationCache != nil { + switch migrationCache.Add(dp) { + case MigrationDataCacheStateWaiting: + // waiting for the reliable sequence to continue from last node + return + + case MigrationDataCacheStateTimeout: + p.reliableDataInfo.migrateInPubDataCache.Store(nil) + // waiting time out, handle all cached messages + cachedMsgs := migrationCache.Get() + if len(cachedMsgs) == 0 { + p.pubLogger.Warnw( + "migration data cache timed out, no cached messages received", nil, + "lastPubReliableSeq", p.params.LastPubReliableSeq, + ) + } else { + p.pubLogger.Warnw( + "migration data cache timed out, handling cached messages", nil, + "cachedFirstSeq", cachedMsgs[0].Sequence, + "cachedLastSeq", cachedMsgs[len(cachedMsgs)-1].Sequence, + "lastPubReliableSeq", p.params.LastPubReliableSeq, + ) + } + for _, cachedDp := range cachedMsgs { + p.handleReceivedDataMessage(kind, cachedDp) + } + return + + case MigrationDataCacheStateDone: + // see the continuous message, drop the cache + p.reliableDataInfo.migrateInPubDataCache.Store(nil) + } + } + } + + p.handleReceivedDataMessage(kind, dp) +} + +func (p *ParticipantImpl) handleReceivedDataMessage(kind livekit.DataPacket_Kind, dp *livekit.DataPacket) { + if kind == livekit.DataPacket_RELIABLE && dp.Sequence > 0 { + if p.reliableDataInfo.lastPubReliableSeq.Load() >= dp.Sequence { + p.params.Logger.Infow( + "received out of order reliable data packet", + "lastPubReliableSeq", p.reliableDataInfo.lastPubReliableSeq.Load(), + "dpSequence", dp.Sequence, + ) + return + } + + p.reliableDataInfo.lastPubReliableSeq.Store(dp.Sequence) + } + + // trust the channel that it came in as the source of truth + dp.Kind = kind + + shouldForwardData := true + shouldForwardMetrics := false + overrideSenderIdentity := true + // only forward on user payloads + switch payload := dp.Value.(type) { + case *livekit.DataPacket_User: + if payload.User == nil { + return + } + u := payload.User + if p.Hidden() { + u.ParticipantSid = "" + u.ParticipantIdentity = "" + } else { + u.ParticipantSid = string(p.ID()) + u.ParticipantIdentity = string(p.params.Identity) + } + if len(dp.DestinationIdentities) != 0 { + u.DestinationIdentities = dp.DestinationIdentities + } else { + dp.DestinationIdentities = u.DestinationIdentities + } + case *livekit.DataPacket_SipDtmf: + if payload.SipDtmf == nil { + return + } + case *livekit.DataPacket_Transcription: + if payload.Transcription == nil { + return + } + if !p.IsAgent() { + shouldForwardData = false + } + case *livekit.DataPacket_ChatMessage: + if payload.ChatMessage == nil { + return + } + if p.IsAgent() && dp.ParticipantIdentity != "" && string(p.params.Identity) != dp.ParticipantIdentity { + overrideSenderIdentity = false + payload.ChatMessage.Generated = true + } + case *livekit.DataPacket_Metrics: + if payload.Metrics == nil { + return + } + shouldForwardData = false + shouldForwardMetrics = true + p.metricTimestamper.Process(payload.Metrics) + case *livekit.DataPacket_RpcRequest: + if payload.RpcRequest == nil { + return + } + p.pubLogger.Infow( + "received RPC request", + "method", payload.RpcRequest.Method, + "rpc_request_id", payload.RpcRequest.Id, + "destinationIdentities", dp.DestinationIdentities, + ) + case *livekit.DataPacket_RpcResponse: + if payload.RpcResponse == nil { + return + } + p.pubLogger.Infow( + "received RPC response", + "rpc_request_id", payload.RpcResponse.RequestId, + ) + + rpcResponse := payload.RpcResponse + switch res := rpcResponse.Value.(type) { + case *livekit.RpcResponse_Payload: + shouldForwardData = !p.handleIncomingRpcResponse(payload.RpcResponse.GetRequestId(), res.Payload, nil) + case *livekit.RpcResponse_Error: + shouldForwardData = !p.handleIncomingRpcResponse(payload.RpcResponse.GetRequestId(), "", &utils.DataChannelRpcError{ + Code: utils.DataChannelRpcErrorCode(res.Error.GetCode()), + Message: res.Error.GetMessage(), + Data: res.Error.GetData(), + }) + } + case *livekit.DataPacket_RpcAck: + if payload.RpcAck == nil { + return + } + p.pubLogger.Infow( + "received RPC ack", + "rpc_request_id", payload.RpcAck.RequestId, + ) + + shouldForwardData = !p.handleIncomingRpcAck(payload.RpcAck.GetRequestId()) + case *livekit.DataPacket_StreamHeader: + if payload.StreamHeader == nil { + return + } + + prometheus.RecordDataPacketStream(payload.StreamHeader, len(dp.DestinationIdentities)) + + if p.IsAgent() && dp.ParticipantIdentity != "" && string(p.params.Identity) != dp.ParticipantIdentity { + switch contentHeader := payload.StreamHeader.ContentHeader.(type) { + case *livekit.DataStream_Header_TextHeader: + contentHeader.TextHeader.Generated = true + overrideSenderIdentity = false + default: + overrideSenderIdentity = true + } + } + case *livekit.DataPacket_StreamChunk: + if payload.StreamChunk == nil { + return + } + case *livekit.DataPacket_StreamTrailer: + if payload.StreamTrailer == nil { + return + } + case *livekit.DataPacket_EncryptedPacket: + if payload.EncryptedPacket == nil { + return + } + default: + p.pubLogger.Warnw("received unsupported data packet", nil, "payload", payload) + } + + // SFU typically asserts the sender's identity. However, agents are able to + // publish data on behalf of the participant in case of transcriptions/text streams + // in those cases we'd leave the existing identity on the data packet alone. + if overrideSenderIdentity { + if p.Hidden() { + dp.ParticipantIdentity = "" + } else { + dp.ParticipantIdentity = string(p.params.Identity) + } + } + + if shouldForwardData { + p.listener().OnDataMessage(p, kind, dp) + } + if shouldForwardMetrics { + p.listener().OnMetrics(p, dp) + } +} + +func (p *ParticipantImpl) onReceivedDataMessageUnlabeled(data []byte) { + if p.IsDisconnected() || !p.CanPublishData() { + return + } + + p.dataChannelStats.AddBytes(uint64(len(data)), false) + + p.listener().OnDataMessageUnlabeled(p, data) +} + +func (p *ParticipantImpl) onICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + if p.IsDisconnected() || p.IsClosed() { + return nil + } + + if target == livekit.SignalTarget_SUBSCRIBER && p.MigrateState() == types.MigrateStateInit { + return nil + } + + return p.sendICECandidate(c, target) +} + +func (p *ParticipantImpl) onPublisherInitialConnected() { + p.SetMigrateState(types.MigrateStateComplete) + + if p.supervisor != nil { + p.supervisor.SetPublisherPeerConnectionConnected(true) + } + + if p.params.UseOneShotSignallingMode || p.params.UseSinglePeerConnection { + go p.subscriberRTCPWorker() + + p.setDownTracksConnected() + } + + p.pubRTCPQueue.Start() +} + +func (p *ParticipantImpl) onSubscriberInitialConnected() { + go p.subscriberRTCPWorker() + + p.setDownTracksConnected() +} + +func (p *ParticipantImpl) onPrimaryTransportInitialConnected() { + if !p.hasPendingMigratedTrack() && len(p.GetPublishedTracks()) == 0 { + // if there are no published tracks, declare migration complete on primary transport initial connect, + // else, wait for all tracks to be published and publisher peer connection established + p.SetMigrateState(types.MigrateStateComplete) + } + + if !p.sessionStartRecorded.Swap(true) { + prometheus.RecordSessionStartTime(int(p.ProtocolVersion()), time.Since(p.params.SessionStartTime)) + } + p.updateState(livekit.ParticipantInfo_ACTIVE) +} + +func (p *ParticipantImpl) onPrimaryTransportFullyEstablished() { + p.replayJoiningReliableMessages() +} + +func (p *ParticipantImpl) clearDisconnectTimer() { + p.lock.Lock() + if p.disconnectTimer != nil { + p.disconnectTimer.Stop() + p.disconnectTimer = nil + } + p.lock.Unlock() +} + +func (p *ParticipantImpl) setupDisconnectTimer() { + p.clearDisconnectTimer() + + p.lock.Lock() + p.disconnectTimer = time.AfterFunc(disconnectCleanupDuration, func() { + p.clearDisconnectTimer() + + if p.IsClosed() || p.IsDisconnected() { + return + } + _ = p.Close(true, types.ParticipantCloseReasonPeerConnectionDisconnected, false) + }) + p.lock.Unlock() +} + +func (p *ParticipantImpl) onAnyTransportFailed() { + if p.params.UseOneShotSignallingMode { + // as there is no way to notify participant, close the participant on transport failure + _ = p.Close(false, types.ParticipantCloseReasonPeerConnectionDisconnected, false) + return + } + + p.sendLeaveRequest( + types.ParticipantCloseReasonPeerConnectionDisconnected, + true, // isExpectedToResume + false, // isExpectedToReconnect + true, // sendOnlyIfSupportingLeaveRequestWithAction + ) + + // clients support resuming of connections when signalling becomes disconnected + p.CloseSignalConnection(types.SignallingCloseReasonTransportFailure) + + // detect when participant has actually left. + p.setupDisconnectTimer() +} + +// subscriberRTCPWorker sends SenderReports periodically when the participant is subscribed to +// other publishedTracks in the room. +func (p *ParticipantImpl) subscriberRTCPWorker() { + defer func() { + if r := Recover(p.GetLogger()); r != nil { + os.Exit(1) + } + }() + for { + if p.IsDisconnected() { + return + } + + subscribedTracks := p.SubscriptionManager.GetSubscribedTracks() + + // send in batches of sdBatchSize + batchSize := 0 + var pkts []rtcp.Packet + var sd []rtcp.SourceDescriptionChunk + for _, subTrack := range subscribedTracks { + sr := subTrack.DownTrack().CreateSenderReport() + chunks := subTrack.DownTrack().CreateSourceDescriptionChunks() + if sr == nil || chunks == nil { + continue + } + + pkts = append(pkts, sr) + sd = append(sd, chunks...) + numItems := 0 + for _, chunk := range chunks { + numItems += len(chunk.Items) + } + batchSize = batchSize + 1 + numItems + if batchSize >= sdBatchSize { + if len(sd) != 0 { + pkts = append(pkts, &rtcp.SourceDescription{Chunks: sd}) + } + if err := p.TransportManager.WriteSubscriberRTCP(pkts); err != nil { + if IsEOF(err) { + return + } + p.subLogger.Errorw("could not send down track reports", err) + } + + pkts = pkts[:0] + sd = sd[:0] + batchSize = 0 + } + } + + if len(pkts) != 0 || len(sd) != 0 { + if len(sd) != 0 { + pkts = append(pkts, &rtcp.SourceDescription{Chunks: sd}) + } + if err := p.TransportManager.WriteSubscriberRTCP(pkts); err != nil { + if IsEOF(err) { + return + } + p.subLogger.Errorw("could not send down track reports", err) + } + } + + time.Sleep(3 * time.Second) + } +} + +func (p *ParticipantImpl) onStreamStateChange(update *streamallocator.StreamStateUpdate) error { + if len(update.StreamStates) == 0 { + return nil + } + + streamStateUpdate := &livekit.StreamStateUpdate{} + for _, streamStateInfo := range update.StreamStates { + state := livekit.StreamState_ACTIVE + if streamStateInfo.State == streamallocator.StreamStatePaused { + state = livekit.StreamState_PAUSED + } + streamStateUpdate.StreamStates = append(streamStateUpdate.StreamStates, &livekit.StreamStateInfo{ + ParticipantSid: string(streamStateInfo.ParticipantID), + TrackSid: string(streamStateInfo.TrackID), + State: state, + }) + } + + return p.sendStreamStateUpdate(streamStateUpdate) +} + +func (p *ParticipantImpl) onSubscribedMaxQualityChange( + trackID livekit.TrackID, + trackInfo *livekit.TrackInfo, + subscribedQualities []*livekit.SubscribedCodec, + maxSubscribedQualities []types.SubscribedCodecQuality, +) error { + if p.params.DisableDynacast { + return nil + } + + if len(subscribedQualities) == 0 { + return nil + } + + // send layer info about max subscription changes to telemetry + for _, maxSubscribedQuality := range maxSubscribedQualities { + ti := &livekit.TrackInfo{ + Sid: trackInfo.Sid, + Type: trackInfo.Type, + } + for _, layer := range buffer.GetVideoLayersForMimeType(maxSubscribedQuality.CodecMime, trackInfo) { + if layer.Quality == maxSubscribedQuality.Quality { + ti.Width = layer.Width + ti.Height = layer.Height + break + } + } + p.params.Telemetry.TrackMaxSubscribedVideoQuality( + context.Background(), + p.ID(), + ti, + maxSubscribedQuality.CodecMime, + maxSubscribedQuality.Quality, + ) + } + + // normalize the codec name + for _, subscribedQuality := range subscribedQualities { + subscribedQuality.Codec = strings.ToLower(strings.TrimPrefix(subscribedQuality.Codec, mime.MimeTypePrefixVideo)) + } + + subscribedQualityUpdate := &livekit.SubscribedQualityUpdate{ + TrackSid: string(trackID), + SubscribedQualities: subscribedQualities[0].Qualities, // for compatible with old client + SubscribedCodecs: subscribedQualities, + } + + p.pubLogger.Debugw( + "sending max subscribed quality", + "trackID", trackID, + "qualities", subscribedQualities, + "max", maxSubscribedQualities, + ) + return p.sendSubscribedQualityUpdate(subscribedQualityUpdate) +} + +func (p *ParticipantImpl) onSubscribedAudioCodecChange( + trackID livekit.TrackID, + codecs []*livekit.SubscribedAudioCodec, +) error { + if p.params.DisableDynacast { + return nil + } + + if len(codecs) == 0 { + return nil + } + + // normalize the codec name + for _, codec := range codecs { + codec.Codec = strings.ToLower(strings.TrimPrefix(codec.Codec, mime.MimeTypePrefixAudio)) + } + + subscribedAudioCodecUpdate := &livekit.SubscribedAudioCodecUpdate{ + TrackSid: string(trackID), + SubscribedAudioCodecs: codecs, + } + p.pubLogger.Debugw( + "sending subscribed audio codec update", + "trackID", trackID, + "update", logger.Proto(subscribedAudioCodecUpdate), + ) + return p.sendSubscribedAudioCodecUpdate(subscribedAudioCodecUpdate) +} + +func (p *ParticipantImpl) addPendingTrackLocked(req *livekit.AddTrackRequest) *livekit.TrackInfo { + if req.Sid != "" { + track := p.GetPublishedTrack(livekit.TrackID(req.Sid)) + if track == nil { + p.pubLogger.Infow("could not find existing track for multi-codec simulcast", "trackID", req.Sid) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_FOUND, + Request: &livekit.RequestResponse_AddTrack{ + AddTrack: utils.CloneProto(req), + }, + }) + return nil + } + + track.(*MediaTrack).UpdateCodecInfo(req.SimulcastCodecs) + return track.ToProto() + } + + backupCodecPolicy := req.BackupCodecPolicy + + // enable simulcast codec for audio by default + if (backupCodecPolicy != livekit.BackupCodecPolicy_REGRESSION && req.Type == livekit.TrackType_AUDIO) || + (backupCodecPolicy != livekit.BackupCodecPolicy_SIMULCAST && p.params.DisableCodecRegression) { + backupCodecPolicy = livekit.BackupCodecPolicy_SIMULCAST + } + + cloneLayers := func(layers []*livekit.VideoLayer) []*livekit.VideoLayer { + if len(layers) == 0 { + return nil + } + + clonedLayers := make([]*livekit.VideoLayer, 0, len(layers)) + for _, l := range layers { + clonedLayers = append(clonedLayers, utils.CloneProto(l)) + } + slices.SortFunc(clonedLayers, func(i, j *livekit.VideoLayer) int { + return int(i.Quality) - int(j.Quality) + }) + return clonedLayers + } + + ti := &livekit.TrackInfo{ + Type: req.Type, + Name: req.Name, + Width: req.Width, + Height: req.Height, + Muted: req.Muted, + DisableDtx: req.DisableDtx, + Source: req.Source, + Layers: cloneLayers(req.Layers), + DisableRed: req.DisableRed, + Stereo: req.Stereo, + Encryption: req.Encryption, + Stream: req.Stream, + BackupCodecPolicy: backupCodecPolicy, + AudioFeatures: sutils.DedupeSlice(req.AudioFeatures), + } + if req.Stereo && !slices.Contains(ti.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO) { + ti.AudioFeatures = append(ti.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO) + } + if req.DisableDtx && !slices.Contains(ti.AudioFeatures, livekit.AudioTrackFeature_TF_NO_DTX) { + ti.AudioFeatures = append(ti.AudioFeatures, livekit.AudioTrackFeature_TF_NO_DTX) + } + if ti.Stream == "" { + ti.Stream = StreamFromTrackSource(ti.Source) + } + p.setTrackID(req.Cid, ti) + + if len(req.SimulcastCodecs) == 0 { + // clients not supporting simulcast codecs, synthesise a codec + videoLayerMode := livekit.VideoLayer_MODE_UNUSED + if p.params.ClientInfo.isOBS() { + videoLayerMode = livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM_INCOMPLETE_RTCP_SR + } + ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ + Cid: req.Cid, + Layers: cloneLayers(req.Layers), + VideoLayerMode: videoLayerMode, + }) + } else { + seenCodecs := make(map[string]struct{}) + for _, codec := range req.SimulcastCodecs { + if codec.Codec == "" { + p.pubLogger.Warnw( + "simulcast codec without mime type", nil, + "trackID", ti.Sid, + "track", logger.Proto(ti), + "addTrackRequest", logger.Proto(req), + ) + } + + mimeType := codec.Codec + videoLayerMode := codec.VideoLayerMode + switch req.Type { + case livekit.TrackType_VIDEO: + if !mime.IsMimeTypeStringVideo(mimeType) { + mimeType = mime.MimeTypePrefixVideo + mimeType + } + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mimeType}) { + altCodec := selectAlternativeVideoCodec(p.enabledPublishCodecs) + p.pubLogger.Infow( + "falling back to alternative video codec", + "codec", mimeType, + "altCodec", altCodec, + "enabledPublishCodecs", logger.ProtoSlice(p.enabledPublishCodecs), + "trackID", ti.Sid, + ) + // select an alternative MIME type that's generally supported + mimeType = altCodec + } + if videoLayerMode == livekit.VideoLayer_MODE_UNUSED { + if mime.IsMimeTypeStringSVCCapable(mimeType) { + videoLayerMode = livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM + } else { + if p.params.ClientInfo.isOBS() { + videoLayerMode = livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM_INCOMPLETE_RTCP_SR + } else { + videoLayerMode = livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM + } + } + } + + case livekit.TrackType_AUDIO: + if !mime.IsMimeTypeStringAudio(mimeType) { + mimeType = mime.MimeTypePrefixAudio + mimeType + } + if !IsCodecEnabled(p.enabledPublishCodecs, webrtc.RTPCodecCapability{MimeType: mimeType}) { + altCodec := selectAlternativeAudioCodec(p.enabledPublishCodecs) + p.pubLogger.Infow( + "falling back to alternative audio codec", + "codec", mimeType, + "altCodec", altCodec, + "enabledPublishCodecs", logger.ProtoSlice(p.enabledPublishCodecs), + "trackID", ti.Sid, + ) + // select an alternative MIME type that's generally supported + mimeType = altCodec + } + } + + if _, ok := seenCodecs[mimeType]; ok || mimeType == "" { + continue + } + seenCodecs[mimeType] = struct{}{} + + ti.Codecs = append(ti.Codecs, &livekit.SimulcastCodecInfo{ + MimeType: mimeType, + Cid: codec.Cid, + VideoLayerMode: videoLayerMode, + }) + } + + // set up layers with codec specific layers, + // fall back to common layers if codec specific layer is not available + for idx, codec := range ti.Codecs { + found := false + for _, simulcastCodec := range req.SimulcastCodecs { + if mime.GetMimeTypeCodec(codec.MimeType) != mime.NormalizeMimeTypeCodec(simulcastCodec.Codec) { + continue + } + + if len(simulcastCodec.Layers) != 0 { + codec.Layers = cloneLayers(simulcastCodec.Layers) + } else { + codec.Layers = cloneLayers(req.Layers) + } + found = true + break + } + + if !found { + // could happen if an alternate codec is selected and that is not in the simulcast codecs list + codec.Layers = cloneLayers(req.Layers) + } + + // populate simulcast flag for compatibility, true if primary codec is not SVC and has multiple layers + if idx == 0 && codec.VideoLayerMode != livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM && len(codec.Layers) > 1 { + ti.Simulcast = true + } + } + } + + p.params.Telemetry.TrackPublishRequested(context.Background(), p.ID(), p.Identity(), utils.CloneProto(ti)) + + if p.supervisor != nil { + p.supervisor.AddPublication(livekit.TrackID(ti.Sid)) + p.supervisor.SetPublicationMute(livekit.TrackID(ti.Sid), ti.Muted) + } + + if p.getPublishedTrackBySignalCid(req.Cid) != nil || p.getPublishedTrackBySdpCid(req.Cid) != nil || p.pendingTracks[req.Cid] != nil { + if p.pendingTracks[req.Cid] == nil { + pti := &pendingTrackInfo{ + trackInfos: []*livekit.TrackInfo{ti}, + createdAt: time.Now(), + queued: true, + } + if ti.Type == livekit.TrackType_VIDEO { + pti.sdpRids = buffer.DefaultVideoLayersRid // could get updated from SDP + } + p.pendingTracks[req.Cid] = pti + } else { + p.pendingTracks[req.Cid].trackInfos = append(p.pendingTracks[req.Cid].trackInfos, ti) + } + p.pubLogger.Infow( + "pending track queued", + "trackID", ti.Sid, + "request", logger.Proto(req), + "pendingTrack", p.pendingTracks[req.Cid], + ) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_QUEUED, + Request: &livekit.RequestResponse_AddTrack{ + AddTrack: utils.CloneProto(req), + }, + }) + return nil + } + + pti := &pendingTrackInfo{ + trackInfos: []*livekit.TrackInfo{ti}, + createdAt: time.Now(), + } + if ti.Type == livekit.TrackType_VIDEO { + pti.sdpRids = buffer.DefaultVideoLayersRid // could get updated from SDP + } + p.pendingTracks[req.Cid] = pti + p.pubLogger.Debugw( + "pending track added", + "trackID", ti.Sid, + "request", logger.Proto(req), + "pendingTrack", p.pendingTracks[req.Cid], + ) + return ti +} + +func (p *ParticipantImpl) GetPendingTrack(trackID livekit.TrackID) *livekit.TrackInfo { + p.pendingTracksLock.RLock() + defer p.pendingTracksLock.RUnlock() + + for _, t := range p.pendingTracks { + if livekit.TrackID(t.trackInfos[0].Sid) == trackID { + return t.trackInfos[0] + } + } + + return nil +} + +func (p *ParticipantImpl) HasConnected() bool { + return p.TransportManager.HasSubscriberEverConnected() || p.TransportManager.HasPublisherEverConnected() +} + +func (p *ParticipantImpl) SetTrackMuted(mute *livekit.MuteTrackRequest, fromAdmin bool) *livekit.TrackInfo { + // when request is coming from admin, send message to current participant + if fromAdmin { + p.sendTrackMuted(livekit.TrackID(mute.Sid), mute.Muted) + } + + return p.setTrackMuted(mute, fromAdmin) +} + +func (p *ParticipantImpl) setTrackMuted(mute *livekit.MuteTrackRequest, fromAdmin bool) *livekit.TrackInfo { + trackID := livekit.TrackID(mute.Sid) + p.dirty.Store(true) + if p.supervisor != nil { + p.supervisor.SetPublicationMute(trackID, mute.Muted) + } + + track, changed := p.UpTrackManager.SetPublishedTrackMuted(trackID, mute.Muted) + var trackInfo *livekit.TrackInfo + if track != nil { + trackInfo = track.ToProto() + } + + // update mute status in any pending/queued add track requests too + p.pendingTracksLock.RLock() + for _, pti := range p.pendingTracks { + for i, ti := range pti.trackInfos { + if livekit.TrackID(ti.Sid) == trackID { + ti = utils.CloneProto(ti) + changed = changed || ti.Muted != mute.Muted + ti.Muted = mute.Muted + pti.trackInfos[i] = ti + if trackInfo == nil { + trackInfo = ti + } + } + } + } + p.pendingTracksLock.RUnlock() + + if trackInfo != nil && changed { + if mute.Muted { + p.params.Telemetry.TrackMuted(context.Background(), p.ID(), trackInfo) + } else { + p.params.Telemetry.TrackUnmuted(context.Background(), p.ID(), trackInfo) + } + } + + if trackInfo == nil && !fromAdmin { + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_FOUND, + Request: &livekit.RequestResponse_Mute{ + Mute: utils.CloneProto(mute), + }, + }) + } + + return trackInfo +} + +func (p *ParticipantImpl) mediaTrackReceived( + track sfu.TrackRemote, + rtpReceiver *webrtc.RTPReceiver, +) (*MediaTrack, bool, bool, buffer.VideoLayersRid) { + p.pendingTracksLock.Lock() + newTrack := false + + mid := p.TransportManager.GetPublisherMid(rtpReceiver) + p.pubLogger.Debugw( + "media track received", + "kind", track.Kind().String(), + "trackID", track.ID(), + "rid", track.RID(), + "ssrc", track.SSRC(), + "rtxSsrc", track.RtxSSRC(), + "mime", mime.NormalizeMimeType(track.Codec().MimeType), + "mid", mid, + ) + if mid == "" { + p.pendingRemoteTracks = append( + p.pendingRemoteTracks, + &pendingRemoteTrack{track: track.RTCTrack(), receiver: rtpReceiver}, + ) + p.pendingTracksLock.Unlock() + p.pubLogger.Warnw("could not get mid for track", nil, "trackID", track.ID()) + return nil, false, false, buffer.VideoLayersRid{} + } + + // use existing media track to handle simulcast + var pubTime time.Duration + var isMigrated bool + var ridsFromSdp buffer.VideoLayersRid + mt, ok := p.getPublishedTrackBySdpCid(track.ID()).(*MediaTrack) + if !ok { + signalCid, ti, sdpRids, migrated, createdAt := p.getPendingTrack(track.ID(), ToProtoTrackKind(track.Kind()), true) + ridsFromSdp = sdpRids + if ti == nil { + p.pendingRemoteTracks = append( + p.pendingRemoteTracks, + &pendingRemoteTrack{track: track.RTCTrack(), receiver: rtpReceiver}, + ) + p.pendingTracksLock.Unlock() + return nil, false, false, ridsFromSdp + } + isMigrated = migrated + + // check if the migrated track has correct codec + if migrated && len(ti.Codecs) > 0 { + parameters := rtpReceiver.GetParameters() + var codecFound int + for _, c := range ti.Codecs { + for _, nc := range parameters.Codecs { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { + codecFound++ + break + } + } + } + if codecFound != len(ti.Codecs) { + p.pubLogger.Warnw("migrated track codec mismatched", nil, "track", logger.Proto(ti), "webrtcCodec", parameters) + p.pendingTracksLock.Unlock() + p.IssueFullReconnect(types.ParticipantCloseReasonMigrateCodecMismatch) + return nil, false, false, ridsFromSdp + } + } + + ti.MimeType = track.Codec().MimeType + // set mime_type for tracks that the AddTrackRequest do not have simulcast_codecs set + if len(ti.Codecs) == 1 && ti.Codecs[0].MimeType == "" { + ti.Codecs[0].MimeType = track.Codec().MimeType + } + if utils.TimedVersionFromProto(ti.Version).IsZero() { + // only assign version on a fresh publish, i. e. avoid updating version in scenarios like migration + ti.Version = p.params.VersionGenerator.Next().ToProto() + } + + mimeType := mime.NormalizeMimeType(ti.MimeType) + for _, layer := range ti.Layers { + layer.SpatialLayer = buffer.VideoQualityToSpatialLayer(mimeType, layer.Quality, ti) + layer.Rid = buffer.VideoQualityToRid(mimeType, layer.Quality, ti, sdpRids) + } + + mt = p.addMediaTrack(signalCid, ti) + newTrack = true + + // if the addTrackRequest is sent before participant active then it means the client tries to publish + // before fully connected, in this case we only record the time when the participant is active since + // we want this metric to represent the time cost by publishing. + if activeAt := p.lastActiveAt.Load(); activeAt != nil && createdAt.Before(*activeAt) { + createdAt = *activeAt + } + pubTime = time.Since(createdAt) + p.dirty.Store(true) + } + p.pendingTracksLock.Unlock() + + _, isReceiverAdded := mt.AddReceiver(rtpReceiver, track, mid) + + if newTrack { + go func() { + // TODO: remove this after we know where the high delay is coming from + if pubTime > 3*time.Second { + p.pubLogger.Infow( + "track published with high delay", + "trackID", mt.ID(), + "track", logger.Proto(mt.ToProto()), + "cost", pubTime.Milliseconds(), + "rid", track.RID(), + "mime", track.Codec().MimeType, + "isMigrated", isMigrated, + ) + } else { + p.pubLogger.Debugw( + "track published", + "trackID", mt.ID(), + "track", logger.Proto(mt.ToProto()), + "cost", pubTime.Milliseconds(), + "isMigrated", isMigrated, + ) + } + + prometheus.RecordPublishTime( + p.params.Country, + mt.Source(), + mt.Kind(), + pubTime, + p.GetClientInfo().GetSdk(), + p.Kind(), + ) + p.handleTrackPublished(mt, isMigrated) + }() + } + + return mt, newTrack, isReceiverAdded, ridsFromSdp +} + +func (p *ParticipantImpl) addMigratedTrack(cid string, ti *livekit.TrackInfo) *MediaTrack { + p.pubLogger.Infow("add migrated track", "cid", cid, "trackID", ti.Sid, "track", logger.Proto(ti)) + rtpReceiver := p.TransportManager.GetPublisherRTPReceiver(ti.Mid) + if rtpReceiver == nil { + p.pubLogger.Errorw( + "could not find receiver for migrated track", nil, + "trackID", ti.Sid, + "mid", ti.Mid, + ) + return nil + } + + mt := p.addMediaTrack(cid, ti) + + potentialCodecs := make([]webrtc.RTPCodecParameters, 0, len(ti.Codecs)) + parameters := rtpReceiver.GetParameters() + for _, c := range ti.Codecs { + for _, nc := range parameters.Codecs { + if mime.IsMimeTypeStringEqual(nc.MimeType, c.MimeType) { + potentialCodecs = append(potentialCodecs, nc) + break + } + } + } + // check for mime_type for tracks that do not have simulcast_codecs set + if ti.MimeType != "" { + for _, nc := range parameters.Codecs { + if mime.IsMimeTypeStringEqual(nc.MimeType, ti.MimeType) { + alreadyAdded := false + for _, pc := range potentialCodecs { + if mime.IsMimeTypeStringEqual(pc.MimeType, ti.MimeType) { + alreadyAdded = true + break + } + } + if !alreadyAdded { + potentialCodecs = append(potentialCodecs, nc) + } + break + } + } + } + mt.SetPotentialCodecs(potentialCodecs, parameters.HeaderExtensions) + + for _, codec := range ti.Codecs { + for ssrc, info := range p.params.SimTracks { + if info.Mid == codec.Mid && !info.IsRepairStream { + mt.SetLayerSsrcsForRid(mime.NormalizeMimeType(codec.MimeType), info.StreamID, ssrc, info.RepairSSRC) + } + } + } + + return mt +} + +func (p *ParticipantImpl) addMediaTrack(signalCid string, ti *livekit.TrackInfo) *MediaTrack { + mt := NewMediaTrack(MediaTrackParams{ + ParticipantID: p.ID, + ParticipantIdentity: p.params.Identity, + ParticipantVersion: p.version.Load(), + ParticipantCountry: p.params.Country, + BufferFactory: p.params.Config.BufferFactory, + ReceiverConfig: p.params.Config.Receiver, + AudioConfig: p.params.AudioConfig, + VideoConfig: p.params.VideoConfig, + Telemetry: p.params.Telemetry, + Logger: LoggerWithTrack(p.pubLogger, livekit.TrackID(ti.Sid), false), + Reporter: p.params.Reporter.WithTrack(ti.Sid), + SubscriberConfig: p.params.Config.Subscriber, + PLIThrottleConfig: p.params.PLIThrottleConfig, + SimTracks: p.params.SimTracks, + OnRTCP: p.postRtcp, + ForwardStats: p.params.ForwardStats, + OnTrackEverSubscribed: p.sendTrackHasBeenSubscribed, + ShouldRegressCodec: func() bool { + return p.helper().ShouldRegressCodec() + }, + PreferVideoSizeFromMedia: p.params.PreferVideoSizeFromMedia, + EnableRTPStreamRestartDetection: p.params.EnableRTPStreamRestartDetection, + UpdateTrackInfoByVideoSizeChange: p.params.UseOneShotSignallingMode, + ForceBackupCodecPolicySimulcast: p.params.ForceBackupCodecPolicySimulcast, + }, ti) + + mt.OnSubscribedMaxQualityChange(p.onSubscribedMaxQualityChange) + mt.OnSubscribedAudioCodecChange(p.onSubscribedAudioCodecChange) + + // add to published and clean up pending + if p.supervisor != nil { + p.supervisor.SetPublishedTrack(livekit.TrackID(ti.Sid), mt) + } + p.UpTrackManager.AddPublishedTrack(mt) + + pti := p.pendingTracks[signalCid] + if pti != nil { + if p.pendingPublishingTracks[livekit.TrackID(ti.Sid)] != nil { + p.pubLogger.Infow("unexpected pending publish track", "trackID", ti.Sid) + } + p.pendingPublishingTracks[livekit.TrackID(ti.Sid)] = &pendingTrackInfo{ + trackInfos: []*livekit.TrackInfo{pti.trackInfos[0]}, + migrated: pti.migrated, + } + } + + p.pendingTracks[signalCid].trackInfos = p.pendingTracks[signalCid].trackInfos[1:] + if len(p.pendingTracks[signalCid].trackInfos) == 0 { + delete(p.pendingTracks, signalCid) + } else { + p.pendingTracks[signalCid].queued = true + p.pendingTracks[signalCid].createdAt = time.Now() + } + + trackID := livekit.TrackID(ti.Sid) + mt.AddOnClose(func(isExpectedToResume bool) { + if p.supervisor != nil { + p.supervisor.ClearPublishedTrack(trackID, mt) + } + + p.params.Telemetry.TrackUnpublished( + context.Background(), + p.ID(), + p.Identity(), + mt.ToProto(), + !isExpectedToResume, + ) + + p.pendingTracksLock.Lock() + if pti := p.pendingTracks[signalCid]; pti != nil { + p.sendTrackPublished(signalCid, pti.trackInfos[0]) + pti.queued = false + } + p.pendingTracksLock.Unlock() + p.handlePendingRemoteTracks() + + p.dirty.Store(true) + + p.pubLogger.Debugw( + "track unpublished", + "trackID", ti.Sid, + "expectedToResume", isExpectedToResume, + "track", logger.Proto(ti), + ) + p.listener().OnTrackUnpublished(p, mt) + }) + + return mt +} + +func (p *ParticipantImpl) handleTrackPublished(track types.MediaTrack, isMigrated bool) { + p.listener().OnTrackPublished(p, track) + + // send webhook after callbacks are complete, persistence and state handling happens + // in `onTrackPublished` cb + p.params.Telemetry.TrackPublished( + context.Background(), + p.ID(), + p.Identity(), + track.ToProto(), + !isMigrated, + ) + + p.pendingTracksLock.Lock() + delete(p.pendingPublishingTracks, track.ID()) + p.pendingTracksLock.Unlock() +} + +func (p *ParticipantImpl) hasPendingMigratedTrack() bool { + p.pendingTracksLock.RLock() + defer p.pendingTracksLock.RUnlock() + + for _, t := range p.pendingTracks { + if t.migrated { + return true + } + } + + for _, t := range p.pendingPublishingTracks { + if t.migrated { + return true + } + } + + return false +} + +func (p *ParticipantImpl) onUpTrackManagerClose() { + p.pubRTCPQueue.Stop() +} + +func (p *ParticipantImpl) getPendingTrack(clientId string, kind livekit.TrackType, skipQueued bool) (string, *livekit.TrackInfo, buffer.VideoLayersRid, bool, time.Time) { + signalCid := clientId + pendingInfo := p.pendingTracks[clientId] + if pendingInfo == nil { + track_loop: + for cid, pti := range p.pendingTracks { + ti := pti.trackInfos[0] + for _, c := range ti.Codecs { + if c.Cid == clientId { + pendingInfo = pti + signalCid = cid + break track_loop + } + } + } + + if pendingInfo == nil { + // + // If no match on client id, find first one matching type + // as MediaStreamTrack can change client id when transceiver + // is added to peer connection. + // + for cid, pti := range p.pendingTracks { + ti := pti.trackInfos[0] + if ti.Type == kind { + pendingInfo = pti + signalCid = cid + break + } + } + } + } + + // if still not found, we are done + if pendingInfo == nil || (skipQueued && pendingInfo.queued) { + return signalCid, nil, buffer.VideoLayersRid{}, false, time.Time{} + } + + return signalCid, utils.CloneProto(pendingInfo.trackInfos[0]), pendingInfo.sdpRids, pendingInfo.migrated, pendingInfo.createdAt +} + +func (p *ParticipantImpl) getPendingTrackPrimaryBySdpCid(sdpCid string) *pendingTrackInfo { + for _, pti := range p.pendingTracks { + ti := pti.trackInfos[0] + if len(ti.Codecs) == 0 { + continue + } + if ti.Codecs[0].Cid == sdpCid || ti.Codecs[0].SdpCid == sdpCid { + return pti + } + } + + return nil +} + +// setTrackID either generates a new TrackID for an AddTrackRequest +func (p *ParticipantImpl) setTrackID(cid string, info *livekit.TrackInfo) { + var trackID string + // if already pending, use the same SID + // it is possible to have multiple AddTrackRequests for the same track + if pti := p.pendingTracks[cid]; pti != nil { + trackID = pti.trackInfos[0].Sid + } + + // otherwise generate + if trackID == "" { + trackPrefix := utils.TrackPrefix + switch info.Type { + case livekit.TrackType_VIDEO: + trackPrefix += "V" + case livekit.TrackType_AUDIO: + trackPrefix += "A" + } + + switch info.Source { + case livekit.TrackSource_CAMERA: + trackPrefix += "C" + case livekit.TrackSource_MICROPHONE: + trackPrefix += "M" + case livekit.TrackSource_SCREEN_SHARE: + trackPrefix += "S" + case livekit.TrackSource_SCREEN_SHARE_AUDIO: + trackPrefix += "s" + } + trackID = guid.New(trackPrefix) + } + info.Sid = trackID +} + +func (p *ParticipantImpl) getPublishedTrackBySignalCid(clientId string) types.MediaTrack { + for _, publishedTrack := range p.GetPublishedTracks() { + if publishedTrack.(types.LocalMediaTrack).HasSignalCid(clientId) { + p.pubLogger.Debugw("found track by signal cid", "signalCid", clientId, "trackID", publishedTrack.ID()) + return publishedTrack + } + } + + return nil +} + +func (p *ParticipantImpl) getPublishedTrackBySdpCid(clientId string) types.MediaTrack { + for _, publishedTrack := range p.GetPublishedTracks() { + if publishedTrack.(types.LocalMediaTrack).HasSdpCid(clientId) { + p.pubLogger.Debugw("found track by SDP cid", "sdpCid", clientId, "trackID", publishedTrack.ID()) + return publishedTrack + } + } + + return nil +} + +func (p *ParticipantImpl) DebugInfo() map[string]any { + info := map[string]any{ + "ID": p.ID(), + "State": p.State().String(), + } + + pendingTrackInfo := make(map[string]any) + p.pendingTracksLock.RLock() + for clientID, pti := range p.pendingTracks { + var trackInfos []string + for _, ti := range pti.trackInfos { + trackInfos = append(trackInfos, ti.String()) + } + + pendingTrackInfo[clientID] = map[string]any{ + "TrackInfos": trackInfos, + "Migrated": pti.migrated, + } + } + p.pendingTracksLock.RUnlock() + info["PendingTracks"] = pendingTrackInfo + + info["UpTrackManager"] = p.UpTrackManager.DebugInfo() + + return info +} + +func (p *ParticipantImpl) postRtcp(pkts []rtcp.Packet) { + p.lock.RLock() + migrationTimer := p.migrationTimer + p.lock.RUnlock() + + // Once migration out is active, layers getting added would not be communicated to + // where the publisher is migrating to. Without SSRC, `UnhandleSimulcastInterceptor` + // cannot be set up on the migrating in node. Without that interceptor, simulcast + // probing will fail. + // + // Clients usually send `rid` RTP header extension till they get an RTCP Receiver Report + // from the remote side. So, by curbing RTCP when migration is active, even if a new layer + // get published to this node, client should continue to send `rid` to the new node + // post migration and the new node can do regular simulcast probing (without the + // `UnhandleSimulcastInterceptor`) to fire `OnTrack` on that layer. And when the new node + // sends RTCP Receiver Report back to the client, client will stop `rid`. + if migrationTimer != nil { + return + } + + p.pubRTCPQueue.Enqueue(func(op postRtcpOp) { + if err := op.TransportManager.WritePublisherRTCP(op.pkts); err != nil && !IsEOF(err) { + op.pubLogger.Errorw("could not write RTCP to participant", err) + } + }, postRtcpOp{p, pkts}) +} + +func (p *ParticipantImpl) setDownTracksConnected() { + for _, t := range p.SubscriptionManager.GetSubscribedTracks() { + if dt := t.DownTrack(); dt != nil { + dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(t.ID())}) + dt.SetConnected() + } + } +} + +func (p *ParticipantImpl) cacheForwarderState() { + // if migrating in, get forwarder state from migrating out node to facilitate resume + if fs, err := p.helper().GetSubscriberForwarderState(p); err == nil && fs != nil { + p.lock.Lock() + p.forwarderState = fs + p.lock.Unlock() + + for _, t := range p.SubscriptionManager.GetSubscribedTracks() { + if dt := t.DownTrack(); dt != nil { + dt.SeedState(sfu.DownTrackState{ForwarderState: p.getAndDeleteForwarderState(t.ID())}) + } + } + } +} + +func (p *ParticipantImpl) getAndDeleteForwarderState(trackID livekit.TrackID) *livekit.RTPForwarderState { + p.lock.Lock() + fs := p.forwarderState[trackID] + delete(p.forwarderState, trackID) + p.lock.Unlock() + + return fs +} + +func (p *ParticipantImpl) CacheDownTrack(trackID livekit.TrackID, rtpTransceiver *webrtc.RTPTransceiver, downTrack sfu.DownTrackState) { + p.lock.Lock() + if existing := p.cachedDownTracks[trackID]; existing != nil && existing.transceiver != rtpTransceiver { + p.subLogger.Warnw("cached transceiver changed", nil, "trackID", trackID) + } + p.cachedDownTracks[trackID] = &downTrackState{transceiver: rtpTransceiver, downTrack: downTrack} + p.subLogger.Debugw("caching downtrack", "trackID", trackID) + p.lock.Unlock() +} + +func (p *ParticipantImpl) UncacheDownTrack(rtpTransceiver *webrtc.RTPTransceiver) { + p.lock.Lock() + for trackID, dts := range p.cachedDownTracks { + if dts.transceiver == rtpTransceiver { + if dts := p.cachedDownTracks[trackID]; dts != nil { + p.subLogger.Debugw("uncaching downtrack", "trackID", trackID) + } + delete(p.cachedDownTracks, trackID) + break + } + } + p.lock.Unlock() +} + +func (p *ParticipantImpl) GetCachedDownTrack(trackID livekit.TrackID) (*webrtc.RTPTransceiver, sfu.DownTrackState) { + p.lock.RLock() + defer p.lock.RUnlock() + + if dts := p.cachedDownTracks[trackID]; dts != nil { + return dts.transceiver, dts.downTrack + } + + return nil, sfu.DownTrackState{} +} + +func (p *ParticipantImpl) IssueFullReconnect(reason types.ParticipantCloseReason) { + p.sendLeaveRequest( + reason, + false, // isExpectedToResume + true, // isExpectedToReconnect + false, // sendOnlyIfSupportingLeaveRequestWithAction + ) + + scr := types.SignallingCloseReasonUnknown + switch reason { + case types.ParticipantCloseReasonPublicationError, types.ParticipantCloseReasonMigrateCodecMismatch: + scr = types.SignallingCloseReasonFullReconnectPublicationError + case types.ParticipantCloseReasonSubscriptionError: + scr = types.SignallingCloseReasonFullReconnectSubscriptionError + case types.ParticipantCloseReasonDataChannelError: + scr = types.SignallingCloseReasonFullReconnectDataChannelError + case types.ParticipantCloseReasonNegotiateFailed: + scr = types.SignallingCloseReasonFullReconnectNegotiateFailed + } + p.CloseSignalConnection(scr) + + // a full reconnect == client should connect back with a new session, close current one + p.Close(false, reason, false) +} + +func (p *ParticipantImpl) onPublicationError(trackID livekit.TrackID) { + if p.params.ReconnectOnPublicationError { + p.pubLogger.Infow("issuing full reconnect on publication error", "trackID", trackID) + p.IssueFullReconnect(types.ParticipantCloseReasonPublicationError) + } +} + +func (p *ParticipantImpl) onSubscriptionError(trackID livekit.TrackID, fatal bool, err error) { + signalErr := livekit.SubscriptionError_SE_UNKNOWN + switch { + case errors.Is(err, webrtc.ErrUnsupportedCodec): + signalErr = livekit.SubscriptionError_SE_CODEC_UNSUPPORTED + case errors.Is(err, ErrTrackNotFound): + signalErr = livekit.SubscriptionError_SE_TRACK_NOTFOUND + } + + p.sendSubscriptionResponse(trackID, signalErr) + + if p.params.ReconnectOnSubscriptionError && fatal { + p.subLogger.Infow("issuing full reconnect on subscription error", "trackID", trackID) + p.IssueFullReconnect(types.ParticipantCloseReasonSubscriptionError) + } +} + +func (p *ParticipantImpl) onAnyTransportNegotiationFailed() { + if p.TransportManager.SinceLastSignal() < negotiationFailedTimeout/2 { + p.params.Logger.Infow("negotiation failed, starting full reconnect") + } + p.IssueFullReconnect(types.ParticipantCloseReasonNegotiateFailed) +} + +func (p *ParticipantImpl) UpdateSubscribedQuality(nodeID livekit.NodeID, trackID livekit.TrackID, maxQualities []types.SubscribedCodecQuality) error { + track := p.GetPublishedTrack(trackID) + if track == nil { + p.pubLogger.Debugw("could not find track", "trackID", trackID) + return errors.New("could not find published track") + } + + track.(types.LocalMediaTrack).NotifySubscriberNodeMaxQuality(nodeID, maxQualities) + return nil +} + +func (p *ParticipantImpl) UpdateSubscribedAudioCodecs(nodeID livekit.NodeID, trackID livekit.TrackID, codecs []*livekit.SubscribedAudioCodec) error { + track := p.GetPublishedTrack(trackID) + if track == nil { + p.pubLogger.Debugw("could not find track", "trackID", trackID) + return errors.New("could not find published track") + } + + track.(types.LocalMediaTrack).NotifySubscriptionNode(nodeID, codecs) + return nil +} + +func (p *ParticipantImpl) UpdateMediaLoss(nodeID livekit.NodeID, trackID livekit.TrackID, fractionalLoss uint32) error { + track := p.GetPublishedTrack(trackID) + if track == nil { + p.pubLogger.Debugw("could not find track", "trackID", trackID) + return errors.New("could not find published track") + } + + track.(types.LocalMediaTrack).NotifySubscriberNodeMediaLoss(nodeID, uint8(fractionalLoss)) + return nil +} + +func (p *ParticipantImpl) GetPlayoutDelayConfig() *livekit.PlayoutDelay { + return p.params.PlayoutDelay +} + +func (p *ParticipantImpl) SupportsSyncStreamID() bool { + return p.ProtocolVersion().SupportsSyncStreamID() && !p.params.ClientInfo.isFirefox() && p.params.SyncStreams +} + +func (p *ParticipantImpl) SupportsTransceiverReuse() bool { + if p.params.UseOneShotSignallingMode { + return p.ProtocolVersion().SupportsTransceiverReuse() + } + + return p.ProtocolVersion().SupportsTransceiverReuse() && !p.SupportsSyncStreamID() +} + +func (p *ParticipantImpl) SendDataMessage(kind livekit.DataPacket_Kind, data []byte, sender livekit.ParticipantID, seq uint32) error { + if sender == "" || kind != livekit.DataPacket_RELIABLE || seq == 0 { + if p.State() != livekit.ParticipantInfo_ACTIVE { + return ErrDataChannelUnavailable + } + return p.TransportManager.SendDataMessage(kind, data) + } + + p.reliableDataInfo.joiningMessageLock.Lock() + if !p.reliableDataInfo.canWriteReliable { + if _, ok := p.reliableDataInfo.joiningMessageFirstSeqs[sender]; !ok { + p.reliableDataInfo.joiningMessageFirstSeqs[sender] = seq + } + p.reliableDataInfo.joiningMessageLock.Unlock() + return nil + } + + lastWrittenSeq, ok := p.reliableDataInfo.joiningMessageLastWrittenSeqs[sender] + if ok { + if seq <= lastWrittenSeq { + // already sent by replayJoiningReliableMessages + p.reliableDataInfo.joiningMessageLock.Unlock() + return nil + } else { + delete(p.reliableDataInfo.joiningMessageLastWrittenSeqs, sender) + } + } + + p.reliableDataInfo.joiningMessageLock.Unlock() + + return p.TransportManager.SendDataMessage(kind, data) +} + +func (p *ParticipantImpl) SendDataMessageUnlabeled(data []byte, useRaw bool, sender livekit.ParticipantIdentity) error { + if p.State() != livekit.ParticipantInfo_ACTIVE { + return ErrDataChannelUnavailable + } + + return p.TransportManager.SendDataMessageUnlabeled(data, useRaw, sender) +} + +func (p *ParticipantImpl) onDataSendError(err error) { + if p.params.ReconnectOnDataChannelError { + p.params.Logger.Infow("issuing full reconnect on data channel error", "error", err) + p.IssueFullReconnect(types.ParticipantCloseReasonDataChannelError) + } +} + +func (p *ParticipantImpl) setupEnabledCodecs(publishEnabledCodecs []*livekit.Codec, subscribeEnabledCodecs []*livekit.Codec, disabledCodecs *livekit.DisabledCodecs) { + shouldDisable := func(c *livekit.Codec, disabled []*livekit.Codec) bool { + for _, disableCodec := range disabled { + // disable codec's fmtp is empty means disable this codec entirely + if mime.IsMimeTypeStringEqual(c.Mime, disableCodec.Mime) { + return true + } + } + return false + } + + publishCodecsAudio := make([]*livekit.Codec, 0, len(publishEnabledCodecs)) + publishCodecsVideo := make([]*livekit.Codec, 0, len(publishEnabledCodecs)) + for _, c := range publishEnabledCodecs { + if shouldDisable(c, disabledCodecs.GetCodecs()) || shouldDisable(c, disabledCodecs.GetPublish()) { + continue + } + + // sort by compatibility, since we will look for backups in these. + if mime.IsMimeTypeStringVP8(c.Mime) { + if len(p.enabledPublishCodecs) > 0 { + p.enabledPublishCodecs = slices.Insert(p.enabledPublishCodecs, 0, c) + } else { + p.enabledPublishCodecs = append(p.enabledPublishCodecs, c) + } + } else if mime.IsMimeTypeStringH264(c.Mime) { + p.enabledPublishCodecs = append(p.enabledPublishCodecs, c) + } else { + if mime.IsMimeTypeStringAudio(c.Mime) { + publishCodecsAudio = append(publishCodecsAudio, c) + } else { + publishCodecsVideo = append(publishCodecsVideo, c) + } + } + } + // list all video first and then audio to work around a client side issue with Flutter SDK 2.4.2 + p.enabledPublishCodecs = append(p.enabledPublishCodecs, publishCodecsVideo...) + p.enabledPublishCodecs = append(p.enabledPublishCodecs, publishCodecsAudio...) + + subscribeCodecs := make([]*livekit.Codec, 0, len(subscribeEnabledCodecs)) + for _, c := range subscribeEnabledCodecs { + if shouldDisable(c, disabledCodecs.GetCodecs()) { + continue + } + subscribeCodecs = append(subscribeCodecs, c) + } + p.enabledSubscribeCodecs = subscribeCodecs + p.params.Logger.Debugw( + "setup enabled codecs", + "publish", logger.ProtoSlice(p.enabledPublishCodecs), + "subscribe", logger.ProtoSlice(p.enabledSubscribeCodecs), + "disabled", logger.Proto(disabledCodecs), + ) +} + +func (p *ParticipantImpl) replayJoiningReliableMessages() { + p.reliableDataInfo.joiningMessageLock.Lock() + for _, msgCache := range p.helper().GetCachedReliableDataMessage(p.reliableDataInfo.joiningMessageFirstSeqs) { + if len(msgCache.DestIdentities) != 0 && !slices.Contains(msgCache.DestIdentities, p.Identity()) { + continue + } + if lastSeq, ok := p.reliableDataInfo.joiningMessageLastWrittenSeqs[msgCache.SenderID]; !ok || lastSeq < msgCache.Seq { + p.reliableDataInfo.joiningMessageLastWrittenSeqs[msgCache.SenderID] = msgCache.Seq + } + + p.TransportManager.SendDataMessage(livekit.DataPacket_RELIABLE, msgCache.Data) + } + + p.reliableDataInfo.joiningMessageFirstSeqs = make(map[livekit.ParticipantID]uint32) + p.reliableDataInfo.canWriteReliable = true + p.reliableDataInfo.joiningMessageLock.Unlock() +} + +func (p *ParticipantImpl) GetEnabledPublishCodecs() []*livekit.Codec { + codecs := make([]*livekit.Codec, 0, len(p.enabledPublishCodecs)) + for _, c := range p.enabledPublishCodecs { + if mime.IsMimeTypeStringRTX(c.Mime) { + continue + } + codecs = append(codecs, c) + } + return codecs +} + +func (p *ParticipantImpl) UpdateAudioTrack(update *livekit.UpdateLocalAudioTrack) error { + if track := p.UpTrackManager.UpdatePublishedAudioTrack(update); track != nil { + return nil + } + + isPending := false + p.pendingTracksLock.RLock() + for _, pti := range p.pendingTracks { + for _, ti := range pti.trackInfos { + if ti.Sid == update.TrackSid { + isPending = true + + ti.AudioFeatures = sutils.DedupeSlice(update.Features) + ti.Stereo = false + ti.DisableDtx = false + for _, feature := range update.Features { + switch feature { + case livekit.AudioTrackFeature_TF_STEREO: + ti.Stereo = true + case livekit.AudioTrackFeature_TF_NO_DTX: + ti.DisableDtx = true + } + } + + p.pubLogger.Debugw("updated pending track", "trackID", ti.Sid, "trackInfo", logger.Proto(ti)) + } + } + } + p.pendingTracksLock.RUnlock() + if isPending { + return nil + } + + p.pubLogger.Debugw("could not locate track", "trackID", update.TrackSid) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_FOUND, + Request: &livekit.RequestResponse_UpdateAudioTrack{ + UpdateAudioTrack: utils.CloneProto(update), + }, + }) + return errors.New("could not find track") +} + +func (p *ParticipantImpl) UpdateVideoTrack(update *livekit.UpdateLocalVideoTrack) error { + if track := p.UpTrackManager.UpdatePublishedVideoTrack(update); track != nil { + return nil + } + + isPending := false + p.pendingTracksLock.RLock() + for _, pti := range p.pendingTracks { + for _, ti := range pti.trackInfos { + if ti.Sid == update.TrackSid { + isPending = true + + ti.Width = update.Width + ti.Height = update.Height + } + } + } + p.pendingTracksLock.RUnlock() + if isPending { + return nil + } + + p.pubLogger.Debugw("could not locate track", "trackID", update.TrackSid) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_FOUND, + Request: &livekit.RequestResponse_UpdateVideoTrack{ + UpdateVideoTrack: utils.CloneProto(update), + }, + }) + return errors.New("could not find track") +} + +func (p *ParticipantImpl) HandleMetrics(senderParticipantID livekit.ParticipantID, metrics *livekit.MetricsBatch) error { + if p.State() != livekit.ParticipantInfo_ACTIVE { + return ErrDataChannelUnavailable + } + + if !p.CanSubscribeMetrics() { + return ErrNoSubscribeMetricsPermission + } + + if senderParticipantID != p.ID() && !p.SubscriptionManager.IsSubscribedTo(senderParticipantID) { + return nil + } + + p.metricsReporter.Merge(metrics) + return nil +} + +func (p *ParticipantImpl) SupportsCodecChange() bool { + return p.params.ClientInfo.SupportsCodecChange() +} + +func (p *ParticipantImpl) SupportsMoving() error { + if !p.ProtocolVersion().SupportsMoving() { + return ErrMoveOldClientVersion + } + + if kind := p.Kind(); kind == livekit.ParticipantInfo_EGRESS || kind == livekit.ParticipantInfo_AGENT || p.params.UseOneShotSignallingMode { + return fmt.Errorf("%s participants cannot be moved, one-shot signaling mode: %t", kind.String(), p.params.UseOneShotSignallingMode) + } + + return nil +} + +func (p *ParticipantImpl) MoveToRoom(params types.MoveToRoomParams) { + // fire onClose callback for original room + p.lock.Lock() + onClose := p.onClose + p.onClose = make(map[string]func(types.LocalParticipant)) + p.lock.Unlock() + for _, cb := range onClose { + cb(p) + } + + for _, track := range p.GetPublishedTracks() { + for _, sub := range track.GetAllSubscribers() { + track.RemoveSubscriber(sub, false) + } + + // clear the subscriber node max quality/audio codecs as the remote quality notify + // from source room would not reach the moving out participant. + track.(types.LocalMediaTrack).ClearSubscriberNodes() + + trackInfo := track.ToProto() + p.params.Telemetry.TrackUnpublished( + context.Background(), + p.ID(), + p.Identity(), + trackInfo, + true, + ) + } + + p.params.Logger.Infow("move participant to new room", "newRoomName", params.RoomName, "newID", params.ParticipantID) + + p.params.LoggerResolver.Reset() + p.params.ReporterResolver.Reset() + p.setListener(params.Listener) + p.participantHelper.Store(params.Helper) + p.SubscriptionManager.ClearAllSubscriptions() + p.id.Store(params.ParticipantID) + grants := p.grants.Load().Clone() + grants.Video.Room = string(params.RoomName) + p.grants.Store(grants) +} + +func (p *ParticipantImpl) helper() types.LocalParticipantHelper { + return p.participantHelper.Load().(types.LocalParticipantHelper) +} + +func (p *ParticipantImpl) GetLastReliableSequence(migrateOut bool) uint32 { + if migrateOut { + p.reliableDataInfo.stopReliableByMigrateOut.Store(true) + } + return p.reliableDataInfo.lastPubReliableSeq.Load() +} + +func (p *ParticipantImpl) HandleUpdateSubscriptions( + trackIDs []livekit.TrackID, + participantTracks []*livekit.ParticipantTracks, + subscribe bool, +) { + p.listener().OnUpdateSubscriptions(p, trackIDs, participantTracks, subscribe) +} + +func (p *ParticipantImpl) HandleUpdateSubscriptionPermission(subscriptionPermission *livekit.SubscriptionPermission) error { + return p.listener().OnUpdateSubscriptionPermission(p, subscriptionPermission) +} + +func (p *ParticipantImpl) HandleSyncState(syncState *livekit.SyncState) error { + return p.listener().OnSyncState(p, syncState) +} + +func (p *ParticipantImpl) HandleSimulateScenario(simulateScenario *livekit.SimulateScenario) error { + return p.listener().OnSimulateScenario(p, simulateScenario) +} + +func (p *ParticipantImpl) HandleLeaveRequest(reason types.ParticipantCloseReason) { + p.listener().OnLeave(p, reason) +} + +func (p *ParticipantImpl) HandleSignalMessage(msg proto.Message) error { + return p.signalHandler.HandleMessage(msg) +} + +func (p *ParticipantImpl) IsUsingSinglePeerConnection() bool { + return p.params.UseSinglePeerConnection +} + +func (p *ParticipantImpl) AddTrackLocal( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, +) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + if p.params.UseSinglePeerConnection { + return p.TransportManager.AddTrackLocal( + trackLocal, + params, + p.enabledSubscribeCodecs, + p.params.Config.Subscriber.RTCPFeedback, + ) + } else { + return p.TransportManager.AddTrackLocal(trackLocal, params, nil, RTCPFeedbackConfig{}) + } +} + +func (p *ParticipantImpl) AddTransceiverFromTrackLocal( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, +) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + if p.params.UseSinglePeerConnection { + return p.TransportManager.AddTransceiverFromTrackLocal( + trackLocal, + params, + p.enabledSubscribeCodecs, + p.params.Config.Subscriber.RTCPFeedback, + ) + } else { + return p.TransportManager.AddTransceiverFromTrackLocal( + trackLocal, + params, + nil, + RTCPFeedbackConfig{}, + ) + } +} + +func (p *ParticipantImpl) handleIncomingRpcAck(requestId string) bool { + p.rpcLock.Lock() + defer p.rpcLock.Unlock() + + handler, ok := p.rpcPendingAcks[requestId] + if !ok { + return false + } + + handler.Resolve() + delete(p.rpcPendingAcks, requestId) + return true +} + +func (p *ParticipantImpl) handleIncomingRpcResponse(requestId string, payload string, err *utils.DataChannelRpcError) bool { + p.rpcLock.Lock() + defer p.rpcLock.Unlock() + + handler, ok := p.rpcPendingResponses[requestId] + if !ok { + return false + } + + handler.Resolve(payload, err) + delete(p.rpcPendingResponses, requestId) + return true +} + +func (p *ParticipantImpl) PerformRpc(req *livekit.PerformRpcRequest, resultCh chan string, errorCh chan error) { + responseTimeout := req.GetResponseTimeoutMs() + if responseTimeout <= 0 { + responseTimeout = uint32(utils.DataChannelRpcDefaultResponseTimeout.Milliseconds()) + } + + go func() { + if len([]byte(req.GetPayload())) > utils.DataChannelRpcMaxPayloadBytes { + errorCh <- utils.DataChannelRpcErrorFromBuiltInCodes(utils.DataChannelRpcRequestPayloadTooLarge, "").PsrpcError() + return + } + + id := uuid.NewString() + + responseTimer := time.AfterFunc(time.Duration(responseTimeout)*time.Millisecond, func() { + p.rpcLock.Lock() + delete(p.rpcPendingResponses, id) + p.rpcLock.Unlock() + + select { + case errorCh <- utils.DataChannelRpcErrorFromBuiltInCodes(utils.DataChannelRpcResponseTimeout, "").PsrpcError(): + default: + } + }) + ackTimer := time.AfterFunc(utils.DataChannelRpcMaxRoundTripLatency, func() { + p.rpcLock.Lock() + delete(p.rpcPendingAcks, id) + delete(p.rpcPendingResponses, id) + p.rpcLock.Unlock() + responseTimer.Stop() + + select { + case errorCh <- utils.DataChannelRpcErrorFromBuiltInCodes(utils.DataChannelRpcConnectionTimeout, "").PsrpcError(): + default: + } + }) + + rpcRequest := &livekit.DataPacket{ + Kind: livekit.DataPacket_RELIABLE, + ParticipantIdentity: id, + Value: &livekit.DataPacket_RpcRequest{ + RpcRequest: &livekit.RpcRequest{ + Id: id, + Method: req.GetMethod(), + Payload: req.GetPayload(), + ResponseTimeoutMs: responseTimeout - p.lastRTT, + Version: 1, + }, + }, + } + data, err := proto.Marshal(rpcRequest) + if err != nil { + ackTimer.Stop() + responseTimer.Stop() + errorCh <- psrpc.NewError(psrpc.Internal, err) + return + } + + // using RPC ID as the unique ID for server to identify the response + err = p.SendDataMessage(livekit.DataPacket_RELIABLE, data, livekit.ParticipantID(id), 0) + if err != nil { + ackTimer.Stop() + responseTimer.Stop() + errorCh <- psrpc.NewError(psrpc.Internal, err) + return + } + + p.rpcLock.Lock() + p.rpcPendingAcks[id] = &utils.DataChannelRpcPendingAckHandler{ + Resolve: func() { + ackTimer.Stop() + }, + ParticipantIdentity: req.GetDestinationIdentity(), + } + p.rpcPendingResponses[id] = &utils.DataChannelRpcPendingResponseHandler{ + Resolve: func(payload string, error *utils.DataChannelRpcError) { + responseTimer.Stop() + if _, ok := p.rpcPendingAcks[id]; ok { + p.rpcPendingAcks[id].Resolve() + ackTimer.Stop() + } + + if error != nil { + errorCh <- error.PsrpcError() + } else { + resultCh <- payload + } + }, + ParticipantIdentity: req.GetDestinationIdentity(), + } + p.rpcLock.Unlock() + }() +} diff --git a/livekit/pkg/rtc/participant_data_track.go b/livekit/pkg/rtc/participant_data_track.go new file mode 100644 index 0000000..f5da89f --- /dev/null +++ b/livekit/pkg/rtc/participant_data_track.go @@ -0,0 +1,160 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" +) + +func (p *ParticipantImpl) HandlePublishDataTrackRequest(req *livekit.PublishDataTrackRequest) { + if !p.CanPublishData() || !p.params.EnableDataTracks { + p.pubLogger.Warnw("no permission to publish data track", nil, "req", logger.Proto(req)) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_ALLOWED, + Message: "does not have permission to publish data", + Request: &livekit.RequestResponse_PublishDataTrack{ + PublishDataTrack: utils.CloneProto(req), + }, + }) + return + } + + if req.PubHandle == 0 || req.PubHandle > 65535 { + p.pubLogger.Warnw("invalid data track handle", nil, "req", logger.Proto(req)) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_INVALID_HANDLE, + Message: "handle should be > 0 AND < 65536", + Request: &livekit.RequestResponse_PublishDataTrack{ + PublishDataTrack: utils.CloneProto(req), + }, + }) + return + } + + if len(req.Name) == 0 || len(req.Name) > 256 { + p.pubLogger.Warnw("invalid data track name", nil, "req", logger.Proto(req)) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_INVALID_NAME, + Message: "name should not be empty and should not exceed 256 characters", + Request: &livekit.RequestResponse_PublishDataTrack{ + PublishDataTrack: utils.CloneProto(req), + }, + }) + return + } + + publishedDataTracks := p.UpDataTrackManager.GetPublishedDataTracks() + for _, dt := range publishedDataTracks { + message := "" + reason := livekit.RequestResponse_OK + switch { + case dt.PubHandle() == uint16(req.PubHandle): + message = "a data track with same handle already exists" + reason = livekit.RequestResponse_DUPLICATE_HANDLE + case dt.Name() == req.Name: + message = "a data track with same name already exists" + reason = livekit.RequestResponse_DUPLICATE_NAME + } + if message != "" { + p.pubLogger.Warnw( + "cannot publish duplicate data track", nil, + "req", logger.Proto(req), + "existing", logger.Proto(dt.ToProto()), + ) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: reason, + Message: message, + Request: &livekit.RequestResponse_PublishDataTrack{ + PublishDataTrack: utils.CloneProto(req), + }, + }) + return + } + } + + dti := &livekit.DataTrackInfo{ + PubHandle: req.PubHandle, + Sid: guid.New(utils.DataTrackPrefix), + Name: req.Name, + Encryption: req.Encryption, + } + dt := NewDataTrack( + DataTrackParams{ + Logger: p.params.Logger.WithValues("trackID", dti.Sid), + ParticipantID: p.ID, + ParticipantIdentity: p.params.Identity, + }, + dti, + ) + + p.UpDataTrackManager.AddPublishedDataTrack(dt) + + p.sendPublishDataTrackResponse(dti) + + p.setIsPublisher(true) + p.dirty.Store(true) +} + +func (p *ParticipantImpl) HandleUnpublishDataTrackRequest(req *livekit.UnpublishDataTrackRequest) { + dt := p.UpDataTrackManager.GetPublishedDataTrack(uint16(req.PubHandle)) + if dt == nil { + p.pubLogger.Warnw("unpublish data track not found", nil, "req", logger.Proto(req)) + p.sendRequestResponse(&livekit.RequestResponse{ + Reason: livekit.RequestResponse_NOT_FOUND, + Request: &livekit.RequestResponse_UnpublishDataTrack{ + UnpublishDataTrack: utils.CloneProto(req), + }, + }) + return + } + + p.UpDataTrackManager.RemovePublishedDataTrack(dt) + + p.sendUnpublishDataTrackResponse(dt.ToProto()) + + p.dirty.Store(true) +} + +func (p *ParticipantImpl) HandleUpdateDataSubscription(req *livekit.UpdateDataSubscription) { + p.listener().OnUpdateDataSubscriptions(p, req) +} + +func (p *ParticipantImpl) onReceivedDataTrackMessage(data []byte, arrivalTime int64) { + var packet datatrack.Packet + if err := packet.Unmarshal(data); err != nil { + p.params.Logger.Errorw("could not unmarshal data track message", err) + return + } + + p.UpDataTrackManager.HandleReceivedDataTrackMessage(data, &packet, arrivalTime) + + p.listener().OnDataTrackMessage(p, data, &packet) +} + +func (p *ParticipantImpl) GetNextSubscribedDataTrackHandle() uint16 { + p.lock.Lock() + defer p.lock.Unlock() + + p.nextSubscribedDataTrackHandle++ + if p.nextSubscribedDataTrackHandle == 0 { + p.nextSubscribedDataTrackHandle++ + } + + return p.nextSubscribedDataTrackHandle +} diff --git a/livekit/pkg/rtc/participant_internal_test.go b/livekit/pkg/rtc/participant_internal_test.go new file mode 100644 index 0000000..f4d2653 --- /dev/null +++ b/livekit/pkg/rtc/participant_internal_test.go @@ -0,0 +1,828 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/telemetry/telemetryfakes" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + lksdp "github.com/livekit/protocol/sdp" + "github.com/livekit/protocol/signalling" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/routing/routingfakes" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" + "github.com/livekit/livekit-server/pkg/testutils" +) + +func TestIsReady(t *testing.T) { + tests := []struct { + state livekit.ParticipantInfo_State + ready bool + }{ + { + state: livekit.ParticipantInfo_JOINING, + ready: false, + }, + { + state: livekit.ParticipantInfo_JOINED, + ready: true, + }, + { + state: livekit.ParticipantInfo_ACTIVE, + ready: true, + }, + { + state: livekit.ParticipantInfo_DISCONNECTED, + ready: false, + }, + } + + for _, test := range tests { + t.Run(test.state.String(), func(t *testing.T) { + p := &ParticipantImpl{} + p.state.Store(test.state) + require.Equal(t, test.ready, p.IsReady()) + }) + } +} + +func TestTrackPublishing(t *testing.T) { + t.Run("should send the correct events", func(t *testing.T) { + p := newParticipantForTest("test") + track := &typesfakes.FakeMediaTrack{} + track.IDReturns("id") + published := false + updated := false + p.listener().(*typesfakes.FakeLocalParticipantListener).OnTrackUpdatedCalls(func(p types.Participant, track types.MediaTrack) { + updated = true + }) + p.listener().(*typesfakes.FakeLocalParticipantListener).OnTrackPublishedCalls(func(p types.Participant, track types.MediaTrack) { + published = true + }) + p.UpTrackManager.AddPublishedTrack(track) + p.handleTrackPublished(track, false) + require.True(t, published) + require.False(t, updated) + require.Len(t, p.UpTrackManager.publishedTracks, 1) + }) + + t.Run("sends back trackPublished event", func(t *testing.T) { + p := newParticipantForTest("test") + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + Width: 1024, + Height: 768, + }) + require.Equal(t, 1, sink.WriteMessageCallCount()) + res := sink.WriteMessageArgsForCall(0).(*livekit.SignalResponse) + require.IsType(t, &livekit.SignalResponse_TrackPublished{}, res.Message) + published := res.Message.(*livekit.SignalResponse_TrackPublished).TrackPublished + require.Equal(t, "cid", published.Cid) + require.Equal(t, "webcam", published.Track.Name) + require.Equal(t, livekit.TrackType_VIDEO, published.Track.Type) + require.Equal(t, uint32(1024), published.Track.Width) + require.Equal(t, uint32(768), published.Track.Height) + }) + + t.Run("should not allow adding of duplicate tracks", func(t *testing.T) { + p := newParticipantForTest("test") + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "duplicate", + Type: livekit.TrackType_AUDIO, + }) + + // error response on duplicate adds a message + require.Equal(t, 2, sink.WriteMessageCallCount()) + }) + + t.Run("should queue adding of duplicate tracks if already published by client id in signalling", func(t *testing.T) { + p := newParticipantForTest("test") + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + + track := &typesfakes.FakeLocalMediaTrack{} + track.HasSignalCidCalls(func(s string) bool { return s == "cid" }) + track.ToProtoReturns(&livekit.TrackInfo{}) + // directly add to publishedTracks without lock - for testing purpose only + p.UpTrackManager.publishedTracks["cid"] = track + + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + // `queued` `RequestResponse` should add a message + require.Equal(t, 1, sink.WriteMessageCallCount()) + require.Equal(t, 1, len(p.pendingTracks["cid"].trackInfos)) + + // add again - it should be added to the queue + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + // `queued` `RequestResponse`s should have been sent for duplicate additions + require.Equal(t, 2, sink.WriteMessageCallCount()) + require.Equal(t, 2, len(p.pendingTracks["cid"].trackInfos)) + + // check SID is the same + require.Equal(t, p.pendingTracks["cid"].trackInfos[0].Sid, p.pendingTracks["cid"].trackInfos[1].Sid) + }) + + t.Run("should queue adding of duplicate tracks if already published by client id in sdp", func(t *testing.T) { + p := newParticipantForTest("test") + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + + track := &typesfakes.FakeLocalMediaTrack{} + track.ToProtoReturns(&livekit.TrackInfo{}) + track.HasSdpCidCalls(func(s string) bool { return s == "cid" }) + // directly add to publishedTracks without lock - for testing purpose only + p.UpTrackManager.publishedTracks["cid"] = track + + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + // `queued` `RequestResponse` should add a message + require.Equal(t, 1, sink.WriteMessageCallCount()) + require.Equal(t, 1, len(p.pendingTracks["cid"].trackInfos)) + + // add again - it should be added to the queue + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Type: livekit.TrackType_VIDEO, + }) + // `queued` `RequestResponse`s should have been sent for duplicate additions + require.Equal(t, 2, sink.WriteMessageCallCount()) + require.Equal(t, 2, len(p.pendingTracks["cid"].trackInfos)) + + // check SID is the same + require.Equal(t, p.pendingTracks["cid"].trackInfos[0].Sid, p.pendingTracks["cid"].trackInfos[1].Sid) + }) + + t.Run("should not allow adding disallowed sources", func(t *testing.T) { + p := newParticipantForTest("test") + p.SetPermission(&livekit.ParticipantPermission{ + CanPublish: true, + CanPublishSources: []livekit.TrackSource{ + livekit.TrackSource_CAMERA, + }, + }) + sink := p.params.Sink.(*routingfakes.FakeMessageSink) + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Name: "webcam", + Source: livekit.TrackSource_CAMERA, + Type: livekit.TrackType_VIDEO, + }) + require.Equal(t, 1, sink.WriteMessageCallCount()) + + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid2", + Name: "rejected source", + Type: livekit.TrackType_AUDIO, + Source: livekit.TrackSource_MICROPHONE, + }) + // an error response for disallowed source should send a `RequestResponse`. + require.Equal(t, 2, sink.WriteMessageCallCount()) + }) +} + +func TestOutOfOrderUpdates(t *testing.T) { + p := newParticipantForTest("test") + p.updateState(livekit.ParticipantInfo_JOINED) + p.SetMetadata("initial metadata") + sink := p.GetResponseSink().(*routingfakes.FakeMessageSink) + pi1 := p.ToProto() + p.SetMetadata("second update") + pi2 := p.ToProto() + + require.Greater(t, pi2.Version, pi1.Version) + + // send the second update first + require.NoError(t, p.SendParticipantUpdate([]*livekit.ParticipantInfo{pi2})) + require.NoError(t, p.SendParticipantUpdate([]*livekit.ParticipantInfo{pi1})) + + // only sent once, and it's the earlier message + require.Equal(t, 1, sink.WriteMessageCallCount()) + sent := sink.WriteMessageArgsForCall(0).(*livekit.SignalResponse) + require.Equal(t, "second update", sent.GetUpdate().Participants[0].Metadata) +} + +// after disconnection, things should continue to function and not panic +func TestDisconnectTiming(t *testing.T) { + t.Run("Negotiate doesn't panic after channel closed", func(t *testing.T) { + p := newParticipantForTest("test") + msg := routing.NewMessageChannel(livekit.ConnectionID("test"), routing.DefaultMessageChannelSize) + p.params.Sink = msg + go func() { + for msg := range msg.ReadChan() { + t.Log("received message from chan", msg) + } + }() + track := &typesfakes.FakeMediaTrack{} + p.UpTrackManager.AddPublishedTrack(track) + p.handleTrackPublished(track, false) + + // close channel and then try to Negotiate + msg.Close() + }) +} + +func TestCorrectJoinedAt(t *testing.T) { + p := newParticipantForTest("test") + info := p.ToProto() + require.NotZero(t, info.JoinedAt) + require.True(t, time.Now().Unix()-info.JoinedAt <= 1) +} + +func TestMuteSetting(t *testing.T) { + t.Run("can set mute when track is pending", func(t *testing.T) { + p := newParticipantForTest("test") + ti := &livekit.TrackInfo{Sid: "testTrack"} + p.pendingTracks["cid"] = &pendingTrackInfo{trackInfos: []*livekit.TrackInfo{ti}} + + p.SetTrackMuted(&livekit.MuteTrackRequest{ + Sid: ti.Sid, + Muted: true, + }, false) + require.True(t, p.pendingTracks["cid"].trackInfos[0].Muted) + }) + + t.Run("can publish a muted track", func(t *testing.T) { + p := newParticipantForTest("test") + p.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid", + Type: livekit.TrackType_AUDIO, + Muted: true, + }) + + _, ti, _, _, _ := p.getPendingTrack("cid", livekit.TrackType_AUDIO, false) + require.NotNil(t, ti) + require.True(t, ti.Muted) + }) +} + +func TestSubscriberAsPrimary(t *testing.T) { + t.Run("protocol 4 uses subs as primary", func(t *testing.T) { + p := newParticipantForTestWithOpts("test", &participantOpts{ + permissions: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, + }) + require.True(t, p.SubscriberAsPrimary()) + }) + + t.Run("protocol 2 uses pub as primary", func(t *testing.T) { + p := newParticipantForTestWithOpts("test", &participantOpts{ + protocolVersion: 2, + permissions: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, + }) + require.False(t, p.SubscriberAsPrimary()) + }) + + t.Run("publisher only uses pub as primary", func(t *testing.T) { + p := newParticipantForTestWithOpts("test", &participantOpts{ + permissions: &livekit.ParticipantPermission{ + CanSubscribe: false, + CanPublish: true, + }, + }) + require.False(t, p.SubscriberAsPrimary()) + + // ensure that it doesn't change after perms + p.SetPermission(&livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }) + require.False(t, p.SubscriberAsPrimary()) + }) +} + +func TestDisableCodecs(t *testing.T) { + participant := newParticipantForTestWithOpts("123", &participantOpts{ + publisher: false, + clientConf: &livekit.ClientConfiguration{ + DisabledCodecs: &livekit.DisabledCodecs{ + Codecs: []*livekit.Codec{ + {Mime: "video/h264"}, + }, + }, + }, + }) + + participant.SetMigrateState(types.MigrateStateComplete) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + transceiver, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendrecv}) + require.NoError(t, err) + sdp, err := pc.CreateOffer(nil) + require.NoError(t, err) + pc.SetLocalDescription(sdp) + codecs := transceiver.Receiver().GetParameters().Codecs + var found264 bool + for _, c := range codecs { + if mime.IsMimeTypeStringH264(c.MimeType) { + found264 = true + } + } + require.True(t, found264) + offerId := uint32(42) + + // negotiated codec should not contain h264 + sink := &routingfakes.FakeMessageSink{} + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) + var answer webrtc.SessionDescription + var answerId uint32 + var answerReceived atomic.Bool + var answerIdReceived atomic.Uint32 + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if res.GetAnswer() != nil { + answer, answerId, _ = signalling.FromProtoSessionDescription(res.GetAnswer()) + answerReceived.Store(true) + answerIdReceived.Store(answerId) + } + } + return nil + }) + participant.HandleOffer(&livekit.SessionDescription{ + Type: webrtc.SDPTypeOffer.String(), + Sdp: sdp.SDP, + Id: offerId, + }) + + testutils.WithTimeout(t, func() string { + if answerReceived.Load() && answerIdReceived.Load() == offerId { + return "" + } else { + return "answer not received OR answer id mismatch" + } + }) + require.NoError(t, pc.SetRemoteDescription(answer), answer.SDP, sdp.SDP) + + codecs = transceiver.Receiver().GetParameters().Codecs + found264 = false + for _, c := range codecs { + if mime.IsMimeTypeStringH264(c.MimeType) { + found264 = true + } + } + require.False(t, found264) +} + +func TestDisablePublishCodec(t *testing.T) { + participant := newParticipantForTestWithOpts("123", &participantOpts{ + publisher: true, + clientConf: &livekit.ClientConfiguration{ + DisabledCodecs: &livekit.DisabledCodecs{ + Publish: []*livekit.Codec{ + {Mime: "video/h264"}, + }, + }, + }, + }) + + for _, codec := range participant.enabledPublishCodecs { + require.False(t, mime.IsMimeTypeStringH264(codec.Mime)) + } + + sink := &routingfakes.FakeMessageSink{} + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) + var publishReceived atomic.Bool + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if published := res.GetTrackPublished(); published != nil { + publishReceived.Store(true) + require.NotEmpty(t, published.Track.Codecs) + require.True(t, mime.IsMimeTypeStringVP8(published.Track.Codecs[0].MimeType)) + } + } + return nil + }) + + // simulcast codec response should pick an alternative + participant.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid1", + Type: livekit.TrackType_VIDEO, + SimulcastCodecs: []*livekit.SimulcastCodec{{ + Codec: "h264", + Cid: "cid1", + }}, + }) + + require.Eventually(t, func() bool { return publishReceived.Load() }, 5*time.Second, 10*time.Millisecond) + + // publishing a supported codec should not change + publishReceived.Store(false) + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if published := res.GetTrackPublished(); published != nil { + publishReceived.Store(true) + require.NotEmpty(t, published.Track.Codecs) + require.True(t, mime.IsMimeTypeStringVP8(published.Track.Codecs[0].MimeType)) + } + } + return nil + }) + participant.AddTrack(&livekit.AddTrackRequest{ + Cid: "cid2", + Type: livekit.TrackType_VIDEO, + SimulcastCodecs: []*livekit.SimulcastCodec{{ + Codec: "vp8", + Cid: "cid2", + }}, + }) + require.Eventually(t, func() bool { return publishReceived.Load() }, 5*time.Second, 10*time.Millisecond) +} + +func TestPreferMediaCodecForPublisher(t *testing.T) { + testCases := []struct { + name string + mediaKind string + trackBaseCid string + preferredCodec string + addTrack *livekit.AddTrackRequest + mimeTypeStringChecker func(string) bool + mimeTypeCodecStringChecker func(string) bool + transceiverMimeType mime.MimeType + }{ + { + name: "video", + mediaKind: "video", + trackBaseCid: "preferH264Video", + preferredCodec: "h264", + addTrack: &livekit.AddTrackRequest{ + Type: livekit.TrackType_VIDEO, + Name: "video", + Width: 1280, + Height: 720, + Source: livekit.TrackSource_CAMERA, + }, + mimeTypeStringChecker: mime.IsMimeTypeStringH264, + mimeTypeCodecStringChecker: mime.IsMimeTypeCodecStringH264, + transceiverMimeType: mime.MimeTypeVP8, + }, + { + name: "audio", + mediaKind: "audio", + trackBaseCid: "preferPCMAAudio", + preferredCodec: "pcma", + addTrack: &livekit.AddTrackRequest{ + Type: livekit.TrackType_AUDIO, + Name: "audio", + Source: livekit.TrackSource_MICROPHONE, + }, + mimeTypeStringChecker: mime.IsMimeTypeStringPCMA, + mimeTypeCodecStringChecker: mime.IsMimeTypeCodecStringPCMA, + transceiverMimeType: mime.MimeTypeOpus, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + participant := newParticipantForTestWithOpts("123", &participantOpts{ + publisher: true, + }) + participant.SetMigrateState(types.MigrateStateComplete) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + for i := 0; i < 2; i++ { + // publish preferred track without client using setCodecPreferences() + trackCid := fmt.Sprintf("%s-%d", tc.trackBaseCid, i) + req := utils.CloneProto(tc.addTrack) + req.SimulcastCodecs = []*livekit.SimulcastCodec{ + { + Codec: tc.preferredCodec, + Cid: trackCid, + }, + } + participant.AddTrack(req) + + track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{MimeType: tc.transceiverMimeType.String()}, trackCid, trackCid) + require.NoError(t, err) + transceiver, err := pc.AddTransceiverFromTrack(track, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendrecv}) + require.NoError(t, err) + codecs := transceiver.Receiver().GetParameters().Codecs + + if i > 0 { + // the negotiated codecs order could be updated by first negotiation, + // reorder to make tested preferred codec not preferred + for tc.mimeTypeStringChecker(codecs[0].MimeType) { + codecs = append(codecs[1:], codecs[0]) + } + } + // preferred codec should not be preferred in `offer` + require.False(t, tc.mimeTypeStringChecker(codecs[0].MimeType), "codecs", codecs) + + sdp, err := pc.CreateOffer(nil) + require.NoError(t, err) + require.NoError(t, pc.SetLocalDescription(sdp)) + offerId := uint32(23) + + sink := &routingfakes.FakeMessageSink{} + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) + var answer webrtc.SessionDescription + var answerId uint32 + var answerReceived atomic.Bool + var answerIdReceived atomic.Uint32 + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if res.GetAnswer() != nil { + answer, answerId, _ = signalling.FromProtoSessionDescription(res.GetAnswer()) + pc.SetRemoteDescription(answer) + answerReceived.Store(true) + answerIdReceived.Store(answerId) + } + } + return nil + }) + participant.HandleOffer(&livekit.SessionDescription{ + Type: webrtc.SDPTypeOffer.String(), + Sdp: sdp.SDP, + Id: offerId, + }) + + require.Eventually(t, func() bool { return answerReceived.Load() && answerIdReceived.Load() == offerId }, 5*time.Second, 10*time.Millisecond) + + var havePreferred bool + parsed, err := answer.Unmarshal() + require.NoError(t, err) + var mediaSectionIndex int + for _, m := range parsed.MediaDescriptions { + if m.MediaName.Media == tc.mediaKind { + if mediaSectionIndex == i { + codecs, err := lksdp.CodecsFromMediaDescription(m) + require.NoError(t, err) + if tc.mimeTypeCodecStringChecker(codecs[0].Name) { + havePreferred = true + break + } + } + mediaSectionIndex++ + } + } + + require.Truef(t, havePreferred, "%s should be preferred for %s section %d, answer sdp: \n%s", tc.preferredCodec, tc.mediaKind, i, answer.SDP) + } + }) + } +} + +func TestPreferAudioCodecForRed(t *testing.T) { + participant := newParticipantForTestWithOpts("123", &participantOpts{ + publisher: true, + }) + participant.SetMigrateState(types.MigrateStateComplete) + + me := webrtc.MediaEngine{} + opusCodecParameters := OpusCodecParameters + opusCodecParameters.RTPCodecCapability.RTCPFeedback = []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}} + require.NoError(t, me.RegisterCodec(opusCodecParameters, webrtc.RTPCodecTypeAudio)) + redCodecParameters := RedCodecParameters + redCodecParameters.RTPCodecCapability.RTCPFeedback = []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}} + require.NoError(t, me.RegisterCodec(redCodecParameters, webrtc.RTPCodecTypeAudio)) + + api := webrtc.NewAPI(webrtc.WithMediaEngine(&me)) + pc, err := api.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + for idx, disableRed := range []bool{false, true, false, true} { + t.Run(fmt.Sprintf("disableRed=%v", disableRed), func(t *testing.T) { + trackCid := fmt.Sprintf("audiotrack%d", idx) + req := &livekit.AddTrackRequest{ + Type: livekit.TrackType_AUDIO, + Cid: trackCid, + } + if idx < 2 { + req.DisableRed = disableRed + } else { + codec := "red" + if disableRed { + codec = "opus" + } + req.SimulcastCodecs = []*livekit.SimulcastCodec{ + { + Codec: codec, + Cid: trackCid, + }, + } + } + participant.AddTrack(req) + + track, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{MimeType: "audio/opus"}, + trackCid, + trackCid, + ) + require.NoError(t, err) + + transceiver, err := pc.AddTransceiverFromTrack( + track, + webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendrecv}, + ) + require.NoError(t, err) + codecs := transceiver.Sender().GetParameters().Codecs + for i, c := range codecs { + if c.MimeType == "audio/opus" { + if i != 0 { + codecs[0], codecs[i] = codecs[i], codecs[0] + } + break + } + } + transceiver.SetCodecPreferences(codecs) + sdp, err := pc.CreateOffer(nil) + require.NoError(t, err) + pc.SetLocalDescription(sdp) + // opus should be preferred + require.Equal(t, codecs[0].MimeType, "audio/opus", sdp) + offerId := uint32(0xffffff) + + sink := &routingfakes.FakeMessageSink{} + participant.SwapResponseSink(sink, types.SignallingCloseReasonUnknown) + var answer webrtc.SessionDescription + var answerId uint32 + var answerReceived atomic.Bool + var answerIdReceived atomic.Uint32 + sink.WriteMessageCalls(func(msg proto.Message) error { + if res, ok := msg.(*livekit.SignalResponse); ok { + if res.GetAnswer() != nil { + answer, answerId, _ = signalling.FromProtoSessionDescription(res.GetAnswer()) + pc.SetRemoteDescription(answer) + answerReceived.Store(true) + answerIdReceived.Store(answerId) + } + } + return nil + }) + participant.HandleOffer(&livekit.SessionDescription{ + Type: webrtc.SDPTypeOffer.String(), + Sdp: sdp.SDP, + Id: offerId, + }) + require.Eventually( + t, + func() bool { + return answerReceived.Load() && answerIdReceived.Load() == offerId + }, + 5*time.Second, + 10*time.Millisecond, + ) + + var redPreferred bool + parsed, err := answer.Unmarshal() + require.NoError(t, err) + var audioSectionIndex int + for _, m := range parsed.MediaDescriptions { + if m.MediaName.Media == "audio" { + if audioSectionIndex == idx { + codecs, err := lksdp.CodecsFromMediaDescription(m) + require.NoError(t, err) + // nack is always enabled. if red is preferred, server will not generate nack request + var nackEnabled bool + for _, c := range codecs { + if c.Name == "opus" { + for _, fb := range c.RTCPFeedback { + if strings.Contains(fb, "nack") { + nackEnabled = true + break + } + } + } + } + require.True(t, nackEnabled, "nack should be enabled for opus") + + if mime.IsMimeTypeCodecStringRED(codecs[0].Name) { + redPreferred = true + break + } + } + audioSectionIndex++ + } + } + require.Equalf(t, !disableRed, redPreferred, "offer : \n%s\nanswer sdp: \n%s", sdp, answer.SDP) + }) + } +} + +type participantOpts struct { + permissions *livekit.ParticipantPermission + protocolVersion types.ProtocolVersion + publisher bool + clientConf *livekit.ClientConfiguration + clientInfo *livekit.ClientInfo +} + +func newParticipantForTestWithOpts(identity livekit.ParticipantIdentity, opts *participantOpts) *ParticipantImpl { + if opts == nil { + opts = &participantOpts{} + } + if opts.protocolVersion == 0 { + opts.protocolVersion = 6 + } + conf, _ := config.NewConfig("", true, nil, nil) + // disable mux, it doesn't play too well with unit test + conf.RTC.TCPPort = 0 + rtcConf, err := NewWebRTCConfig(conf) + if err != nil { + panic(err) + } + ff := buffer.NewFactoryOfBufferFactory(500, 200) + rtcConf.SetBufferFactory(ff.CreateBufferFactory()) + grants := &auth.ClaimGrants{ + Video: &auth.VideoGrant{}, + } + if opts.permissions != nil { + grants.Video.SetCanPublish(opts.permissions.CanPublish) + grants.Video.SetCanPublishData(opts.permissions.CanPublishData) + grants.Video.SetCanSubscribe(opts.permissions.CanSubscribe) + } + + enabledCodecs := make([]*livekit.Codec, 0, len(conf.Room.EnabledCodecs)) + for _, c := range conf.Room.EnabledCodecs { + enabledCodecs = append(enabledCodecs, &livekit.Codec{ + Mime: c.Mime, + FmtpLine: c.FmtpLine, + }) + } + sid := livekit.ParticipantID(guid.New(utils.ParticipantPrefix)) + p, _ := NewParticipant(ParticipantParams{ + SID: sid, + Identity: identity, + Config: rtcConf, + Sink: &routingfakes.FakeMessageSink{}, + ProtocolVersion: opts.protocolVersion, + SessionStartTime: time.Now(), + PLIThrottleConfig: conf.RTC.PLIThrottle, + Grants: grants, + PublishEnabledCodecs: enabledCodecs, + SubscribeEnabledCodecs: enabledCodecs, + ClientConf: opts.clientConf, + ClientInfo: ClientInfo{ClientInfo: opts.clientInfo}, + Logger: LoggerWithParticipant(logger.GetLogger(), identity, sid, false), + Reporter: roomobs.NewNoopParticipantSessionReporter(), + Telemetry: &telemetryfakes.FakeTelemetryService{}, + VersionGenerator: utils.NewDefaultTimedVersionGenerator(), + ParticipantListener: &typesfakes.FakeLocalParticipantListener{}, + ParticipantHelper: &typesfakes.FakeLocalParticipantHelper{}, + }) + p.isPublisher.Store(opts.publisher) + p.updateState(livekit.ParticipantInfo_ACTIVE) + + return p +} + +func newParticipantForTest(identity livekit.ParticipantIdentity) *ParticipantImpl { + return newParticipantForTestWithOpts(identity, nil) +} diff --git a/livekit/pkg/rtc/participant_sdp.go b/livekit/pkg/rtc/participant_sdp.go new file mode 100644 index 0000000..828863f --- /dev/null +++ b/livekit/pkg/rtc/participant_sdp.go @@ -0,0 +1,408 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + lksdp "github.com/livekit/protocol/sdp" + "github.com/livekit/protocol/utils" +) + +func (p *ParticipantImpl) populateSdpCid(parsedOffer *sdp.SessionDescription) ([]*sdp.MediaDescription, []*sdp.MediaDescription) { + processUnmatch := func(unmatches []*sdp.MediaDescription, trackType livekit.TrackType) { + for _, unmatch := range unmatches { + streamID, ok := lksdp.ExtractStreamID(unmatch) + if !ok { + continue + } + + sdpCodecs, err := lksdp.CodecsFromMediaDescription(unmatch) + if err != nil || len(sdpCodecs) == 0 { + p.pubLogger.Errorw( + "extract codecs from media section failed", err, + "media", unmatch, + "parsedOffer", parsedOffer, + ) + continue + } + + p.pendingTracksLock.Lock() + signalCid, info, _, migrated, _ := p.getPendingTrack(streamID, trackType, false) + if migrated { + p.pendingTracksLock.Unlock() + continue + } + + if info == nil { + p.pendingTracksLock.Unlock() + + // could be already published track and the unmatch could be a back up codec publish + numUnmatchedTracks := 0 + var unmatchedTrack types.MediaTrack + var unmatchedSdpMimeType mime.MimeType + + found := false + for _, sdpCodec := range sdpCodecs { + sdpMimeType := mime.NormalizeMimeTypeCodec(sdpCodec.Name).ToMimeType() + for _, publishedTrack := range p.GetPublishedTracks() { + if sigCid, sdpCid := publishedTrack.(*MediaTrack).GetCidsForMimeType(sdpMimeType); sigCid != "" && sdpCid == "" { + // a back up codec has a SDP cid match + if sigCid == streamID { + found = true + break + } else { + numUnmatchedTracks++ + unmatchedTrack = publishedTrack + unmatchedSdpMimeType = sdpMimeType + } + } + } + if found { + break + } + } + if !found && unmatchedTrack != nil { + if numUnmatchedTracks != 1 { + p.pubLogger.Warnw( + "too many unmatched tracks", nil, + "media", unmatch, + "parsedOffer", parsedOffer, + ) + } + unmatchedTrack.(*MediaTrack).UpdateCodecSdpCid(unmatchedSdpMimeType, streamID) + p.pubLogger.Debugw( + "published track SDP cid updated", + "trackID", unmatchedTrack.ID(), + "track", logger.Proto(unmatchedTrack.ToProto()), + ) + } + continue + } + + if len(info.Codecs) == 0 { + p.pendingTracksLock.Unlock() + p.pubLogger.Warnw( + "track without codecs", nil, + "trackID", info.Sid, + "pendingTrack", p.pendingTracks[signalCid], + "media", unmatch, + "parsedOffer", parsedOffer, + ) + continue + } + + found := false + updated := false + for _, sdpCodec := range sdpCodecs { + if mime.NormalizeMimeTypeCodec(sdpCodec.Name) == mime.GetMimeTypeCodec(info.Codecs[0].MimeType) { + // set SdpCid only if different from SignalCid + if streamID != info.Codecs[0].Cid { + info.Codecs[0].SdpCid = streamID + updated = true + } + found = true + break + } + if found { + break + } + } + + if !found { + // not using SimulcastCodec, i. e. mime type not available till track publish + if len(info.Codecs) == 1 { + // set SdpCid only if different from SignalCid + if streamID != info.Codecs[0].Cid { + info.Codecs[0].SdpCid = streamID + updated = true + } + } + } + + if updated { + p.pendingTracks[signalCid].trackInfos[0] = utils.CloneProto(info) + p.pubLogger.Debugw( + "pending track SDP cid updated", + "signalCid", signalCid, + "trackID", info.Sid, + "pendingTrack", p.pendingTracks[signalCid], + ) + } + p.pendingTracksLock.Unlock() + } + } + + unmatchAudios, err := p.TransportManager.GetUnmatchMediaForOffer(parsedOffer, "audio") + if err != nil { + p.pubLogger.Warnw("could not get unmatch audios", err) + return nil, nil + } + + unmatchVideos, err := p.TransportManager.GetUnmatchMediaForOffer(parsedOffer, "video") + if err != nil { + p.pubLogger.Warnw("could not get unmatch videos", err) + return nil, nil + } + + processUnmatch(unmatchAudios, livekit.TrackType_AUDIO) + processUnmatch(unmatchVideos, livekit.TrackType_VIDEO) + return unmatchAudios, unmatchVideos +} + +func (p *ParticipantImpl) setCodecPreferencesForPublisher( + parsedOffer *sdp.SessionDescription, + unmatchAudios []*sdp.MediaDescription, + unmatchVideos []*sdp.MediaDescription, +) { + unprocessedUnmatchAudios := p.setCodecPreferencesForPublisherMedia( + parsedOffer, + unmatchAudios, + livekit.TrackType_AUDIO, + ) + p.setCodecPreferencesOpusRedForPublisher(parsedOffer, unprocessedUnmatchAudios) + _ = p.setCodecPreferencesForPublisherMedia( + parsedOffer, + unmatchVideos, + livekit.TrackType_VIDEO, + ) +} + +func (p *ParticipantImpl) setCodecPreferencesForPublisherMedia( + parsedOffer *sdp.SessionDescription, + unmatches []*sdp.MediaDescription, + trackType livekit.TrackType, +) []*sdp.MediaDescription { + unprocessed := make([]*sdp.MediaDescription, 0, len(unmatches)) + for _, unmatch := range unmatches { + var ti *livekit.TrackInfo + var mimeType string + + mid := lksdp.GetMidValue(unmatch) + if mid == "" { + unprocessed = append(unprocessed, unmatch) + continue + } + transceiver := p.TransportManager.GetPublisherRTPTransceiver(mid) + if transceiver == nil { + unprocessed = append(unprocessed, unmatch) + continue + } + + streamID, ok := lksdp.ExtractStreamID(unmatch) + if !ok { + unprocessed = append(unprocessed, unmatch) + continue + } + + p.pendingTracksLock.RLock() + mt := p.getPublishedTrackBySdpCid(streamID) + if mt != nil { + ti = mt.ToProto() + } else { + _, ti, _, _, _ = p.getPendingTrack(streamID, trackType, false) + } + p.pendingTracksLock.RUnlock() + + if ti == nil { + unprocessed = append(unprocessed, unmatch) + continue + } + + for _, c := range ti.Codecs { + if c.Cid == streamID || c.SdpCid == streamID { + mimeType = c.MimeType + break + } + } + if mimeType == "" && len(ti.Codecs) > 0 { + mimeType = ti.Codecs[0].MimeType + } + + if mimeType == "" { + unprocessed = append(unprocessed, unmatch) + continue + } + + configureReceiverCodecs( + transceiver, + mimeType, + p.params.ClientInfo.ComplyWithCodecOrderInSDPAnswer(), + ) + } + + return unprocessed +} + +func (p *ParticipantImpl) setCodecPreferencesOpusRedForPublisher( + parsedOffer *sdp.SessionDescription, + unmatchAudios []*sdp.MediaDescription, +) { + for _, unmatchAudio := range unmatchAudios { + mid := lksdp.GetMidValue(unmatchAudio) + if mid == "" { + continue + } + transceiver := p.TransportManager.GetPublisherRTPTransceiver(mid) + if transceiver == nil { + continue + } + + streamID, ok := lksdp.ExtractStreamID(unmatchAudio) + if !ok { + continue + } + + p.pendingTracksLock.RLock() + _, ti, _, _, _ := p.getPendingTrack(streamID, livekit.TrackType_AUDIO, false) + p.pendingTracksLock.RUnlock() + if ti == nil { + continue + } + + codecs, err := lksdp.CodecsFromMediaDescription(unmatchAudio) + if err != nil { + p.pubLogger.Errorw( + "extract codecs from media section failed", err, + "media", unmatchAudio, + "parsedOffer", parsedOffer, + ) + continue + } + + var opusPayload uint8 + for _, codec := range codecs { + if mime.IsMimeTypeCodecStringOpus(codec.Name) { + opusPayload = codec.PayloadType + break + } + } + if opusPayload == 0 { + continue + } + + preferRED := IsRedEnabled(ti) + // if RED is enabled for this track, prefer RED codec in offer + for _, codec := range codecs { + // codec contain opus/red + if preferRED && + mime.IsMimeTypeCodecStringRED(codec.Name) && + strings.Contains(codec.Fmtp, strconv.FormatInt(int64(opusPayload), 10)) { + configureReceiverCodecs(transceiver, "audio/red", true) + break + } + } + } +} + +// configure publisher answer for audio track's dtx and stereo settings +func (p *ParticipantImpl) configurePublisherAnswer(answer webrtc.SessionDescription) webrtc.SessionDescription { + offer := p.TransportManager.LastPublisherOffer() + parsedOffer, err := offer.Unmarshal() + if err != nil { + return answer + } + + parsedAnswer, err := answer.Unmarshal() + if err != nil { + return answer + } + + for _, m := range parsedAnswer.MediaDescriptions { + switch m.MediaName.Media { + case "audio": + _, ok := m.Attribute(sdp.AttrKeyInactive) + if ok { + continue + } + mid, ok := m.Attribute(sdp.AttrKeyMID) + if !ok { + continue + } + // find track info from offer's stream id + var ti *livekit.TrackInfo + for _, om := range parsedOffer.MediaDescriptions { + _, ok := om.Attribute(sdp.AttrKeyInactive) + if ok { + continue + } + omid, ok := om.Attribute(sdp.AttrKeyMID) + if ok && omid == mid { + streamID, ok := lksdp.ExtractStreamID(om) + if !ok { + continue + } + track, _ := p.getPublishedTrackBySdpCid(streamID).(*MediaTrack) + if track == nil { + p.pendingTracksLock.RLock() + _, ti, _, _, _ = p.getPendingTrack(streamID, livekit.TrackType_AUDIO, false) + p.pendingTracksLock.RUnlock() + } else { + ti = track.ToProto() + } + break + } + } + + if ti == nil { + // no need to configure + continue + } + + opusPT, err := parsedAnswer.GetPayloadTypeForCodec(sdp.Codec{Name: mime.MimeTypeCodecOpus.String()}) + if err != nil { + p.pubLogger.Infow("failed to get opus payload type", "error", err, "trackID", ti.Sid) + continue + } + + for i, attr := range m.Attributes { + if strings.HasPrefix(attr.String(), fmt.Sprintf("fmtp:%d", opusPT)) { + if !slices.Contains(ti.AudioFeatures, livekit.AudioTrackFeature_TF_NO_DTX) { + attr.Value += ";usedtx=1" + } else { + attr.Value = strings.ReplaceAll(attr.Value, ";usedtx=1", "") + } + if slices.Contains(ti.AudioFeatures, livekit.AudioTrackFeature_TF_STEREO) { + attr.Value += ";stereo=1;maxaveragebitrate=510000" + } else { + attr.Value = strings.ReplaceAll(attr.Value, ";stereo=1", "") + } + m.Attributes[i] = attr + } + } + + default: + continue + } + } + + bytes, err := parsedAnswer.Marshal() + if err != nil { + p.pubLogger.Infow("failed to marshal answer", "error", err) + return answer + } + answer.SDP = string(bytes) + return answer +} diff --git a/livekit/pkg/rtc/participant_signal.go b/livekit/pkg/rtc/participant_signal.go new file mode 100644 index 0000000..3fa388a --- /dev/null +++ b/livekit/pkg/rtc/participant_signal.go @@ -0,0 +1,370 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "time" + + "github.com/pion/webrtc/v4" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + protosignalling "github.com/livekit/protocol/signalling" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +func (p *ParticipantImpl) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) { + p.signaller.SwapResponseSink(sink, reason) +} + +func (p *ParticipantImpl) GetResponseSink() routing.MessageSink { + return p.signaller.GetResponseSink() +} + +func (p *ParticipantImpl) CloseSignalConnection(reason types.SignallingCloseReason) { + p.signaller.CloseSignalConnection(reason) +} + +func (p *ParticipantImpl) SendJoinResponse(joinResponse *livekit.JoinResponse) error { + // keep track of participant updates and versions + p.updateLock.Lock() + for _, op := range joinResponse.OtherParticipants { + p.updateCache.Add(livekit.ParticipantID(op.Sid), participantUpdateInfo{ + identity: livekit.ParticipantIdentity(op.Identity), + version: op.Version, + state: op.State, + updatedAt: time.Now(), + }) + } + p.updateLock.Unlock() + + // send Join response + err := p.signaller.WriteMessage(p.signalling.SignalJoinResponse(joinResponse)) + if err != nil { + return err + } + + // update state after sending message, so that no participant updates could slip through before JoinResponse is sent + p.updateLock.Lock() + if p.State() == livekit.ParticipantInfo_JOINING { + p.updateState(livekit.ParticipantInfo_JOINED) + } + queuedUpdates := p.queuedUpdates + p.queuedUpdates = nil + p.updateLock.Unlock() + + if len(queuedUpdates) > 0 { + return p.SendParticipantUpdate(queuedUpdates) + } + + return nil +} + +func (p *ParticipantImpl) SendParticipantUpdate(participantsToUpdate []*livekit.ParticipantInfo) error { + p.updateLock.Lock() + if p.IsDisconnected() { + p.updateLock.Unlock() + return nil + } + + if !p.IsReady() { + // queue up updates + p.queuedUpdates = append(p.queuedUpdates, participantsToUpdate...) + p.updateLock.Unlock() + return nil + } + validUpdates := make([]*livekit.ParticipantInfo, 0, len(participantsToUpdate)) + for _, pi := range participantsToUpdate { + isValid := true + pID := livekit.ParticipantID(pi.Sid) + if lastVersion, ok := p.updateCache.Get(pID); ok { + // this is a message delivered out of order, a more recent version of the message had already been + // sent. + if pi.Version < lastVersion.version { + p.params.Logger.Debugw( + "skipping outdated participant update", + "otherParticipant", pi.Identity, + "otherPID", pi.Sid, + "version", pi.Version, + "lastVersion", lastVersion, + ) + isValid = false + } + } + if pi.Permission != nil && pi.Permission.Hidden && pi.Sid != string(p.ID()) { + p.params.Logger.Debugw("skipping hidden participant update", "otherParticipant", pi.Identity) + isValid = false + } + if isValid { + p.updateCache.Add(pID, participantUpdateInfo{ + identity: livekit.ParticipantIdentity(pi.Identity), + version: pi.Version, + state: pi.State, + updatedAt: time.Now(), + }) + validUpdates = append(validUpdates, pi) + } + } + p.updateLock.Unlock() + + return p.signaller.WriteMessage(p.signalling.SignalParticipantUpdate(validUpdates)) +} + +// SendSpeakerUpdate notifies participant changes to speakers. only send members that have changed since last update +func (p *ParticipantImpl) SendSpeakerUpdate(speakers []*livekit.SpeakerInfo, force bool) error { + if !p.IsReady() { + return nil + } + + var scopedSpeakers []*livekit.SpeakerInfo + if force { + scopedSpeakers = speakers + } else { + for _, s := range speakers { + participantID := livekit.ParticipantID(s.Sid) + if p.IsSubscribedTo(participantID) || participantID == p.ID() { + scopedSpeakers = append(scopedSpeakers, s) + } + } + } + + return p.signaller.WriteMessage(p.signalling.SignalSpeakerUpdate(scopedSpeakers)) +} + +func (p *ParticipantImpl) SendRoomUpdate(room *livekit.Room) error { + return p.signaller.WriteMessage(p.signalling.SignalRoomUpdate(room)) +} + +func (p *ParticipantImpl) SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error { + return p.signaller.WriteMessage(p.signalling.SignalConnectionQualityUpdate(update)) +} + +func (p *ParticipantImpl) SendRefreshToken(token string) error { + return p.signaller.WriteMessage(p.signalling.SignalRefreshToken(token)) +} + +func (p *ParticipantImpl) sendRequestResponse(requestResponse *livekit.RequestResponse) error { + if !p.params.ClientInfo.SupportsRequestResponse() { + return nil + } + + if requestResponse.Reason == livekit.RequestResponse_OK && !p.ProtocolVersion().SupportsNonErrorSignalResponse() { + return nil + } + + return p.signaller.WriteMessage(p.signalling.SignalRequestResponse(requestResponse)) +} + +func (p *ParticipantImpl) SendRoomMovedResponse(roomMovedResponse *livekit.RoomMovedResponse) error { + return p.signaller.WriteMessage(p.signalling.SignalRoomMovedResponse(roomMovedResponse)) +} + +func (p *ParticipantImpl) HandleReconnectAndSendResponse(reconnectReason livekit.ReconnectReason, reconnectResponse *livekit.ReconnectResponse) error { + p.TransportManager.HandleClientReconnect(reconnectReason) + + if !p.params.ClientInfo.CanHandleReconnectResponse() { + return nil + } + if err := p.signaller.WriteMessage(p.signalling.SignalReconnectResponse(reconnectResponse)); err != nil { + return err + } + + if p.params.ProtocolVersion.SupportsDisconnectedUpdate() { + return p.sendDisconnectUpdatesForReconnect() + } + + return nil +} + +func (p *ParticipantImpl) sendDisconnectUpdatesForReconnect() error { + lastSignalAt := p.TransportManager.LastSeenSignalAt() + var disconnectedParticipants []*livekit.ParticipantInfo + p.updateLock.Lock() + keys := p.updateCache.Keys() + for i := len(keys) - 1; i >= 0; i-- { + if info, ok := p.updateCache.Get(keys[i]); ok { + if info.updatedAt.Before(lastSignalAt) { + break + } else if info.state == livekit.ParticipantInfo_DISCONNECTED { + disconnectedParticipants = append(disconnectedParticipants, &livekit.ParticipantInfo{ + Sid: string(keys[i]), + Identity: string(info.identity), + Version: info.version, + State: livekit.ParticipantInfo_DISCONNECTED, + }) + } + } + } + p.updateLock.Unlock() + + return p.signaller.WriteMessage(p.signalling.SignalParticipantUpdate(disconnectedParticipants)) +} + +func (p *ParticipantImpl) sendICECandidate(ic *webrtc.ICECandidate, target livekit.SignalTarget) error { + prevIC := p.icQueue[target].Swap(ic) + if prevIC == nil { + return nil + } + + trickle := protosignalling.ToProtoTrickle(prevIC.ToJSON(), target, ic == nil) + p.params.Logger.Debugw("sending ICE candidate", "transport", target, "trickle", logger.Proto(trickle)) + + return p.signaller.WriteMessage(p.signalling.SignalICECandidate(trickle)) +} + +func (p *ParticipantImpl) sendTrackMuted(trackID livekit.TrackID, muted bool) { + _ = p.signaller.WriteMessage(p.signalling.SignalTrackMuted(&livekit.MuteTrackRequest{ + Sid: string(trackID), + Muted: muted, + })) +} + +func (p *ParticipantImpl) sendTrackPublished(cid string, ti *livekit.TrackInfo) error { + p.pubLogger.Debugw("sending track published", "cid", cid, "trackInfo", logger.Proto(ti)) + return p.signaller.WriteMessage(p.signalling.SignalTrackPublished(&livekit.TrackPublishedResponse{ + Cid: cid, + Track: ti, + })) +} + +func (p *ParticipantImpl) sendTrackUnpublished(trackID livekit.TrackID) { + _ = p.signaller.WriteMessage(p.signalling.SignalTrackUnpublished(&livekit.TrackUnpublishedResponse{ + TrackSid: string(trackID), + })) +} + +func (p *ParticipantImpl) sendTrackHasBeenSubscribed(trackID livekit.TrackID) { + if !p.params.ClientInfo.SupportsTrackSubscribedEvent() { + return + } + _ = p.signaller.WriteMessage(p.signalling.SignalTrackSubscribed(&livekit.TrackSubscribed{ + TrackSid: string(trackID), + })) + p.params.Logger.Debugw("track has been subscribed", "trackID", trackID) +} + +func (p *ParticipantImpl) sendLeaveRequest( + reason types.ParticipantCloseReason, + isExpectedToResume bool, + isExpectedToReconnect bool, + sendOnlyIfSupportingLeaveRequestWithAction bool, +) error { + var leave *livekit.LeaveRequest + if p.ProtocolVersion().SupportsRegionsInLeaveRequest() { + leave = &livekit.LeaveRequest{ + Reason: reason.ToDisconnectReason(), + } + switch { + case isExpectedToResume: + leave.Action = livekit.LeaveRequest_RESUME + case isExpectedToReconnect: + leave.Action = livekit.LeaveRequest_RECONNECT + default: + leave.Action = livekit.LeaveRequest_DISCONNECT + } + if leave.Action != livekit.LeaveRequest_DISCONNECT { + // sending region settings even for RESUME just in case client wants to a full reconnect despite server saying RESUME + leave.Regions = p.helper().GetRegionSettings(p.params.ClientInfo.Address) + } + } else { + if !sendOnlyIfSupportingLeaveRequestWithAction { + leave = &livekit.LeaveRequest{ + CanReconnect: isExpectedToReconnect, + Reason: reason.ToDisconnectReason(), + } + } + } + if leave != nil { + return p.signaller.WriteMessage(p.signalling.SignalLeaveRequest(leave)) + } + + return nil +} + +func (p *ParticipantImpl) sendSdpAnswer(answer webrtc.SessionDescription, answerId uint32, midToTrackID map[string]string) error { + return p.signaller.WriteMessage(p.signalling.SignalSdpAnswer(protosignalling.ToProtoSessionDescription(answer, answerId, midToTrackID))) +} + +func (p *ParticipantImpl) sendSdpOffer(offer webrtc.SessionDescription, offerId uint32, midToTrackID map[string]string) error { + return p.signaller.WriteMessage(p.signalling.SignalSdpOffer(protosignalling.ToProtoSessionDescription(offer, offerId, midToTrackID))) +} + +func (p *ParticipantImpl) sendStreamStateUpdate(streamStateUpdate *livekit.StreamStateUpdate) error { + return p.signaller.WriteMessage(p.signalling.SignalStreamStateUpdate(streamStateUpdate)) +} + +func (p *ParticipantImpl) sendSubscribedQualityUpdate(subscribedQualityUpdate *livekit.SubscribedQualityUpdate) error { + return p.signaller.WriteMessage(p.signalling.SignalSubscribedQualityUpdate(subscribedQualityUpdate)) +} + +func (p *ParticipantImpl) sendSubscribedAudioCodecUpdate(subscribedAudioCodecUpdate *livekit.SubscribedAudioCodecUpdate) error { + return p.signaller.WriteMessage(p.signalling.SignalSubscribedAudioCodecUpdate(subscribedAudioCodecUpdate)) +} + +func (p *ParticipantImpl) sendSubscriptionResponse(trackID livekit.TrackID, subErr livekit.SubscriptionError) error { + return p.signaller.WriteMessage(p.signalling.SignalSubscriptionResponse(&livekit.SubscriptionResponse{ + TrackSid: string(trackID), + Err: subErr, + })) +} + +func (p *ParticipantImpl) SendSubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) error { + p.subLogger.Debugw("sending subscription permission update", "publisherID", publisherID, "trackID", trackID, "allowed", allowed) + err := p.signaller.WriteMessage(p.signalling.SignalSubscriptionPermissionUpdate(&livekit.SubscriptionPermissionUpdate{ + ParticipantSid: string(publisherID), + TrackSid: string(trackID), + Allowed: allowed, + })) + if err != nil { + p.subLogger.Errorw("could not send subscription permission update", err) + } + return err +} + +func (p *ParticipantImpl) sendMediaSectionsRequirement(numAudios uint32, numVideos uint32) error { + p.pubLogger.Debugw( + "sending media sections requirement", + "numAudios", numAudios, + "numVideos", numVideos, + ) + err := p.signaller.WriteMessage(p.signalling.SignalMediaSectionsRequirement(&livekit.MediaSectionsRequirement{ + NumAudios: numAudios, + NumVideos: numVideos, + })) + if err != nil { + p.subLogger.Errorw("could not send media sections requirement", err) + } + return err +} + +func (p *ParticipantImpl) sendPublishDataTrackResponse(dti *livekit.DataTrackInfo) error { + return p.signaller.WriteMessage(p.signalling.SignalPublishDataTrackResponse(&livekit.PublishDataTrackResponse{ + Info: dti, + })) +} + +func (p *ParticipantImpl) sendUnpublishDataTrackResponse(dti *livekit.DataTrackInfo) error { + return p.signaller.WriteMessage(p.signalling.SignalUnpublishDataTrackResponse(&livekit.UnpublishDataTrackResponse{ + Info: dti, + })) +} + +func (p *ParticipantImpl) SendDataTrackSubscriberHandles(handles map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack) error { + return p.signaller.WriteMessage(p.signalling.SignalDataTrackSubscriberHandles(&livekit.DataTrackSubscriberHandles{ + SubHandles: handles, + })) +} diff --git a/livekit/pkg/rtc/room.go b/livekit/pkg/rtc/room.go new file mode 100644 index 0000000..3a094c3 --- /dev/null +++ b/livekit/pkg/rtc/room.go @@ -0,0 +1,2296 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "fmt" + "math" + "slices" + "sort" + "strings" + "sync" + "time" + + "go.uber.org/atomic" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + sutils "github.com/livekit/livekit-server/pkg/utils" +) + +const ( + AudioLevelQuantization = 8 // ideally power of 2 to minimize float decimal + invAudioLevelQuantization = 1.0 / AudioLevelQuantization + subscriberUpdateInterval = 3 * time.Second + + dataForwardLoadBalanceThreshold = 4 + + simulateDisconnectSignalTimeout = 5 * time.Second + + dataMessageCacheTTL = 2 * time.Second + dataMessageCacheSize = 100_000 +) + +var ( + // var to allow unit test override + roomUpdateInterval = 5 * time.Second // frequency to update room participant counts + + ErrJobShutdownTimeout = psrpc.NewErrorf(psrpc.DeadlineExceeded, "timed out waiting for agent job to shutdown") +) + +// Duplicate the service.AgentStore interface to avoid a rtc -> service -> rtc import cycle +type AgentStore interface { + StoreAgentDispatch(ctx context.Context, dispatch *livekit.AgentDispatch) error + DeleteAgentDispatch(ctx context.Context, dispatch *livekit.AgentDispatch) error + ListAgentDispatches(ctx context.Context, roomName livekit.RoomName) ([]*livekit.AgentDispatch, error) + + StoreAgentJob(ctx context.Context, job *livekit.Job) error + DeleteAgentJob(ctx context.Context, job *livekit.Job) error +} + +type broadcastOptions struct { + skipSource bool + immediate bool +} + +type disconnectSignalOnResumeNoMessages struct { + expiry time.Time + closedCount int +} + +type Room struct { + // atomics always need to be 64bit/8byte aligned + // on 32bit arch only the beginning of the struct + // starts at such a boundary. + // time the first participant joined the room + joinedAt atomic.Int64 + // time that the last participant left the room + leftAt atomic.Int64 + holds atomic.Int32 + + lock sync.RWMutex + + protoRoom *livekit.Room + internal *livekit.RoomInternal + protoProxy *utils.ProtoProxy[*livekit.Room] + logger logger.Logger + + config WebRTCConfig + roomConfig config.RoomConfig + audioConfig *sfu.AudioConfig + serverInfo *livekit.ServerInfo + telemetry telemetry.TelemetryService + egressLauncher EgressLauncher + trackManager *RoomTrackManager + agentDispatches map[string]*agentDispatch + + // agents + agentClient agent.Client + agentStore AgentStore + + // map of identity -> Participant + participants map[livekit.ParticipantIdentity]types.LocalParticipant + participantOpts map[livekit.ParticipantIdentity]*ParticipantOptions + participantRequestSources map[livekit.ParticipantIdentity]routing.MessageSource + hasPublished map[livekit.ParticipantIdentity]bool + agentParticpants map[livekit.ParticipantIdentity]*agentJob + bufferFactory *buffer.FactoryOfBufferFactory + + // batch update participant info for non-publishers + batchedUpdates map[livekit.ParticipantIdentity]*ParticipantUpdate + batchedUpdatesMu sync.Mutex + + closed chan struct{} + + trailer []byte + + onParticipantChanged func(p types.Participant) + onRoomUpdated func() + onClose func() + + simulationLock sync.Mutex + disconnectSignalOnResumeParticipants map[livekit.ParticipantIdentity]time.Time + disconnectSignalOnResumeNoMessagesParticipants map[livekit.ParticipantIdentity]*disconnectSignalOnResumeNoMessages + + userPacketDeduper *UserPacketDeduper + + dataMessageCache *utils.TimeSizeCache[types.DataMessageCache] + + onStateChangeMu sync.Mutex + localParticipantListener types.LocalParticipantListener +} + +type ParticipantOptions struct { + AutoSubscribe bool + AutoSubscribeDataTrack bool +} + +type agentDispatch struct { + *livekit.AgentDispatch + lock sync.Mutex + pending map[chan struct{}]struct{} +} + +type agentJob struct { + *livekit.Job + lock sync.Mutex + done chan struct{} +} + +// This provides utilities attached the agent dispatch to ensure that all pending jobs are created +// before terminating jobs attached to an agent dispatch. This avoids a race that could cause some pending jobs +// to not be terminated when a dispatch is deleted. +func newAgentDispatch(ad *livekit.AgentDispatch) *agentDispatch { + return &agentDispatch{ + AgentDispatch: ad, + pending: make(map[chan struct{}]struct{}), + } +} + +func (ad *agentDispatch) jobsLaunching() (jobsLaunched func()) { + ad.lock.Lock() + c := make(chan struct{}) + ad.pending[c] = struct{}{} + ad.lock.Unlock() + + return func() { + close(c) + ad.lock.Lock() + delete(ad.pending, c) + ad.lock.Unlock() + } +} + +func (ad *agentDispatch) waitForPendingJobs() { + ad.lock.Lock() + cs := maps.Keys(ad.pending) + ad.lock.Unlock() + + for _, c := range cs { + <-c + } +} + +// This provides utilities to ensure that an agent left the room when killing a job +func newAgentJob(j *livekit.Job) *agentJob { + return &agentJob{ + Job: j, + done: make(chan struct{}), + } +} + +func (j *agentJob) participantLeft() { + j.lock.Lock() + if j.done != nil { + close(j.done) + j.done = nil + } + j.lock.Unlock() +} + +func (j *agentJob) waitForParticipantLeaving() error { + var done chan struct{} + + j.lock.Lock() + done = j.done + j.lock.Unlock() + + if done != nil { + select { + case <-done: + return nil + case <-time.After(3 * time.Second): + return ErrJobShutdownTimeout + } + } + + return nil +} + +func NewRoom( + room *livekit.Room, + internal *livekit.RoomInternal, + config WebRTCConfig, + roomConfig config.RoomConfig, + audioConfig *sfu.AudioConfig, + serverInfo *livekit.ServerInfo, + telemetry telemetry.TelemetryService, + agentClient agent.Client, + agentStore AgentStore, + egressLauncher EgressLauncher, +) *Room { + r := &Room{ + protoRoom: utils.CloneProto(room), + internal: internal, + logger: LoggerWithRoom( + logger.GetLogger().WithComponent(sutils.ComponentRoom), + livekit.RoomName(room.Name), + livekit.RoomID(room.Sid), + ), + config: config, + roomConfig: roomConfig, + audioConfig: audioConfig, + telemetry: telemetry, + egressLauncher: egressLauncher, + agentClient: agentClient, + agentStore: agentStore, + agentDispatches: make(map[string]*agentDispatch), + serverInfo: serverInfo, + participants: make(map[livekit.ParticipantIdentity]types.LocalParticipant), + participantOpts: make(map[livekit.ParticipantIdentity]*ParticipantOptions), + participantRequestSources: make(map[livekit.ParticipantIdentity]routing.MessageSource), + hasPublished: make(map[livekit.ParticipantIdentity]bool), + agentParticpants: make(map[livekit.ParticipantIdentity]*agentJob), + bufferFactory: buffer.NewFactoryOfBufferFactory(config.Receiver.PacketBufferSizeVideo, config.Receiver.PacketBufferSizeAudio), + batchedUpdates: make(map[livekit.ParticipantIdentity]*ParticipantUpdate), + closed: make(chan struct{}), + trailer: []byte(utils.RandomSecret()), + disconnectSignalOnResumeParticipants: make(map[livekit.ParticipantIdentity]time.Time), + disconnectSignalOnResumeNoMessagesParticipants: make(map[livekit.ParticipantIdentity]*disconnectSignalOnResumeNoMessages), + userPacketDeduper: NewUserPacketDeduper(), + dataMessageCache: utils.NewTimeSizeCache[types.DataMessageCache](utils.TimeSizeCacheParams{ + TTL: dataMessageCacheTTL, + MaxSize: dataMessageCacheSize, + }), + } + r.trackManager = NewRoomTrackManager(r.logger) + r.localParticipantListener = &localParticipantListener{room: r} + + if r.protoRoom.EmptyTimeout == 0 { + r.protoRoom.EmptyTimeout = roomConfig.EmptyTimeout + } + if r.protoRoom.DepartureTimeout == 0 { + r.protoRoom.DepartureTimeout = roomConfig.DepartureTimeout + } + if r.protoRoom.CreationTime == 0 { + now := time.Now() + r.protoRoom.CreationTime = now.Unix() + r.protoRoom.CreationTimeMs = now.UnixMilli() + } + r.protoProxy = utils.NewProtoProxy(roomUpdateInterval, r.updateProto) + + r.createAgentDispatchesFromRoomAgent() + + r.launchRoomAgents(maps.Values(r.agentDispatches)) + + go r.audioUpdateWorker() + go r.connectionQualityWorker() + go r.changeUpdateWorker() + go r.simulationCleanupWorker() + + return r +} + +func (r *Room) Logger() logger.Logger { + return r.logger +} + +func (r *Room) ToProto() *livekit.Room { + return r.protoProxy.Get() +} + +func (r *Room) Name() livekit.RoomName { + return livekit.RoomName(r.protoRoom.Name) +} + +func (r *Room) ID() livekit.RoomID { + return livekit.RoomID(r.protoRoom.Sid) +} + +func (r *Room) Trailer() []byte { + r.lock.RLock() + defer r.lock.RUnlock() + + trailer := make([]byte, len(r.trailer)) + copy(trailer, r.trailer) + return trailer +} + +func (r *Room) GetParticipant(identity livekit.ParticipantIdentity) types.LocalParticipant { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.participants[identity] +} + +func (r *Room) GetParticipantByID(participantID livekit.ParticipantID) types.LocalParticipant { + r.lock.RLock() + defer r.lock.RUnlock() + + for _, p := range r.participants { + if p.ID() == participantID { + return p + } + } + + return nil +} + +func (r *Room) GetParticipants() []types.LocalParticipant { + r.lock.RLock() + defer r.lock.RUnlock() + + return maps.Values(r.participants) +} + +func (r *Room) GetLocalParticipants() []types.LocalParticipant { + return r.GetParticipants() +} + +func (r *Room) GetParticipantCount() int { + r.lock.RLock() + defer r.lock.RUnlock() + + return len(r.participants) +} + +func (r *Room) GetActiveSpeakers() []*livekit.SpeakerInfo { + participants := r.GetParticipants() + speakers := make([]*livekit.SpeakerInfo, 0, len(participants)) + for _, p := range participants { + level, active := p.GetAudioLevel() + if !active { + continue + } + speakers = append(speakers, &livekit.SpeakerInfo{ + Sid: string(p.ID()), + Level: float32(level), + Active: active, + }) + } + + sort.Slice(speakers, func(i, j int) bool { + return speakers[i].Level > speakers[j].Level + }) + + // quantize to smooth out small changes + for _, speaker := range speakers { + speaker.Level = float32(math.Ceil(float64(speaker.Level*AudioLevelQuantization)) * invAudioLevelQuantization) + } + + return speakers +} + +func (r *Room) GetBufferFactory() *buffer.Factory { + return r.bufferFactory.CreateBufferFactory() +} + +func (r *Room) FirstJoinedAt() int64 { + return r.joinedAt.Load() +} + +func (r *Room) LastLeftAt() int64 { + return r.leftAt.Load() +} + +func (r *Room) Internal() *livekit.RoomInternal { + return r.internal +} + +func (r *Room) Hold() bool { + r.lock.Lock() + defer r.lock.Unlock() + + if r.IsClosed() { + return false + } + + r.holds.Inc() + return true +} + +func (r *Room) Release() { + r.holds.Dec() +} + +func (r *Room) Join( + participant types.LocalParticipant, + requestSource routing.MessageSource, + opts *ParticipantOptions, + iceServers []*livekit.ICEServer, +) error { + r.lock.Lock() + defer r.lock.Unlock() + + if r.IsClosed() { + return ErrRoomClosed + } + + if r.participants[participant.Identity()] != nil { + return ErrAlreadyJoined + } + if r.protoRoom.MaxParticipants > 0 && !participant.IsDependent() { + numParticipants := uint32(0) + for _, p := range r.participants { + if !p.IsDependent() { + numParticipants++ + } + } + if numParticipants >= r.protoRoom.MaxParticipants { + return ErrMaxParticipantsExceeded + } + } + + if r.FirstJoinedAt() == 0 && !participant.IsDependent() { + r.joinedAt.Store(time.Now().Unix()) + } + + r.launchTargetAgents(maps.Values(r.agentDispatches), participant, livekit.JobType_JT_PARTICIPANT) + + r.logger.Debugw( + "new participant joined", + "pID", participant.ID(), + "participant", participant.Identity(), + "clientInfo", logger.Proto(participant.GetClientInfo()), + "options", opts, + "numParticipants", len(r.participants), + ) + + if participant.IsRecorder() && !r.protoRoom.ActiveRecording { + r.protoRoom.ActiveRecording = true + r.protoProxy.MarkDirty(true) + } else { + r.protoProxy.MarkDirty(false) + } + + r.participants[participant.Identity()] = participant + r.participantOpts[participant.Identity()] = opts + r.participantRequestSources[participant.Identity()] = requestSource + + if r.onParticipantChanged != nil { + r.onParticipantChanged(participant) + } + + time.AfterFunc(time.Minute, func() { + if !participant.Verify() { + r.RemoveParticipant(participant.Identity(), participant.ID(), types.ParticipantCloseReasonJoinTimeout) + } + }) + + joinResponse := r.createJoinResponseLocked(participant, iceServers) + if err := participant.SendJoinResponse(joinResponse); err != nil { + prometheus.RecordServiceOperationError("participant_join", "send_response") + return err + } + + participant.SetMigrateState(types.MigrateStateComplete) + + if participant.SubscriberAsPrimary() { + // initiates sub connection as primary + if participant.ProtocolVersion().SupportFastStart() { + go func() { + r.subscribeToExistingTracks(participant, false) + participant.Negotiate(true) + }() + } else { + participant.Negotiate(true) + } + } else { + if participant.IsUsingSinglePeerConnection() { + go r.subscribeToExistingTracks(participant, false) + } + } + + prometheus.RecordServiceOperationSuccess("participant_join") + + return nil +} + +func (r *Room) ReplaceParticipantRequestSource(identity livekit.ParticipantIdentity, reqSource routing.MessageSource) { + r.lock.Lock() + if rs, ok := r.participantRequestSources[identity]; ok { + rs.Close() + } + r.participantRequestSources[identity] = reqSource + r.lock.Unlock() +} + +func (r *Room) GetParticipantRequestSource(identity livekit.ParticipantIdentity) routing.MessageSource { + r.lock.RLock() + defer r.lock.RUnlock() + return r.participantRequestSources[identity] +} + +func (r *Room) ResumeParticipant( + p types.LocalParticipant, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + iceConfig *livekit.ICEConfig, + iceServers []*livekit.ICEServer, + reason livekit.ReconnectReason, +) error { + r.ReplaceParticipantRequestSource(p.Identity(), requestSource) + // close previous sink, and link to new one + p.SwapResponseSink(responseSink, types.SignallingCloseReasonResume) + + p.SetSignalSourceValid(true) + + // check for simulated signal disconnect on resume before sending any signal response messages + r.simulationLock.Lock() + if state, ok := r.disconnectSignalOnResumeNoMessagesParticipants[p.Identity()]; ok { + // WARNING: this uses knowledge that service layer tries internally + simulated := false + if time.Now().Before(state.expiry) { + state.closedCount++ + p.CloseSignalConnection(types.SignallingCloseReasonDisconnectOnResumeNoMessages) + simulated = true + } + if state.closedCount == 3 { + delete(r.disconnectSignalOnResumeNoMessagesParticipants, p.Identity()) + } + if simulated { + r.simulationLock.Unlock() + return nil + } + } + r.simulationLock.Unlock() + + if err := p.HandleReconnectAndSendResponse(reason, &livekit.ReconnectResponse{ + IceServers: iceServers, + ClientConfiguration: p.GetClientConfiguration(), + ServerInfo: r.serverInfo, + LastMessageSeq: p.GetLastReliableSequence(false), + }); err != nil { + return err + } + + // include the local participant's info as well, since metadata could have been changed + updates := GetOtherParticipantInfo(nil, false, toParticipants(r.GetParticipants()), false) + if err := p.SendParticipantUpdate(updates); err != nil { + return err + } + + _ = p.SendRoomUpdate(r.ToProto()) + p.ICERestart(iceConfig) + + // check for simulated signal disconnect on resume + r.simulationLock.Lock() + if timeout, ok := r.disconnectSignalOnResumeParticipants[p.Identity()]; ok { + if time.Now().Before(timeout) { + p.CloseSignalConnection(types.SignallingCloseReasonDisconnectOnResume) + } + delete(r.disconnectSignalOnResumeParticipants, p.Identity()) + } + r.simulationLock.Unlock() + + return nil +} + +func (r *Room) HandleSyncState(participant types.LocalParticipant, state *livekit.SyncState) error { + if state != nil { + return r.onSyncState(participant, state) + } + + return nil +} + +func (r *Room) onSyncState(participant types.LocalParticipant, state *livekit.SyncState) error { + pLogger := participant.GetLogger() + pLogger.Infow("setting sync state", "state", logger.Proto(state)) + + shouldReconnect := false + pubTracks := state.GetPublishTracks() + existingPubTracks := participant.GetPublishedTracks() + for _, pubTrack := range pubTracks { + // client may not have sent TrackInfo for each published track + ti := pubTrack.Track + if ti == nil { + pLogger.Warnw("TrackInfo not sent during resume", nil) + shouldReconnect = true + break + } + + found := false + for _, existingPubTrack := range existingPubTracks { + if existingPubTrack.ID() == livekit.TrackID(ti.Sid) { + found = true + break + } + } + if !found { + // is there a pending track? + found = participant.GetPendingTrack(livekit.TrackID(ti.Sid)) != nil + } + if !found { + pLogger.Warnw("unknown track during resume", nil, "trackID", ti.Sid) + shouldReconnect = true + break + } + } + + pubDataTracks := state.GetPublishDataTracks() + existingPubDataTracks := participant.GetPublishedDataTracks() + for _, pubDataTrack := range pubDataTracks { + // client may not have sent DataTrackInfo for each published data track + dti := pubDataTrack.Info + if dti == nil { + pLogger.Warnw("DataTrackInfo not sent during resume", nil) + shouldReconnect = true + break + } + + found := false + for _, existingPubDataTrack := range existingPubDataTracks { + if existingPubDataTrack.ID() == livekit.TrackID(dti.Sid) { + found = true + break + } + } + if !found { + pLogger.Warnw("unknown data track during resume", nil, "trackID", dti.Sid) + shouldReconnect = true + break + } + } + + if shouldReconnect { + pLogger.Warnw("unable to resume due to missing published tracks, starting full reconnect", nil) + participant.IssueFullReconnect(types.ParticipantCloseReasonPublicationError) + return nil + } + + // synthesize a track setting for each disabled track, + // can be set before adding subscriptions, + // in fact it is done before so that setting can be updated immediately upon subscription. + for _, trackSid := range state.TrackSidsDisabled { + participant.UpdateSubscribedTrackSettings(livekit.TrackID(trackSid), &livekit.UpdateTrackSettings{Disabled: true}) + } + + participant.HandleUpdateSubscriptions( + livekit.StringsAsIDs[livekit.TrackID](state.Subscription.TrackSids), + state.Subscription.ParticipantTracks, + state.Subscription.Subscribe, + ) + return nil +} + +func (r *Room) onUpdateSubscriptionPermission(participant types.LocalParticipant, subscriptionPermission *livekit.SubscriptionPermission) error { + if err := participant.UpdateSubscriptionPermission(subscriptionPermission, utils.TimedVersion(0), r.GetParticipantByID); err != nil { + return err + } + for _, track := range participant.GetPublishedTracks() { + r.trackManager.NotifyTrackChanged(track.ID()) + } + return nil +} + +func (r *Room) ResolveMediaTrackForSubscriber(sub types.LocalParticipant, trackID livekit.TrackID) types.MediaResolverResult { + res := types.MediaResolverResult{} + + info := r.trackManager.GetTrackInfo(trackID) + res.TrackChangedNotifier = r.trackManager.GetOrCreateTrackChangeNotifier(trackID) + + if info == nil { + return res + } + + res.Track = info.Track + res.TrackRemovedNotifier = r.trackManager.GetOrCreateTrackRemoveNotifier(trackID) + res.PublisherIdentity = info.PublisherIdentity + res.PublisherID = info.PublisherID + + pub := r.GetParticipantByID(info.PublisherID) + // when publisher is not found, we will assume it doesn't have permission to access + if pub != nil { + res.HasPermission = IsParticipantExemptFromTrackPermissionsRestrictions(sub) || pub.HasPermission(trackID, sub.Identity()) + } + + return res +} + +func (r *Room) ResolveDataTrackForSubscriber(sub types.LocalParticipant, trackID livekit.TrackID) types.DataResolverResult { + res := types.DataResolverResult{} + + info := r.trackManager.GetDataTrackInfo(trackID) + res.TrackChangedNotifier = r.trackManager.GetOrCreateTrackChangeNotifier(trackID) + + if info == nil { + return res + } + + res.DataTrack = info.DataTrack + res.TrackRemovedNotifier = r.trackManager.GetOrCreateTrackRemoveNotifier(trackID) + res.PublisherIdentity = info.PublisherIdentity + res.PublisherID = info.PublisherID + return res +} + +func (r *Room) IsClosed() bool { + select { + case <-r.closed: + return true + default: + return false + } +} + +// CloseIfEmpty closes the room if all participants had left, or it's still empty past timeout +func (r *Room) CloseIfEmpty() { + r.lock.Lock() + + if r.IsClosed() || r.holds.Load() > 0 { + r.lock.Unlock() + return + } + + for _, p := range r.participants { + if !p.IsDependent() { + r.lock.Unlock() + return + } + } + + var timeout uint32 + var elapsed int64 + var reason string + if r.FirstJoinedAt() > 0 && r.LastLeftAt() > 0 { + elapsed = time.Now().Unix() - r.LastLeftAt() + // need to give time in case participant is reconnecting + timeout = r.protoRoom.DepartureTimeout + reason = "departure timeout" + } else { + elapsed = time.Now().Unix() - r.protoRoom.CreationTime + timeout = r.protoRoom.EmptyTimeout + reason = "empty timeout" + } + r.lock.Unlock() + + if elapsed >= int64(timeout) { + r.Close(types.ParticipantCloseReasonRoomClosed) + r.logger.Infow("closing idle room", "reason", reason) + } +} + +func (r *Room) Close(reason types.ParticipantCloseReason) { + r.lock.Lock() + select { + case <-r.closed: + r.lock.Unlock() + return + default: + // fall through + } + close(r.closed) + r.lock.Unlock() + + r.logger.Infow("closing room") + for _, p := range r.GetParticipants() { + _ = p.Close(true, reason, false) + } + + r.protoProxy.Stop() + + if r.onClose != nil { + r.onClose() + } +} + +func (r *Room) OnClose(f func()) { + r.onClose = f +} + +func (r *Room) OnParticipantChanged(f func(participant types.Participant)) { + r.onParticipantChanged = f +} + +func (r *Room) SendDataPacket(dp *livekit.DataPacket, kind livekit.DataPacket_Kind) { + r.onDataMessage(nil, kind, dp) +} + +func (r *Room) SetMetadata(metadata string) <-chan struct{} { + r.lock.Lock() + r.protoRoom.Metadata = metadata + r.lock.Unlock() + return r.protoProxy.MarkDirty(true) +} + +func (r *Room) sendRoomUpdate() { + roomInfo := r.ToProto() + // Send update to participants + for _, p := range r.GetParticipants() { + if !p.IsReady() { + continue + } + + err := p.SendRoomUpdate(roomInfo) + if err != nil { + r.logger.Warnw("failed to send room update", err, "participant", p.Identity()) + } + } +} + +func (r *Room) GetAgentDispatches(dispatchID string) ([]*livekit.AgentDispatch, error) { + r.lock.RLock() + defer r.lock.RUnlock() + + var ret []*livekit.AgentDispatch + + for _, ad := range r.agentDispatches { + if dispatchID == "" || ad.Id == dispatchID { + ret = append(ret, utils.CloneProto(ad.AgentDispatch)) + } + } + + return ret, nil +} + +func (r *Room) AddAgentDispatch(dispatch *livekit.AgentDispatch) (*livekit.AgentDispatch, error) { + ad, err := r.createAgentDispatch(dispatch) + if err != nil { + return nil, err + } + + r.launchRoomAgents([]*agentDispatch{ad}) + + r.lock.RLock() + // launchPublisherAgents starts a goroutine to send requests, so is safe to call locked + for _, p := range r.participants { + if p.IsPublisher() { + r.launchTargetAgents([]*agentDispatch{ad}, p, livekit.JobType_JT_PUBLISHER) + } + r.launchTargetAgents([]*agentDispatch{ad}, p, livekit.JobType_JT_PARTICIPANT) + } + r.lock.RUnlock() + + return ad.AgentDispatch, nil +} + +func (r *Room) DeleteAgentDispatch(dispatchID string) (*livekit.AgentDispatch, error) { + r.lock.Lock() + ad := r.agentDispatches[dispatchID] + if ad == nil { + r.lock.Unlock() + return nil, psrpc.NewErrorf(psrpc.NotFound, "dispatch ID not found") + } + + delete(r.agentDispatches, dispatchID) + r.lock.Unlock() + + // Should Delete be synchronous instead? + go func() { + ad.waitForPendingJobs() + + var jobs []*livekit.Job + r.lock.RLock() + if ad.State != nil { + jobs = ad.State.Jobs + } + r.lock.RUnlock() + + for _, j := range jobs { + state, err := r.agentClient.TerminateJob(context.Background(), j.Id, rpc.JobTerminateReason_TERMINATION_REQUESTED) + if err != nil { + continue + } + if state.ParticipantIdentity != "" { + r.lock.RLock() + agentJob := r.agentParticpants[livekit.ParticipantIdentity(state.ParticipantIdentity)] + p := r.participants[livekit.ParticipantIdentity(state.ParticipantIdentity)] + r.lock.RUnlock() + + if p != nil { + if agentJob != nil { + err := agentJob.waitForParticipantLeaving() + if err == ErrJobShutdownTimeout { + r.logger.Infow("Agent Worker did not disconnect after 3s") + } + } + r.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonServiceRequestRemoveParticipant) + } + } + r.lock.Lock() + j.State = state + r.lock.Unlock() + } + }() + + return ad.AgentDispatch, nil +} + +func (r *Room) OnRoomUpdated(f func()) { + r.onRoomUpdated = f +} + +func (r *Room) onSimulateScenario(participant types.LocalParticipant, simulateScenario *livekit.SimulateScenario) error { + switch scenario := simulateScenario.Scenario.(type) { + case *livekit.SimulateScenario_SpeakerUpdate: + r.logger.Infow("simulating speaker update", "participant", participant.Identity(), "duration", scenario.SpeakerUpdate) + go func() { + <-time.After(time.Duration(scenario.SpeakerUpdate) * time.Second) + r.sendSpeakerChanges([]*livekit.SpeakerInfo{{ + Sid: string(participant.ID()), + Active: false, + Level: 0, + }}) + }() + r.sendSpeakerChanges([]*livekit.SpeakerInfo{{ + Sid: string(participant.ID()), + Active: true, + Level: 0.9, + }}) + case *livekit.SimulateScenario_Migration: + r.logger.Infow("simulating migration", "participant", participant.Identity()) + // drop participant without necessarily cleaning up + if err := participant.Close(false, types.ParticipantCloseReasonSimulateMigration, true); err != nil { + return err + } + case *livekit.SimulateScenario_NodeFailure: + r.logger.Infow("simulating node failure", "participant", participant.Identity()) + // drop participant without necessarily cleaning up + if err := participant.Close(false, types.ParticipantCloseReasonSimulateNodeFailure, true); err != nil { + return err + } + case *livekit.SimulateScenario_ServerLeave: + r.logger.Infow("simulating server leave", "participant", participant.Identity()) + if err := participant.Close(true, types.ParticipantCloseReasonSimulateServerLeave, false); err != nil { + return err + } + case *livekit.SimulateScenario_SwitchCandidateProtocol: + r.logger.Infow("simulating switch candidate protocol", "participant", participant.Identity()) + participant.ICERestart(&livekit.ICEConfig{ + PreferenceSubscriber: livekit.ICECandidateType(scenario.SwitchCandidateProtocol), + PreferencePublisher: livekit.ICECandidateType(scenario.SwitchCandidateProtocol), + }) + case *livekit.SimulateScenario_SubscriberBandwidth: + if scenario.SubscriberBandwidth > 0 { + r.logger.Infow("simulating subscriber bandwidth start", "participant", participant.Identity(), "bandwidth", scenario.SubscriberBandwidth) + } else { + r.logger.Infow("simulating subscriber bandwidth end", "participant", participant.Identity()) + } + participant.SetSubscriberChannelCapacity(scenario.SubscriberBandwidth) + case *livekit.SimulateScenario_DisconnectSignalOnResume: + participant.GetLogger().Infow("simulating disconnect signal on resume") + r.simulationLock.Lock() + r.disconnectSignalOnResumeParticipants[participant.Identity()] = time.Now().Add(simulateDisconnectSignalTimeout) + r.simulationLock.Unlock() + case *livekit.SimulateScenario_DisconnectSignalOnResumeNoMessages: + participant.GetLogger().Infow("simulating disconnect signal on resume before sending any response messages") + r.simulationLock.Lock() + r.disconnectSignalOnResumeNoMessagesParticipants[participant.Identity()] = &disconnectSignalOnResumeNoMessages{ + expiry: time.Now().Add(simulateDisconnectSignalTimeout), + } + r.simulationLock.Unlock() + } + return nil +} + +// checks if participant should be auto subscribed to new tracks, assumes lock is already acquired +func (r *Room) autoSubscribe(participant types.LocalParticipant) bool { + opts := r.participantOpts[participant.Identity()] + // default to true if no options are set + if opts != nil && !opts.AutoSubscribe { + return false + } + return true +} + +// checks if participant should be auto subscribed to new data tracks, assumes lock is already acquired +func (r *Room) autoSubscribeDataTrack(participant types.LocalParticipant) bool { + opts := r.participantOpts[participant.Identity()] + // default to true if no options are set + if opts != nil && !opts.AutoSubscribeDataTrack { + return false + } + return true +} + +func (r *Room) createJoinResponseLocked( + participant types.LocalParticipant, + iceServers []*livekit.ICEServer, +) *livekit.JoinResponse { + iceConfig := participant.GetICEConfig() + hasICEFallback := iceConfig.GetPreferencePublisher() != livekit.ICECandidateType_ICT_NONE || iceConfig.GetPreferenceSubscriber() != livekit.ICECandidateType_ICT_NONE + return &livekit.JoinResponse{ + Room: r.ToProto(), + Participant: participant.ToProto(), + OtherParticipants: GetOtherParticipantInfo( + participant, + false, // isMigratingIn + toParticipants(maps.Values(r.participants)), + false, // skipSubscriberBroadcast + ), + IceServers: iceServers, + // indicates both server and client support subscriber as primary + SubscriberPrimary: participant.SubscriberAsPrimary(), + ClientConfiguration: participant.GetClientConfiguration(), + // sane defaults for ping interval & timeout + PingInterval: PingIntervalSeconds, + PingTimeout: PingTimeoutSeconds, + ServerInfo: r.serverInfo, + ServerVersion: r.serverInfo.Version, + ServerRegion: r.serverInfo.Region, + SifTrailer: r.trailer, + EnabledPublishCodecs: participant.GetEnabledPublishCodecs(), + FastPublish: participant.CanPublish() && !hasICEFallback, + } +} + +// a ParticipantImpl in the room added a new track, subscribe other participants to it +func (r *Room) onTrackPublished(participant types.Participant, track types.MediaTrack) { + r.trackManager.AddTrack(track, participant.Identity(), participant.ID()) + + // publish participant update, since track state is changed + r.broadcastParticipantState(participant, broadcastOptions{skipSource: true}) + + r.lock.RLock() + // subscribe all existing participants to this MediaTrack + for _, existingParticipant := range r.participants { + if existingParticipant == participant { + // skip publishing participant + continue + } + if existingParticipant.State() != livekit.ParticipantInfo_ACTIVE { + // not fully joined. don't subscribe yet + continue + } + if !r.autoSubscribe(existingParticipant) { + continue + } + + existingParticipant.GetLogger().Debugw( + "subscribing to new track", + "publisher", participant.Identity(), + "publisherID", participant.ID(), + "trackID", track.ID(), + ) + existingParticipant.SubscribeToTrack(track.ID(), false) + } + onParticipantChanged := r.onParticipantChanged + r.lock.RUnlock() + + if onParticipantChanged != nil { + onParticipantChanged(participant) + } + + // launch jobs + r.lock.Lock() + hasPublished := r.hasPublished[participant.Identity()] + r.hasPublished[participant.Identity()] = true + r.lock.Unlock() + + if !hasPublished { + r.lock.RLock() + r.launchTargetAgents(maps.Values(r.agentDispatches), participant, livekit.JobType_JT_PUBLISHER) + r.lock.RUnlock() + if r.internal != nil && r.internal.ParticipantEgress != nil { + go func() { + if err := StartParticipantEgress( + context.Background(), + r.egressLauncher, + r.telemetry, + r.internal.ParticipantEgress, + participant.Identity(), + r.Name(), + r.ID(), + ); err != nil { + r.logger.Errorw("failed to launch participant egress", err) + } + }() + } + } + if participant.Kind() != livekit.ParticipantInfo_EGRESS && r.internal != nil && r.internal.TrackEgress != nil { + go func() { + if err := StartTrackEgress( + context.Background(), + r.egressLauncher, + r.telemetry, + r.internal.TrackEgress, + track, + r.Name(), + r.ID(), + ); err != nil { + r.logger.Errorw("failed to launch track egress", err) + } + }() + } +} + +func (r *Room) onTrackUpdated(p types.Participant, _ types.MediaTrack) { + // send track updates to everyone, especially if track was updated by admin + r.broadcastParticipantState(p, broadcastOptions{}) + if r.onParticipantChanged != nil { + r.onParticipantChanged(p) + } +} + +func (r *Room) onTrackUnpublished(p types.Participant, track types.MediaTrack) { + r.trackManager.RemoveTrack(track) + if !p.IsClosed() { + r.broadcastParticipantState(p, broadcastOptions{skipSource: true}) + } + if r.onParticipantChanged != nil { + r.onParticipantChanged(p) + } +} + +func (r *Room) onDataTrackPublished(participant types.Participant, dt types.DataTrack) { + r.trackManager.AddDataTrack(dt, participant.Identity(), participant.ID()) + + // publish participant update, since a new data track was published + r.broadcastParticipantState(participant, broadcastOptions{skipSource: true}) + + r.lock.RLock() + // subscribe all existing participants to this DataTrack + for _, existingParticipant := range r.participants { + if existingParticipant == participant { + // skip publishing participant + continue + } + if existingParticipant.State() != livekit.ParticipantInfo_ACTIVE { + // not fully joined. don't subscribe yet + continue + } + if !r.autoSubscribeDataTrack(existingParticipant) { + continue + } + + existingParticipant.GetLogger().Debugw( + "subscribing to new data track", + "publisher", participant.Identity(), + "publisherID", participant.ID(), + "trackID", dt.ID(), + ) + existingParticipant.SubscribeToDataTrack(dt.ID()) + } + onParticipantChanged := r.onParticipantChanged + r.lock.RUnlock() + + if onParticipantChanged != nil { + onParticipantChanged(participant) + } +} + +func (r *Room) onDataTrackUnpublished(p types.Participant, dt types.DataTrack) { + r.trackManager.RemoveDataTrack(dt) + if !p.IsClosed() { + r.broadcastParticipantState(p, broadcastOptions{skipSource: true}) + } + if r.onParticipantChanged != nil { + r.onParticipantChanged(p) + } +} + +func (r *Room) onParticipantUpdate(p types.Participant) { + r.protoProxy.MarkDirty(false) + // immediately notify when permissions or metadata changed + r.broadcastParticipantState(p, broadcastOptions{immediate: true}) + if r.onParticipantChanged != nil { + r.onParticipantChanged(p) + } +} + +func (r *Room) onStateChange(p types.LocalParticipant) { + if r.onParticipantChanged != nil { + r.onParticipantChanged(p) + } + r.broadcastParticipantState(p, broadcastOptions{skipSource: true}) + + r.onStateChangeMu.Lock() + defer r.onStateChangeMu.Unlock() + + switch p.State() { + case livekit.ParticipantInfo_ACTIVE: + // subscribe participant to existing published tracks + r.subscribeToExistingTracks(p, false) + + connectTime := time.Since(p.ConnectedAt()) + meta := &livekit.AnalyticsClientMeta{ + ClientConnectTime: uint32(connectTime.Milliseconds()), + } + infos := p.GetICEConnectionInfo() + var connectionType roomobs.ConnectionType + for _, info := range infos { + if info.Type != types.ICEConnectionTypeUnknown { + meta.ConnectionType = info.Type.String() + connectionType = info.Type.ReporterType() + break + } + } + r.telemetry.ParticipantActive(context.Background(), + r.ToProto(), + p.ToProto(), + meta, + false, + p.TelemetryGuard(), + ) + + p.GetReporter().Tx(func(tx roomobs.ParticipantSessionTx) { + tx.ReportClientConnectTime(uint16(connectTime.Milliseconds())) + tx.ReportConnectResult(roomobs.ConnectionResultSuccess) + tx.ReportConnectionType(connectionType) + }) + + fields := append( + connectionDetailsFields(infos), + "clientInfo", logger.Proto(sutils.ClientInfoWithoutAddress(p.GetClientInfo())), + "connectTime", connectTime, + ) + p.GetLogger().Infow("participant active", fields...) + + case livekit.ParticipantInfo_DISCONNECTED: + // remove participant from room + go r.RemoveParticipant(p.Identity(), p.ID(), p.CloseReason()) + } +} + +func (r *Room) onDataMessage(source types.LocalParticipant, kind livekit.DataPacket_Kind, dp *livekit.DataPacket) { + if kind == livekit.DataPacket_RELIABLE && source != nil && dp.GetSequence() > 0 { + data, err := proto.Marshal(dp) + if err != nil { + r.logger.Errorw("failed to marshal data packet for cache", err, "participant", source.Identity(), "seq", dp.GetSequence()) + return + } + r.dataMessageCache.Add(&types.DataMessageCache{ + SenderID: source.ID(), + Seq: dp.GetSequence(), + Data: data, + DestIdentities: livekit.StringsAsIDs[livekit.ParticipantIdentity](dp.DestinationIdentities), + }, len(data)) + } + BroadcastDataPacketForRoom(r, source, kind, dp, r.logger) +} + +func (r *Room) onDataMessageUnlabeled(source types.LocalParticipant, data []byte) { + BroadcastDataMessageForRoom(r, source, data, r.logger) +} + +func (r *Room) onMetrics(source types.Participant, dp *livekit.DataPacket) { + BroadcastMetricsForRoom(r, source, dp, r.logger) +} + +func (r *Room) onSubscribeStatusChanged(participant types.LocalParticipant, publisherID livekit.ParticipantID, subscribed bool) { + if subscribed { + pub := r.GetParticipantByID(publisherID) + if pub != nil && pub.State() == livekit.ParticipantInfo_ACTIVE { + // when a participant subscribes to another participant, + // send speaker update if the subscribed to participant is active. + level, active := pub.GetAudioLevel() + if active { + _ = participant.SendSpeakerUpdate([]*livekit.SpeakerInfo{ + { + Sid: string(pub.ID()), + Level: float32(level), + Active: active, + }, + }, false) + } + + if cq := pub.GetConnectionQuality(); cq != nil { + update := &livekit.ConnectionQualityUpdate{} + update.Updates = append(update.Updates, cq) + _ = participant.SendConnectionQualityUpdate(update) + } + } + } else { + // no longer subscribed to the publisher, clear speaker status + _ = participant.SendSpeakerUpdate([]*livekit.SpeakerInfo{ + { + Sid: string(publisherID), + Level: 0, + Active: false, + }, + }, true) + } +} + +func (r *Room) onUpdateSubscriptions( + participant types.LocalParticipant, + trackIDs []livekit.TrackID, + participantTracks []*livekit.ParticipantTracks, + subscribe bool, +) { + r.UpdateSubscriptions(participant, trackIDs, participantTracks, subscribe) +} + +func (r *Room) UpdateSubscriptions( + participant types.LocalParticipant, + trackIDs []livekit.TrackID, + participantTracks []*livekit.ParticipantTracks, + subscribe bool, +) { + for _, trackID := range trackIDs { + if subscribe { + participant.SubscribeToTrack(trackID, false) + } else { + participant.UnsubscribeFromTrack(trackID) + } + } + + for _, pt := range participantTracks { + for _, trackID := range livekit.StringsAsIDs[livekit.TrackID](pt.TrackSids) { + if subscribe { + participant.SubscribeToTrack(trackID, false) + } else { + participant.UnsubscribeFromTrack(trackID) + } + } + } +} + +func (r *Room) onUpdateDataSubscriptions(participant types.LocalParticipant, req *livekit.UpdateDataSubscription) { + for _, update := range req.Updates { + trackID := livekit.TrackID(update.TrackSid) + if update.Subscribe { + participant.SubscribeToDataTrack(trackID) + participant.UpdateDataTrackSubscriptionOptions(trackID, update.Options) + } else { + participant.UnsubscribeFromDataTrack(trackID) + participant.UpdateDataTrackSubscriptionOptions(trackID, nil) + } + } +} + +func (r *Room) onLeave(p types.LocalParticipant, reason types.ParticipantCloseReason) { + r.RemoveParticipant(p.Identity(), p.ID(), reason) +} + +func (r *Room) RemoveParticipant( + identity livekit.ParticipantIdentity, + pID livekit.ParticipantID, + reason types.ParticipantCloseReason, +) { + r.lock.Lock() + p, ok := r.participants[identity] + if !ok { + r.lock.Unlock() + return + } + + if pID != "" && p.ID() != pID { + // participant session has been replaced + r.lock.Unlock() + return + } + + agentJob := r.agentParticpants[identity] + + delete(r.participants, identity) + delete(r.participantOpts, identity) + delete(r.participantRequestSources, identity) + delete(r.hasPublished, identity) + delete(r.agentParticpants, identity) + if !p.Hidden() { + r.protoRoom.NumParticipants-- + } + + immediateChange := false + if p.IsRecorder() { + activeRecording := false + for _, op := range r.participants { + if op.IsRecorder() { + activeRecording = true + break + } + } + + if r.protoRoom.ActiveRecording != activeRecording { + r.protoRoom.ActiveRecording = activeRecording + immediateChange = true + } + } + r.lock.Unlock() + r.protoProxy.MarkDirty(immediateChange) + + if !p.HasConnected() { + fields := append( + connectionDetailsFields(p.GetICEConnectionInfo()), + "reason", reason.String(), + "clientInfo", logger.Proto(sutils.ClientInfoWithoutAddress(p.GetClientInfo())), + ) + p.GetLogger().Infow("removing participant without connection", fields...) + } + + // send broadcast only if it's not already closed + sendUpdates := !p.IsDisconnected() + + // remove all published tracks + for _, t := range p.GetPublishedTracks() { + r.trackManager.RemoveTrack(t) + } + + // remove all published data tracks + for _, t := range p.GetPublishedDataTracks() { + r.trackManager.RemoveDataTrack(t) + } + + if agentJob != nil { + agentJob.participantLeft() + + go func() { + _, err := r.agentClient.TerminateJob(context.Background(), agentJob.Id, rpc.JobTerminateReason_AGENT_LEFT_ROOM) + if err != nil { + r.logger.Infow("failed sending TerminateJob RPC", "error", err, "jobID", agentJob.Id, "participant", identity) + } + }() + } + + p.ClearParticipantListener() + + // close participant as well + _ = p.Close(true, reason, false) + + r.leftAt.Store(time.Now().Unix()) + + if sendUpdates { + if r.onParticipantChanged != nil { + r.onParticipantChanged(p) + } + r.broadcastParticipantState(p, broadcastOptions{skipSource: true}) + } +} + +func (r *Room) subscribeToExistingTracks(p types.LocalParticipant, isSync bool) { + r.lock.RLock() + autoSubscribe := r.autoSubscribe(p) + autoSubscribeDataTrack := r.autoSubscribeDataTrack(p) + r.lock.RUnlock() + + var trackIDs []livekit.TrackID + for _, op := range r.GetParticipants() { + if p.ID() == op.ID() { + // don't send to itself + continue + } + + // subscribe to all + if autoSubscribe { + for _, track := range op.GetPublishedTracks() { + trackIDs = append(trackIDs, track.ID()) + p.SubscribeToTrack(track.ID(), isSync) + } + } + + if autoSubscribeDataTrack { + for _, track := range op.GetPublishedDataTracks() { + trackIDs = append(trackIDs, track.ID()) + p.SubscribeToDataTrack(track.ID()) + } + } + } + if len(trackIDs) > 0 { + p.GetLogger().Debugw("subscribed participant to existing tracks", "trackID", trackIDs) + } +} + +// broadcast an update about participant p +func (r *Room) broadcastParticipantState(p types.Participant, opts broadcastOptions) { + pi := p.ToProto() + + // send it to the same participant immediately + selfSent := false + if !opts.skipSource { + defer func() { + if selfSent { + return + } + + if lp, ok := p.(types.LocalParticipant); ok { + err := lp.SendParticipantUpdate([]*livekit.ParticipantInfo{pi}) + if err != nil { + lp.GetLogger().Errorw("could not send update to participant", err) + } + } + }() + } + + if p.Hidden() { + // hidden participant updates are sent only to the hidden participant itself, + // these could be things like metadata update + return + } + + r.batchedUpdatesMu.Lock() + updates := PushAndDequeueUpdates( + pi, + p.CloseReason(), + opts.immediate, + r.GetParticipant(livekit.ParticipantIdentity(pi.Identity)), + r.batchedUpdates, + ) + r.batchedUpdatesMu.Unlock() + if len(updates) != 0 { + selfSent = true + SendParticipantUpdates(updates, r.GetParticipants(), r.roomConfig.UpdateBatchTargetSize) + } +} + +// for protocol 3, send only changed updates +func (r *Room) sendSpeakerChanges(speakers []*livekit.SpeakerInfo) { + for _, p := range r.GetParticipants() { + if p.ProtocolVersion().SupportsSpeakerChanged() { + _ = p.SendSpeakerUpdate(speakers, false) + } + } +} + +func (r *Room) updateProto() *livekit.Room { + r.lock.RLock() + room := utils.CloneProto(r.protoRoom) + r.lock.RUnlock() + + room.NumPublishers = 0 + room.NumParticipants = 0 + for _, p := range r.GetParticipants() { + if !p.IsDependent() { + room.NumParticipants++ + } + if p.IsPublisher() { + room.NumPublishers++ + } + } + + return room +} + +func (r *Room) changeUpdateWorker() { + subTicker := time.NewTicker(subscriberUpdateInterval) + defer subTicker.Stop() + + cleanDataMessageTicker := time.NewTicker(dataMessageCacheTTL) + + for !r.IsClosed() { + select { + case <-r.closed: + return + case <-r.protoProxy.Updated(): + if r.onRoomUpdated != nil { + r.onRoomUpdated() + } + r.sendRoomUpdate() + case <-subTicker.C: + r.batchedUpdatesMu.Lock() + if len(r.batchedUpdates) == 0 { + r.batchedUpdatesMu.Unlock() + continue + } + updatesMap := r.batchedUpdates + r.batchedUpdates = make(map[livekit.ParticipantIdentity]*ParticipantUpdate) + r.batchedUpdatesMu.Unlock() + + SendParticipantUpdates(maps.Values(updatesMap), r.GetParticipants(), r.roomConfig.UpdateBatchTargetSize) + + case <-cleanDataMessageTicker.C: + r.dataMessageCache.Prune() + } + } +} + +func (r *Room) audioUpdateWorker() { + lastActiveMap := make(map[livekit.ParticipantID]*livekit.SpeakerInfo) + for { + if r.IsClosed() { + return + } + + activeSpeakers := r.GetActiveSpeakers() + changedSpeakers := make([]*livekit.SpeakerInfo, 0, len(activeSpeakers)) + nextActiveMap := make(map[livekit.ParticipantID]*livekit.SpeakerInfo, len(activeSpeakers)) + for _, speaker := range activeSpeakers { + prev := lastActiveMap[livekit.ParticipantID(speaker.Sid)] + if prev == nil || prev.Level != speaker.Level { + changedSpeakers = append(changedSpeakers, speaker) + } + nextActiveMap[livekit.ParticipantID(speaker.Sid)] = speaker + } + + // changedSpeakers need to include previous speakers that are no longer speaking + for sid, speaker := range lastActiveMap { + if nextActiveMap[sid] == nil { + inactiveSpeaker := utils.CloneProto(speaker) + inactiveSpeaker.Level = 0 + inactiveSpeaker.Active = false + changedSpeakers = append(changedSpeakers, inactiveSpeaker) + } + } + + // see if an update is needed + if len(changedSpeakers) > 0 { + r.sendSpeakerChanges(changedSpeakers) + } + + lastActiveMap = nextActiveMap + + time.Sleep(time.Duration(r.audioConfig.UpdateInterval) * time.Millisecond) + } +} + +func (r *Room) connectionQualityWorker() { + ticker := time.NewTicker(connectionquality.UpdateInterval) + defer ticker.Stop() + + prevConnectionInfos := make(map[livekit.ParticipantID]*livekit.ConnectionQualityInfo) + // send updates to only users that are subscribed to each other + for !r.IsClosed() { + <-ticker.C + + participants := r.GetParticipants() + nowConnectionInfos := make(map[livekit.ParticipantID]*livekit.ConnectionQualityInfo, len(participants)) + + for _, p := range participants { + if p.State() != livekit.ParticipantInfo_ACTIVE { + continue + } + + if q := p.GetConnectionQuality(); q != nil { + nowConnectionInfos[p.ID()] = q + } + } + + // send an update if there is a change + // - new participant + // - quality change + // NOTE: participant leaving is explicitly omitted as `leave` signal notifies that a participant is not in the room anymore + sendUpdate := false + for _, p := range participants { + pID := p.ID() + prevInfo, prevOk := prevConnectionInfos[pID] + nowInfo, nowOk := nowConnectionInfos[pID] + if !nowOk { + // participant is not ACTIVE any more + continue + } + if !prevOk || nowInfo.Quality != prevInfo.Quality { + // new entrant OR change in quality + sendUpdate = true + break + } + } + + if !sendUpdate { + prevConnectionInfos = nowConnectionInfos + continue + } + + maybeAddToUpdate := func(pID livekit.ParticipantID, update *livekit.ConnectionQualityUpdate) { + if nowInfo, nowOk := nowConnectionInfos[pID]; nowOk { + update.Updates = append(update.Updates, nowInfo) + } + } + + for _, op := range participants { + if !op.ProtocolVersion().SupportsConnectionQuality() || op.State() != livekit.ParticipantInfo_ACTIVE { + continue + } + update := &livekit.ConnectionQualityUpdate{} + + // send to user itself + maybeAddToUpdate(op.ID(), update) + + // add connection quality of other participants its subscribed to + for _, sid := range op.GetSubscribedParticipants() { + maybeAddToUpdate(sid, update) + } + if len(update.Updates) == 0 { + // no change + continue + } + if err := op.SendConnectionQualityUpdate(update); err != nil { + r.logger.Warnw("could not send connection quality update", err, + "participant", op.Identity()) + } + } + + prevConnectionInfos = nowConnectionInfos + } +} + +func (r *Room) simulationCleanupWorker() { + for { + if r.IsClosed() { + return + } + + now := time.Now() + r.simulationLock.Lock() + for identity, timeout := range r.disconnectSignalOnResumeParticipants { + if now.After(timeout) { + delete(r.disconnectSignalOnResumeParticipants, identity) + } + } + + for identity, state := range r.disconnectSignalOnResumeNoMessagesParticipants { + if now.After(state.expiry) { + delete(r.disconnectSignalOnResumeNoMessagesParticipants, identity) + } + } + r.simulationLock.Unlock() + + time.Sleep(10 * time.Second) + } +} + +func (r *Room) launchRoomAgents(ads []*agentDispatch) { + if r.agentClient == nil { + return + } + + for _, ad := range ads { + done := ad.jobsLaunching() + + go func() { + inc := r.agentClient.LaunchJob(context.Background(), &agent.JobRequest{ + JobType: livekit.JobType_JT_ROOM, + Room: r.ToProto(), + Metadata: ad.Metadata, + AgentName: ad.AgentName, + DispatchId: ad.Id, + }) + r.handleNewJobs(ad.AgentDispatch, inc) + done() + }() + } +} + +func (r *Room) launchTargetAgents(ads []*agentDispatch, p types.Participant, jobType livekit.JobType) { + if p == nil || p.IsDependent() || r.agentClient == nil { + return + } + + for _, ad := range ads { + done := ad.jobsLaunching() + + go func() { + inc := r.agentClient.LaunchJob(context.Background(), &agent.JobRequest{ + JobType: jobType, + Room: r.ToProto(), + Participant: p.ToProto(), + Metadata: ad.Metadata, + AgentName: ad.AgentName, + DispatchId: ad.Id, + }) + r.handleNewJobs(ad.AgentDispatch, inc) + done() + }() + } +} + +func (r *Room) handleNewJobs(ad *livekit.AgentDispatch, inc *sutils.IncrementalDispatcher[*livekit.Job]) { + inc.ForEach(func(job *livekit.Job) { + r.agentStore.StoreAgentJob(context.Background(), job) + r.lock.Lock() + ad.State.Jobs = append(ad.State.Jobs, job) + if job.State != nil && job.State.ParticipantIdentity != "" { + r.agentParticpants[livekit.ParticipantIdentity(job.State.ParticipantIdentity)] = newAgentJob(job) + } + r.lock.Unlock() + }) +} + +func (r *Room) DebugInfo() map[string]any { + info := map[string]any{ + "Name": r.protoRoom.Name, + "Sid": r.protoRoom.Sid, + "CreatedAt": r.protoRoom.CreationTime, + } + + participants := r.GetParticipants() + participantInfo := make(map[string]any) + for _, p := range participants { + participantInfo[string(p.Identity())] = p.DebugInfo() + } + info["Participants"] = participantInfo + + return info +} + +func (r *Room) createAgentDispatch(dispatch *livekit.AgentDispatch) (*agentDispatch, error) { + dispatch.State = &livekit.AgentDispatchState{ + CreatedAt: time.Now().UnixNano(), + } + ad := newAgentDispatch(dispatch) + + r.lock.Lock() + r.agentDispatches[ad.Id] = ad + r.lock.Unlock() + if r.agentStore != nil { + err := r.agentStore.StoreAgentDispatch(context.Background(), ad.AgentDispatch) + if err != nil { + return nil, err + } + } + + return ad, nil +} + +func (r *Room) createAgentDispatchFromParams(agentName string, metadata string) (*agentDispatch, error) { + return r.createAgentDispatch(&livekit.AgentDispatch{ + Id: guid.New(guid.AgentDispatchPrefix), + AgentName: agentName, + Metadata: metadata, + Room: r.protoRoom.Name, + }) +} + +func (r *Room) createAgentDispatchesFromRoomAgent() { + if r.internal == nil { + return + } + + roomDisp := r.internal.AgentDispatches + if len(roomDisp) == 0 { + // Backward compatibility: by default, start any agent in the empty JobName + roomDisp = []*livekit.RoomAgentDispatch{{}} + } + + for _, ag := range roomDisp { + _, err := r.createAgentDispatchFromParams(ag.AgentName, ag.Metadata) + if err != nil { + r.logger.Warnw("failed storing room dispatch", err) + } + } +} + +func (r *Room) IsDataMessageUserPacketDuplicate(up *livekit.UserPacket) bool { + return r.userPacketDeduper.IsDuplicate(up) +} + +func (r *Room) GetCachedReliableDataMessage(seqs map[livekit.ParticipantID]uint32) []*types.DataMessageCache { + msgs := make([]*types.DataMessageCache, 0, len(seqs)*10) + for _, msg := range r.dataMessageCache.Get() { + seq, ok := seqs[msg.SenderID] + if ok && msg.Seq >= seq { + msgs = append(msgs, msg) + } + } + return msgs +} + +func (r *Room) LocalParticipantListener() types.LocalParticipantListener { + return r.localParticipantListener +} + +// ------------------------------------------------------------ + +var _ types.LocalParticipantListener = (*localParticipantListener)(nil) + +type localParticipantListener struct { + room *Room +} + +func (l *localParticipantListener) OnParticipantUpdate(p types.Participant) { + l.room.onParticipantUpdate(p) +} + +func (l *localParticipantListener) OnTrackPublished(p types.Participant, track types.MediaTrack) { + l.room.onTrackPublished(p, track) +} + +func (l *localParticipantListener) OnTrackUpdated(p types.Participant, track types.MediaTrack) { + l.room.onTrackUpdated(p, track) +} + +func (l *localParticipantListener) OnTrackUnpublished(p types.Participant, track types.MediaTrack) { + l.room.onTrackUnpublished(p, track) +} + +func (l *localParticipantListener) OnDataTrackPublished(p types.Participant, track types.DataTrack) { + l.room.onDataTrackPublished(p, track) +} + +func (l *localParticipantListener) OnDataTrackUnpublished(p types.Participant, track types.DataTrack) { + l.room.onDataTrackUnpublished(p, track) +} + +func (l *localParticipantListener) OnDataTrackMessage(_p types.Participant, _data []byte, _packet *datatrack.Packet) { +} + +func (l *localParticipantListener) OnMetrics(p types.Participant, dp *livekit.DataPacket) { + l.room.onMetrics(p, dp) +} + +func (l *localParticipantListener) OnStateChange(p types.LocalParticipant) { + l.room.onStateChange(p) +} + +func (l *localParticipantListener) OnSubscriberReady(p types.LocalParticipant) { + l.room.subscribeToExistingTracks(p, false) +} + +func (l *localParticipantListener) OnMigrateStateChange(_p types.LocalParticipant, _migrateState types.MigrateState) { +} + +func (l *localParticipantListener) OnDataMessage(p types.LocalParticipant, kind livekit.DataPacket_Kind, dp *livekit.DataPacket) { + l.room.onDataMessage(p, kind, dp) +} + +func (l *localParticipantListener) OnDataMessageUnlabeled(p types.LocalParticipant, data []byte) { + l.room.onDataMessageUnlabeled(p, data) +} + +func (l *localParticipantListener) OnSubscribeStatusChanged(p types.LocalParticipant, publisherID livekit.ParticipantID, subscribed bool) { + l.room.onSubscribeStatusChanged(p, publisherID, subscribed) +} + +func (l *localParticipantListener) OnUpdateSubscriptions( + p types.LocalParticipant, + trackIDs []livekit.TrackID, + participantTracks []*livekit.ParticipantTracks, + subscribe bool, +) { + l.room.onUpdateSubscriptions(p, trackIDs, participantTracks, subscribe) +} + +func (l *localParticipantListener) OnUpdateSubscriptionPermission(p types.LocalParticipant, subscriptionPermission *livekit.SubscriptionPermission) error { + return l.room.onUpdateSubscriptionPermission(p, subscriptionPermission) +} + +func (l *localParticipantListener) OnUpdateDataSubscriptions(p types.LocalParticipant, req *livekit.UpdateDataSubscription) { + l.room.onUpdateDataSubscriptions(p, req) +} + +func (l *localParticipantListener) OnSyncState(p types.LocalParticipant, state *livekit.SyncState) error { + return l.room.onSyncState(p, state) +} + +func (l *localParticipantListener) OnSimulateScenario(p types.LocalParticipant, simulateScenario *livekit.SimulateScenario) error { + return l.room.onSimulateScenario(p, simulateScenario) +} + +func (l *localParticipantListener) OnLeave(p types.LocalParticipant, closeReason types.ParticipantCloseReason) { + l.room.onLeave(p, closeReason) +} + +// ------------------------------------------------------------ + +func BroadcastDataPacketForRoom( + r types.Room, + source types.LocalParticipant, + kind livekit.DataPacket_Kind, + dp *livekit.DataPacket, + logger logger.Logger, +) { + dp.Kind = kind // backward compatibility + dest := dp.GetUser().GetDestinationSids() + if u := dp.GetUser(); u != nil { + if r.IsDataMessageUserPacketDuplicate(u) { + logger.Infow("dropping duplicate data message", "nonce", u.Nonce) + return + } + if len(dp.DestinationIdentities) == 0 { + dp.DestinationIdentities = u.DestinationIdentities + } else { + u.DestinationIdentities = dp.DestinationIdentities + } + if dp.ParticipantIdentity != "" { + u.ParticipantIdentity = dp.ParticipantIdentity + } else { + dp.ParticipantIdentity = u.ParticipantIdentity + } + } + destIdentities := dp.DestinationIdentities + + participants := r.GetLocalParticipants() + capacity := len(destIdentities) + if capacity == 0 { + capacity = len(dest) + } + if capacity == 0 { + capacity = len(participants) + } + destParticipants := make([]types.LocalParticipant, 0, capacity) + + var dpData []byte + for _, op := range participants { + if source != nil && op.ID() == source.ID() { + continue + } + if len(dest) > 0 || len(destIdentities) > 0 { + if !slices.Contains(dest, string(op.ID())) && !slices.Contains(destIdentities, string(op.Identity())) { + continue + } + } + if dpData == nil { + var err error + dpData, err = proto.Marshal(dp) + if err != nil { + logger.Errorw("failed to marshal data packet", err) + return + } + } + destParticipants = append(destParticipants, op) + } + + utils.ParallelExec(destParticipants, dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { + op.SendDataMessage(kind, dpData, livekit.ParticipantID(dp.GetParticipantSid()), dp.GetSequence()) + }) +} + +func BroadcastDataMessageForRoom(r types.Room, source types.LocalParticipant, data []byte, logger logger.Logger) { + utils.ParallelExec(r.GetLocalParticipants(), dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { + if source != nil && op.ID() == source.ID() { + return + } + + op.SendDataMessageUnlabeled(data, false, source.Identity()) + }) +} + +func BroadcastMetricsForRoom(r types.Room, source types.Participant, dp *livekit.DataPacket, logger logger.Logger) { + switch payload := dp.Value.(type) { + case *livekit.DataPacket_Metrics: + utils.ParallelExec(r.GetLocalParticipants(), dataForwardLoadBalanceThreshold, 1, func(op types.LocalParticipant) { + // echoing back to sender too + op.HandleMetrics(source.ID(), payload.Metrics) + }) + default: + } +} + +func IsCloseNotifySkippable(closeReason types.ParticipantCloseReason) bool { + return closeReason == types.ParticipantCloseReasonDuplicateIdentity +} + +func IsParticipantExemptFromTrackPermissionsRestrictions(p types.LocalParticipant) bool { + // egress/recorder participants bypass permissions as auto-egress does not + // have enough context to check permissions + return p.IsRecorder() +} + +func CompareParticipant(pi1 *livekit.ParticipantInfo, pi2 *livekit.ParticipantInfo) int { + if pi1.JoinedAt != pi2.JoinedAt { + if pi1.JoinedAt < pi2.JoinedAt { + return -1 + } else { + return 1 + } + } + + if pi1.JoinedAtMs != 0 && pi2.JoinedAtMs != 0 && pi1.JoinedAtMs != pi2.JoinedAtMs { + if pi1.JoinedAtMs < pi2.JoinedAtMs { + return -1 + } else { + return 1 + } + } + + // all join times being equal, it is not possible to really know which one is newer, + // pick the higher pID to be consistent + if pi1.Sid != pi2.Sid { + if pi1.Sid < pi2.Sid { + return -1 + } else { + return 1 + } + } + + return 0 +} + +type ParticipantUpdate struct { + ParticipantInfo *livekit.ParticipantInfo + IsSynthesizedDisconnect bool + CloseReason types.ParticipantCloseReason +} + +// push a participant update for batched broadcast, optionally returning immediate updates to broadcast. +// it handles the following scenarios +// * subscriber-only updates will be queued for batch updates +// * publisher & immediate updates will be returned without queuing +// * when the SID changes, it will return both updates, with the earlier participant set to disconnected +func PushAndDequeueUpdates( + pi *livekit.ParticipantInfo, + closeReason types.ParticipantCloseReason, + isImmediate bool, + existingParticipant types.Participant, + cache map[livekit.ParticipantIdentity]*ParticipantUpdate, +) []*ParticipantUpdate { + var updates []*ParticipantUpdate + identity := livekit.ParticipantIdentity(pi.Identity) + existing := cache[identity] + shouldSend := isImmediate || pi.IsPublisher + + if existing != nil { + if pi.Sid == existing.ParticipantInfo.Sid { + // same participant session + if pi.Version < existing.ParticipantInfo.Version { + // out of order update + return nil + } + } else { + // different participant sessions + if CompareParticipant(existing.ParticipantInfo, pi) < 0 { + // existing is older, synthesize a DISCONNECT for older and + // send immediately along with newer session to signal switch + shouldSend = true + existing.ParticipantInfo.State = livekit.ParticipantInfo_DISCONNECTED + existing.IsSynthesizedDisconnect = true + updates = append(updates, existing) + } else { + // older session update, newer session has already become active, so nothing to do + return nil + } + } + } else { + if existingParticipant != nil { + epi := existingParticipant.ToProto() + if CompareParticipant(epi, pi) > 0 { + // older session update, newer session has already become active, so nothing to do + return nil + } + } + } + + if shouldSend { + // include any queued update, and return + delete(cache, identity) + updates = append(updates, &ParticipantUpdate{ParticipantInfo: pi, CloseReason: closeReason}) + } else { + // enqueue for batch + cache[identity] = &ParticipantUpdate{ParticipantInfo: pi, CloseReason: closeReason} + } + + return updates +} + +func SendParticipantUpdates(updates []*ParticipantUpdate, participants []types.LocalParticipant, batchTargetSize int) { + if len(updates) == 0 { + return + } + + // For filtered updates, skip + // 1. synthesized DISCONNECT - this happens on SID change + // 2. close reasons of DUPLICATE_IDENTITY/STALE - A newer session for that identity exists. + // + // Filtered updates are used with clients that can handle identity based reconnect and hence those + // conditions can be skipped. + var filteredUpdates []*livekit.ParticipantInfo + for _, update := range updates { + if update.IsSynthesizedDisconnect || IsCloseNotifySkippable(update.CloseReason) { + continue + } + filteredUpdates = append(filteredUpdates, update.ParticipantInfo) + } + + var fullUpdates []*livekit.ParticipantInfo + for _, update := range updates { + fullUpdates = append(fullUpdates, update.ParticipantInfo) + } + + filteredUpdateChunks := ChunkProtoBatch(filteredUpdates, batchTargetSize) + fullUpdateChunks := ChunkProtoBatch(fullUpdates, batchTargetSize) + + for _, op := range participants { + updateChunks := fullUpdateChunks + if op.ProtocolVersion().SupportsIdentityBasedReconnection() { + updateChunks = filteredUpdateChunks + } + for _, chunk := range updateChunks { + if err := op.SendParticipantUpdate(chunk); err != nil { + op.GetLogger().Errorw("could not send update to participant", err) + break + } + } + } +} + +// GetOtherParticipantInfo returns ParticipantInfo for everyone in the room except for the participant identified by lp.Identity() +func GetOtherParticipantInfo( + lp types.LocalParticipant, + isMigratingIn bool, + allParticipants []types.Participant, + skipSubscriberBroadcast bool, +) []*livekit.ParticipantInfo { + var lpIdentity livekit.ParticipantIdentity + if lp != nil { + lpIdentity = lp.Identity() + } + + pInfos := make([]*livekit.ParticipantInfo, 0, len(allParticipants)) + for _, op := range allParticipants { + if !(skipSubscriberBroadcast && op.CanSkipBroadcast()) && + !op.Hidden() && + op.Identity() != lpIdentity && + !isMigratingIn { + pInfos = append(pInfos, op.ToProto()) + } + } + + return pInfos +} + +func connectionDetailsFields(infos []*types.ICEConnectionInfo) []any { + var fields []any + connectionType := types.ICEConnectionTypeUnknown + for _, info := range infos { + candidates := make([]string, 0, len(info.Remote)+len(info.Local)) + for _, c := range info.Local { + cStr := "[local]" + if c.SelectedOrder != 0 { + cStr += fmt.Sprintf("[selected:%d]", c.SelectedOrder) + } else if c.Filtered { + cStr += "[filtered]" + } + if c.Trickle { + cStr += "[trickle]" + } + cStr += " " + c.Local.String() + candidates = append(candidates, cStr) + } + for _, c := range info.Remote { + cStr := "[remote]" + if c.SelectedOrder != 0 { + cStr += fmt.Sprintf("[selected:%d]", c.SelectedOrder) + } else if c.Filtered { + cStr += "[filtered]" + } + if c.Trickle { + cStr += "[trickle]" + } + cStr += " " + fmt.Sprintf("%s %s %s:%d", c.Remote.NetworkType(), c.Remote.Type(), MaybeTruncateIP(c.Remote.Address()), c.Remote.Port()) + if relatedAddress := c.Remote.RelatedAddress(); relatedAddress != nil { + relatedAddr := MaybeTruncateIP(relatedAddress.Address) + if relatedAddr != "" { + cStr += " " + fmt.Sprintf(" related %s:%d", relatedAddr, relatedAddress.Port) + } + } + candidates = append(candidates, cStr) + } + if len(candidates) > 0 { + fields = append(fields, fmt.Sprintf("%sCandidates", strings.ToLower(info.Transport.String())), candidates) + } + if info.Type != types.ICEConnectionTypeUnknown { + connectionType = info.Type + } + } + fields = append(fields, "connectionType", connectionType) + return fields +} + +func toParticipants(lps []types.LocalParticipant) []types.Participant { + participants := make([]types.Participant, len(lps)) + for idx, lp := range lps { + participants[idx] = lp + } + return participants +} diff --git a/livekit/pkg/rtc/room_test.go b/livekit/pkg/rtc/room_test.go new file mode 100644 index 0000000..0d69710 --- /dev/null +++ b/livekit/pkg/rtc/room_test.go @@ -0,0 +1,860 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/auth/authfakes" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/webhook" + + "github.com/livekit/livekit-server/version" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/audio" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/telemetry/telemetryfakes" + "github.com/livekit/livekit-server/pkg/testutils" +) + +func init() { + prometheus.Init("test", livekit.NodeType_SERVER) +} + +const ( + numParticipants = 3 + defaultDelay = 10 * time.Millisecond + audioUpdateInterval = 25 +) + +func init() { + config.InitLoggerFromConfig(&config.DefaultConfig.Logging) + roomUpdateInterval = defaultDelay +} + +var iceServersForRoom = []*livekit.ICEServer{{Urls: []string{"stun:stun.l.google.com:19302"}}} + +func TestJoinedState(t *testing.T) { + t.Run("new room should return joinedAt 0", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 0}) + require.Equal(t, int64(0), rm.FirstJoinedAt()) + require.Equal(t, int64(0), rm.LastLeftAt()) + }) + + t.Run("should be current time when a participant joins", func(t *testing.T) { + s := time.Now().Unix() + rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) + require.LessOrEqual(t, s, rm.FirstJoinedAt()) + require.Equal(t, int64(0), rm.LastLeftAt()) + }) + + t.Run("should be set when a participant leaves", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) + p0 := rm.GetParticipants()[0] + s := time.Now().Unix() + rm.RemoveParticipant(p0.Identity(), p0.ID(), types.ParticipantCloseReasonClientRequestLeave) + require.LessOrEqual(t, s, rm.LastLeftAt()) + }) + + t.Run("LastLeftAt should be set when there are still participants in the room", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) + p0 := rm.GetParticipants()[0] + rm.RemoveParticipant(p0.Identity(), p0.ID(), types.ParticipantCloseReasonClientRequestLeave) + require.Greater(t, rm.LastLeftAt(), int64(0)) + }) +} + +func TestRoomJoin(t *testing.T) { + t.Run("joining returns existing participant data", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: numParticipants}) + pNew := NewMockParticipant("new", types.CurrentProtocol, false, false, rm.LocalParticipantListener()) + + _ = rm.Join(pNew, nil, nil, iceServersForRoom) + + // expect new participant to get a JoinReply + res := pNew.SendJoinResponseArgsForCall(0) + require.Equal(t, livekit.RoomID(res.Room.Sid), rm.ID()) + require.Len(t, res.OtherParticipants, numParticipants) + require.Len(t, rm.GetParticipants(), numParticipants+1) + require.NotEmpty(t, res.IceServers) + }) + + t.Run("subscribe to existing channels upon join", func(t *testing.T) { + numExisting := 3 + rm := newRoomWithParticipants(t, testRoomOpts{num: numExisting}) + lpl := rm.LocalParticipantListener() + p := NewMockParticipant("new", types.CurrentProtocol, false, false, lpl) + + err := rm.Join(p, nil, &ParticipantOptions{AutoSubscribe: true}, iceServersForRoom) + require.NoError(t, err) + + p.StateReturns(livekit.ParticipantInfo_ACTIVE) + lpl.OnStateChange(p) + + // it should become a subscriber when connectivity changes + numTracks := 0 + for _, op := range rm.GetParticipants() { + if p == op { + continue + } + + numTracks += len(op.GetPublishedTracks()) + } + require.Equal(t, numTracks, p.SubscribeToTrackCallCount()) + }) + + t.Run("participant state change is broadcasted to others", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: numParticipants}) + var changedParticipant types.Participant + rm.OnParticipantChanged(func(participant types.Participant) { + changedParticipant = participant + }) + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + disconnectedParticipant := participants[1].(*typesfakes.FakeLocalParticipant) + disconnectedParticipant.StateReturns(livekit.ParticipantInfo_DISCONNECTED) + + rm.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonClientRequestLeave) + time.Sleep(defaultDelay) + + require.Equal(t, p, changedParticipant) + + numUpdates := 0 + for _, op := range participants { + if op == p || op == disconnectedParticipant { + require.Zero(t, p.SendParticipantUpdateCallCount()) + continue + } + fakeP := op.(*typesfakes.FakeLocalParticipant) + require.Equal(t, 1, fakeP.SendParticipantUpdateCallCount()) + numUpdates += 1 + } + require.Equal(t, numParticipants-2, numUpdates) + }) + + t.Run("cannot exceed max participants", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) + rm.lock.Lock() + rm.protoRoom.MaxParticipants = 1 + rm.lock.Unlock() + p := NewMockParticipant("second", types.ProtocolVersion(0), false, false, rm.LocalParticipantListener()) + + err := rm.Join(p, nil, nil, iceServersForRoom) + require.Equal(t, ErrMaxParticipantsExceeded, err) + }) +} + +// various state changes to participant and that others are receiving update +func TestParticipantUpdate(t *testing.T) { + tests := []struct { + name string + sendToSender bool // should sender receive it + action func(p types.LocalParticipant) + }{ + { + "track mutes are sent to everyone", + true, + func(p types.LocalParticipant) { + p.SetTrackMuted(&livekit.MuteTrackRequest{Muted: true}, false) + }, + }, + { + "track metadata updates are sent to everyone", + true, + func(p types.LocalParticipant) { + p.SetMetadata("") + }, + }, + { + "track publishes are sent to existing participants", + true, + func(p types.LocalParticipant) { + p.AddTrack(&livekit.AddTrackRequest{ + Type: livekit.TrackType_VIDEO, + }) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 3}) + // remember how many times send has been called for each + callCounts := make(map[livekit.ParticipantID]int) + for _, p := range rm.GetParticipants() { + fp := p.(*typesfakes.FakeLocalParticipant) + callCounts[p.ID()] = fp.SendParticipantUpdateCallCount() + } + + sender := rm.GetParticipants()[0] + test.action(sender) + + // go through the other participants, make sure they've received update + for _, p := range rm.GetParticipants() { + expected := callCounts[p.ID()] + if p != sender || test.sendToSender { + expected += 1 + } + fp := p.(*typesfakes.FakeLocalParticipant) + require.Equal(t, expected, fp.SendParticipantUpdateCallCount()) + } + }) + } +} + +func TestPushAndDequeueUpdates(t *testing.T) { + identity := "test_user" + publisher1v1 := &livekit.ParticipantInfo{ + Identity: identity, + Sid: "1", + IsPublisher: true, + Version: 1, + JoinedAt: 0, + } + publisher1v2 := &livekit.ParticipantInfo{ + Identity: identity, + Sid: "1", + IsPublisher: true, + Version: 2, + JoinedAt: 1, + } + publisher2 := &livekit.ParticipantInfo{ + Identity: identity, + Sid: "2", + IsPublisher: true, + Version: 1, + JoinedAt: 2, + } + subscriber1v1 := &livekit.ParticipantInfo{ + Identity: identity, + Sid: "1", + Version: 1, + JoinedAt: 0, + } + subscriber1v2 := &livekit.ParticipantInfo{ + Identity: identity, + Sid: "1", + Version: 2, + JoinedAt: 1, + } + + requirePIEquals := func(t *testing.T, a, b *livekit.ParticipantInfo) { + require.Equal(t, a.Sid, b.Sid) + require.Equal(t, a.Identity, b.Identity) + require.Equal(t, a.Version, b.Version) + } + testCases := []struct { + name string + pi *livekit.ParticipantInfo + closeReason types.ParticipantCloseReason + immediate bool + existing *ParticipantUpdate + expected []*ParticipantUpdate + validate func(t *testing.T, rm *Room, updates []*ParticipantUpdate) + }{ + { + name: "publisher updates are immediate", + pi: publisher1v1, + expected: []*ParticipantUpdate{{ParticipantInfo: publisher1v1}}, + }, + { + name: "subscriber updates are queued", + pi: subscriber1v1, + }, + { + name: "last version is enqueued", + pi: subscriber1v2, + existing: &ParticipantUpdate{ParticipantInfo: utils.CloneProto(subscriber1v1)}, // clone the existing value since it can be modified when setting to disconnected + validate: func(t *testing.T, rm *Room, _ []*ParticipantUpdate) { + queued := rm.batchedUpdates[livekit.ParticipantIdentity(identity)] + require.NotNil(t, queued) + requirePIEquals(t, subscriber1v2, queued.ParticipantInfo) + }, + }, + { + name: "latest version when immediate", + pi: subscriber1v2, + existing: &ParticipantUpdate{ParticipantInfo: utils.CloneProto(subscriber1v1)}, + immediate: true, + expected: []*ParticipantUpdate{{ParticipantInfo: subscriber1v2}}, + validate: func(t *testing.T, rm *Room, _ []*ParticipantUpdate) { + queued := rm.batchedUpdates[livekit.ParticipantIdentity(identity)] + require.Nil(t, queued) + }, + }, + { + name: "out of order updates are rejected", + pi: subscriber1v1, + existing: &ParticipantUpdate{ParticipantInfo: utils.CloneProto(subscriber1v2)}, + validate: func(t *testing.T, rm *Room, updates []*ParticipantUpdate) { + queued := rm.batchedUpdates[livekit.ParticipantIdentity(identity)] + requirePIEquals(t, subscriber1v2, queued.ParticipantInfo) + }, + }, + { + name: "sid change is broadcasted immediately with synthsized disconnect", + pi: publisher2, + closeReason: types.ParticipantCloseReasonServiceRequestRemoveParticipant, // just to test if update contain the close reason + existing: &ParticipantUpdate{ParticipantInfo: utils.CloneProto(subscriber1v2), CloseReason: types.ParticipantCloseReasonStale}, + expected: []*ParticipantUpdate{ + { + ParticipantInfo: &livekit.ParticipantInfo{ + Identity: identity, + Sid: "1", + Version: 2, + State: livekit.ParticipantInfo_DISCONNECTED, + }, + IsSynthesizedDisconnect: true, + CloseReason: types.ParticipantCloseReasonStale, + }, + {ParticipantInfo: publisher2, CloseReason: types.ParticipantCloseReasonServiceRequestRemoveParticipant}, + }, + }, + { + name: "when switching to publisher, queue is cleared", + pi: publisher1v2, + existing: &ParticipantUpdate{ParticipantInfo: utils.CloneProto(subscriber1v1)}, + expected: []*ParticipantUpdate{{ParticipantInfo: publisher1v2}}, + validate: func(t *testing.T, rm *Room, updates []*ParticipantUpdate) { + require.Empty(t, rm.batchedUpdates) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) + if tc.existing != nil { + rm.batchedUpdates[livekit.ParticipantIdentity(tc.existing.ParticipantInfo.Identity)] = tc.existing + } + rm.batchedUpdatesMu.Lock() + updates := PushAndDequeueUpdates( + tc.pi, + tc.closeReason, + tc.immediate, + rm.GetParticipant(livekit.ParticipantIdentity(tc.pi.Identity)), + rm.batchedUpdates, + ) + rm.batchedUpdatesMu.Unlock() + require.Equal(t, len(tc.expected), len(updates)) + for i, item := range tc.expected { + requirePIEquals(t, item.ParticipantInfo, updates[i].ParticipantInfo) + require.Equal(t, item.IsSynthesizedDisconnect, updates[i].IsSynthesizedDisconnect) + require.Equal(t, item.CloseReason, updates[i].CloseReason) + } + + if tc.validate != nil { + tc.validate(t, rm, updates) + } + }) + } +} + +func TestRoomClosure(t *testing.T) { + t.Run("room closes after participant leaves", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) + isClosed := false + rm.OnClose(func() { + isClosed = true + }) + p := rm.GetParticipants()[0] + rm.lock.Lock() + // allows immediate close after + rm.protoRoom.EmptyTimeout = 0 + rm.lock.Unlock() + rm.RemoveParticipant(p.Identity(), p.ID(), types.ParticipantCloseReasonClientRequestLeave) + + time.Sleep(time.Duration(rm.ToProto().DepartureTimeout)*time.Second + defaultDelay) + + rm.CloseIfEmpty() + require.Len(t, rm.GetParticipants(), 0) + require.True(t, isClosed) + + require.Equal(t, ErrRoomClosed, rm.Join(p, nil, nil, iceServersForRoom)) + }) + + t.Run("room does not close before empty timeout", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 0}) + isClosed := false + rm.OnClose(func() { + isClosed = true + }) + require.NotZero(t, rm.protoRoom.EmptyTimeout) + rm.CloseIfEmpty() + require.False(t, isClosed) + }) + + t.Run("room closes after empty timeout", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 0}) + isClosed := false + rm.OnClose(func() { + isClosed = true + }) + rm.lock.Lock() + rm.protoRoom.EmptyTimeout = 1 + rm.lock.Unlock() + + time.Sleep(1010 * time.Millisecond) + rm.CloseIfEmpty() + require.True(t, isClosed) + }) +} + +func TestNewTrack(t *testing.T) { + t.Run("new track should be added to ready participants", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 3}) + lpl := rm.LocalParticipantListener() + + participants := rm.GetParticipants() + p0 := participants[0].(*typesfakes.FakeLocalParticipant) + p0.StateReturns(livekit.ParticipantInfo_JOINED) + p1 := participants[1].(*typesfakes.FakeLocalParticipant) + p1.StateReturns(livekit.ParticipantInfo_ACTIVE) + + pub := participants[2].(*typesfakes.FakeLocalParticipant) + + // pub adds track + track := NewMockTrack(livekit.TrackType_VIDEO, "webcam") + lpl.OnTrackPublished(pub, track) + + // only p1 should've been subscribed to + require.Equal(t, 0, p0.SubscribeToTrackCallCount()) + require.Equal(t, 1, p1.SubscribeToTrackCallCount()) + }) +} + +func TestActiveSpeakers(t *testing.T) { + t.Parallel() + getActiveSpeakerUpdates := func(p *typesfakes.FakeLocalParticipant) [][]*livekit.SpeakerInfo { + var updates [][]*livekit.SpeakerInfo + numCalls := p.SendSpeakerUpdateCallCount() + for i := range numCalls { + infos, _ := p.SendSpeakerUpdateArgsForCall(i) + updates = append(updates, infos) + } + return updates + } + + audioUpdateDuration := (audioUpdateInterval + 10) * time.Millisecond + t.Run("participant should not be getting audio updates (protocol 2)", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 1, protocol: 2}) + defer rm.Close(types.ParticipantCloseReasonNone) + p := rm.GetParticipants()[0].(*typesfakes.FakeLocalParticipant) + require.Empty(t, rm.GetActiveSpeakers()) + + time.Sleep(audioUpdateDuration) + + updates := getActiveSpeakerUpdates(p) + require.Empty(t, updates) + }) + + t.Run("speakers should be sorted by loudness", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) + defer rm.Close(types.ParticipantCloseReasonNone) + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + p2 := participants[1].(*typesfakes.FakeLocalParticipant) + p.GetAudioLevelReturns(20, true) + p2.GetAudioLevelReturns(10, true) + + speakers := rm.GetActiveSpeakers() + require.Len(t, speakers, 2) + require.Equal(t, string(p.ID()), speakers[0].Sid) + require.Equal(t, string(p2.ID()), speakers[1].Sid) + }) + + t.Run("participants are getting audio updates (protocol 3+)", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: 3}) + defer rm.Close(types.ParticipantCloseReasonNone) + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + time.Sleep(time.Millisecond) // let the first update cycle run + p.GetAudioLevelReturns(30, true) + + speakers := rm.GetActiveSpeakers() + require.NotEmpty(t, speakers) + require.Equal(t, string(p.ID()), speakers[0].Sid) + + testutils.WithTimeout(t, func() string { + for _, op := range participants { + op := op.(*typesfakes.FakeLocalParticipant) + updates := getActiveSpeakerUpdates(op) + if len(updates) == 0 { + return fmt.Sprintf("%s did not get any audio updates", op.Identity()) + } + } + return "" + }) + + // no longer speaking, send update with empty items + p.GetAudioLevelReturns(127, false) + + testutils.WithTimeout(t, func() string { + updates := getActiveSpeakerUpdates(p) + lastUpdate := updates[len(updates)-1] + if len(lastUpdate) == 0 { + return "did not get updates of speaker going quiet" + } + if lastUpdate[0].Active { + return "speaker should not have been active" + } + return "" + }) + }) + + t.Run("audio level is smoothed", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2, protocol: 3, audioSmoothIntervals: 3}) + defer rm.Close(types.ParticipantCloseReasonNone) + + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + op := participants[1].(*typesfakes.FakeLocalParticipant) + p.GetAudioLevelReturns(30, true) + convertedLevel := float32(audio.ConvertAudioLevel(30)) + + testutils.WithTimeout(t, func() string { + updates := getActiveSpeakerUpdates(op) + if len(updates) == 0 { + return "no speaker updates received" + } + lastSpeakers := updates[len(updates)-1] + if len(lastSpeakers) == 0 { + return "no speakers in the update" + } + if lastSpeakers[0].Level > convertedLevel { + return "" + } + return "level mismatch" + }) + + testutils.WithTimeout(t, func() string { + updates := getActiveSpeakerUpdates(op) + if len(updates) == 0 { + return "no updates received" + } + lastSpeakers := updates[len(updates)-1] + if len(lastSpeakers) == 0 { + return "no speakers found" + } + if lastSpeakers[0].Level > convertedLevel { + return "" + } + return "did not match expected levels" + }) + + p.GetAudioLevelReturns(127, false) + testutils.WithTimeout(t, func() string { + updates := getActiveSpeakerUpdates(op) + if len(updates) == 0 { + return "no speaker updates received" + } + lastSpeakers := updates[len(updates)-1] + if len(lastSpeakers) == 1 && !lastSpeakers[0].Active { + return "" + } + return "speakers didn't go back to zero" + }) + }) +} + +func TestDataChannel(t *testing.T) { + t.Parallel() + + const ( + curAPI = iota + legacySID + legacyIdentity + ) + modes := []int{ + curAPI, legacySID, legacyIdentity, + } + modeNames := []string{ + "cur", "legacy sid", "legacy identity", + } + + setSource := func(mode int, dp *livekit.DataPacket, p types.LocalParticipant) { + switch mode { + case curAPI: + dp.ParticipantIdentity = string(p.Identity()) + case legacySID: + dp.GetUser().ParticipantSid = string(p.ID()) + case legacyIdentity: + dp.GetUser().ParticipantIdentity = string(p.Identity()) + } + } + setDest := func(mode int, dp *livekit.DataPacket, p types.LocalParticipant) { + switch mode { + case curAPI: + dp.DestinationIdentities = []string{string(p.Identity())} + case legacySID: + dp.GetUser().DestinationSids = []string{string(p.ID())} + case legacyIdentity: + dp.GetUser().DestinationIdentities = []string{string(p.Identity())} + } + } + + t.Run("participants should receive data", func(t *testing.T) { + for _, mode := range modes { + mode := mode + t.Run(modeNames[mode], func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 3}) + defer rm.Close(types.ParticipantCloseReasonNone) + + lpl := rm.LocalParticipantListener() + + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + + packet := &livekit.DataPacket{ + Kind: livekit.DataPacket_RELIABLE, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: []byte("message.."), + }, + }, + } + setSource(mode, packet, p) + + packetExp := utils.CloneProto(packet) + if mode != legacySID { + packetExp.ParticipantIdentity = string(p.Identity()) + packetExp.GetUser().ParticipantIdentity = string(p.Identity()) + } + + encoded, _ := proto.Marshal(packetExp) + lpl.OnDataMessage(p, packet.Kind, packet) + + // ensure everyone has received the packet + for _, op := range participants { + fp := op.(*typesfakes.FakeLocalParticipant) + if fp == p { + require.Zero(t, fp.SendDataMessageCallCount()) + continue + } + require.Equal(t, 1, fp.SendDataMessageCallCount()) + _, got, _, _ := fp.SendDataMessageArgsForCall(0) + require.Equal(t, encoded, got) + } + }) + } + }) + + t.Run("only one participant should receive the data", func(t *testing.T) { + for _, mode := range modes { + mode := mode + t.Run(modeNames[mode], func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 4}) + defer rm.Close(types.ParticipantCloseReasonNone) + + lpl := rm.LocalParticipantListener() + + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + p1 := participants[1].(*typesfakes.FakeLocalParticipant) + + packet := &livekit.DataPacket{ + Kind: livekit.DataPacket_RELIABLE, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: []byte("message to p1.."), + }, + }, + } + setSource(mode, packet, p) + setDest(mode, packet, p1) + + packetExp := utils.CloneProto(packet) + if mode != legacySID { + packetExp.ParticipantIdentity = string(p.Identity()) + packetExp.GetUser().ParticipantIdentity = string(p.Identity()) + packetExp.DestinationIdentities = []string{string(p1.Identity())} + packetExp.GetUser().DestinationIdentities = []string{string(p1.Identity())} + } + + encoded, _ := proto.Marshal(packetExp) + lpl.OnDataMessage(p, packet.Kind, packet) + + // only p1 should receive the data + for _, op := range participants { + fp := op.(*typesfakes.FakeLocalParticipant) + if fp != p1 { + require.Zero(t, fp.SendDataMessageCallCount()) + } + } + require.Equal(t, 1, p1.SendDataMessageCallCount()) + _, got, _, _ := p1.SendDataMessageArgsForCall(0) + require.Equal(t, encoded, got) + }) + } + }) + + t.Run("publishing disallowed", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) + defer rm.Close(types.ParticipantCloseReasonNone) + + participants := rm.GetParticipants() + p := participants[0].(*typesfakes.FakeLocalParticipant) + p.CanPublishDataReturns(false) + + packet := livekit.DataPacket{ + Kind: livekit.DataPacket_RELIABLE, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: []byte{}, + }, + }, + } + if p.CanPublishData() { + lpl := rm.LocalParticipantListener() + lpl.OnDataMessage(p, packet.Kind, &packet) + } + + // no one should've been sent packet + for _, op := range participants { + fp := op.(*typesfakes.FakeLocalParticipant) + require.Zero(t, fp.SendDataMessageCallCount()) + } + }) +} + +func TestHiddenParticipants(t *testing.T) { + t.Run("other participants don't receive hidden updates", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2, numHidden: 1}) + defer rm.Close(types.ParticipantCloseReasonNone) + + pNew := NewMockParticipant("new", types.CurrentProtocol, false, false, rm.LocalParticipantListener()) + rm.Join(pNew, nil, nil, iceServersForRoom) + + // expect new participant to get a JoinReply + res := pNew.SendJoinResponseArgsForCall(0) + require.Equal(t, livekit.RoomID(res.Room.Sid), rm.ID()) + require.Len(t, res.OtherParticipants, 2) + require.Len(t, rm.GetParticipants(), 4) + require.NotEmpty(t, res.IceServers) + require.Equal(t, "testregion", res.ServerInfo.Region) + }) + + t.Run("hidden participant subscribes to tracks", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) + lpl := rm.LocalParticipantListener() + hidden := NewMockParticipant("hidden", types.CurrentProtocol, true, false, lpl) + + err := rm.Join(hidden, nil, &ParticipantOptions{AutoSubscribe: true}, iceServersForRoom) + require.NoError(t, err) + + hidden.StateReturns(livekit.ParticipantInfo_ACTIVE) + lpl.OnStateChange(hidden) + + require.Eventually(t, func() bool { return hidden.SubscribeToTrackCallCount() == 2 }, 5*time.Second, 10*time.Millisecond) + }) +} + +func TestRoomUpdate(t *testing.T) { + t.Run("updates are sent when participant joined", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 1}) + defer rm.Close(types.ParticipantCloseReasonNone) + + p1 := rm.GetParticipants()[0].(*typesfakes.FakeLocalParticipant) + require.Equal(t, 0, p1.SendRoomUpdateCallCount()) + + p2 := NewMockParticipant("p2", types.CurrentProtocol, false, false, rm.LocalParticipantListener()) + require.NoError(t, rm.Join(p2, nil, nil, iceServersForRoom)) + + // p1 should have received an update + time.Sleep(2 * defaultDelay) + require.LessOrEqual(t, 1, p1.SendRoomUpdateCallCount()) + require.EqualValues(t, 2, p1.SendRoomUpdateArgsForCall(p1.SendRoomUpdateCallCount()-1).NumParticipants) + }) + + t.Run("participants should receive metadata update", func(t *testing.T) { + rm := newRoomWithParticipants(t, testRoomOpts{num: 2}) + defer rm.Close(types.ParticipantCloseReasonNone) + + rm.SetMetadata("test metadata...") + + // callbacks are updated from goroutine + time.Sleep(2 * defaultDelay) + + for _, op := range rm.GetParticipants() { + fp := op.(*typesfakes.FakeLocalParticipant) + // room updates are now sent for both participant joining and room metadata + require.GreaterOrEqual(t, fp.SendRoomUpdateCallCount(), 1) + } + }) +} + +type testRoomOpts struct { + num int + numHidden int + protocol types.ProtocolVersion + audioSmoothIntervals uint32 +} + +func newRoomWithParticipants(t *testing.T, opts testRoomOpts) *Room { + kp := &authfakes.FakeKeyProvider{} + kp.GetSecretReturns("testkey") + + n, err := webhook.NewDefaultNotifier(webhook.DefaultWebHookConfig, kp) + require.NoError(t, err) + + rm := NewRoom( + &livekit.Room{Name: "room"}, + nil, + WebRTCConfig{}, + config.RoomConfig{ + EmptyTimeout: 5 * 60, + DepartureTimeout: 1, + }, + &sfu.AudioConfig{ + AudioLevelConfig: audio.AudioLevelConfig{ + UpdateInterval: audioUpdateInterval, + SmoothIntervals: opts.audioSmoothIntervals, + }, + }, + &livekit.ServerInfo{ + Edition: livekit.ServerInfo_Standard, + Version: version.Version, + Protocol: types.CurrentProtocol, + NodeId: "testnode", + Region: "testregion", + }, + telemetry.NewTelemetryService(n, &telemetryfakes.FakeAnalyticsService{}), + nil, nil, nil, + ) + for i := 0; i < opts.num+opts.numHidden; i++ { + identity := livekit.ParticipantIdentity(fmt.Sprintf("p%d", i)) + participant := NewMockParticipant(identity, opts.protocol, i >= opts.num, true, rm.LocalParticipantListener()) + err := rm.Join(participant, nil, &ParticipantOptions{AutoSubscribe: true}, iceServersForRoom) + require.NoError(t, err) + participant.StateReturns(livekit.ParticipantInfo_ACTIVE) + participant.IsReadyReturns(true) + // each participant has a track + participant.GetPublishedTracksReturns([]types.MediaTrack{ + &typesfakes.FakeMediaTrack{}, + }) + } + return rm +} diff --git a/livekit/pkg/rtc/roomtrackmanager.go b/livekit/pkg/rtc/roomtrackmanager.go new file mode 100644 index 0000000..2a5c93f --- /dev/null +++ b/livekit/pkg/rtc/roomtrackmanager.go @@ -0,0 +1,263 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rtc + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +// RoomTrackManager holds tracks that are published to the room +type RoomTrackManager struct { + logger logger.Logger + + lock sync.RWMutex + changedNotifier *utils.ChangeNotifierManager + removedNotifier *utils.ChangeNotifierManager + tracks map[livekit.TrackID][]*TrackInfo + dataTracks map[livekit.TrackID][]*DataTrackInfo +} + +type TrackInfo struct { + Track types.MediaTrack + PublisherIdentity livekit.ParticipantIdentity + PublisherID livekit.ParticipantID +} + +type DataTrackInfo struct { + DataTrack types.DataTrack + PublisherIdentity livekit.ParticipantIdentity + PublisherID livekit.ParticipantID +} + +func NewRoomTrackManager(logger logger.Logger) *RoomTrackManager { + return &RoomTrackManager{ + logger: logger, + tracks: make(map[livekit.TrackID][]*TrackInfo), + dataTracks: make(map[livekit.TrackID][]*DataTrackInfo), + changedNotifier: utils.NewChangeNotifierManager(), + removedNotifier: utils.NewChangeNotifierManager(), + } +} + +func (r *RoomTrackManager) AddTrack(track types.MediaTrack, publisherIdentity livekit.ParticipantIdentity, publisherID livekit.ParticipantID) { + trackID := track.ID() + r.lock.Lock() + infos, ok := r.tracks[trackID] + if ok { + for _, info := range infos { + if info.Track == track { + r.lock.Unlock() + r.logger.Infow("not adding duplicate track", "trackID", trackID) + return + } + } + } + r.tracks[trackID] = append(r.tracks[trackID], &TrackInfo{ + Track: track, + PublisherIdentity: publisherIdentity, + PublisherID: publisherID, + }) + r.lock.Unlock() + + r.NotifyTrackChanged(trackID) +} + +func (r *RoomTrackManager) RemoveTrack(track types.MediaTrack) { + trackID := track.ID() + r.lock.Lock() + // ensure we are removing the same track as added + infos, ok := r.tracks[trackID] + if !ok { + r.lock.Unlock() + return + } + + numRemoved := 0 + idx := 0 + for _, info := range infos { + if info.Track == track { + numRemoved++ + } else { + r.tracks[trackID][idx] = info + idx++ + } + } + for j := idx; j < len(infos); j++ { + r.tracks[trackID][j] = nil + } + r.tracks[trackID] = r.tracks[trackID][:idx] + if len(r.tracks[trackID]) == 0 { + delete(r.tracks, trackID) + } + r.lock.Unlock() + if numRemoved == 0 { + return + } + if numRemoved > 1 { + r.logger.Warnw("removed multiple tracks", nil, "trackID", trackID, "numRemoved", numRemoved) + } + + n := r.removedNotifier.GetNotifier(string(trackID)) + if n != nil { + n.NotifyChanged() + } + + r.changedNotifier.RemoveNotifier(string(trackID), true) + r.removedNotifier.RemoveNotifier(string(trackID), true) +} + +func (r *RoomTrackManager) GetTrackInfo(trackID livekit.TrackID) *TrackInfo { + r.lock.RLock() + defer r.lock.RUnlock() + + infos := r.tracks[trackID] + if len(infos) == 0 { + return nil + } + + // earliest added track is used till it is removed + info := infos[0] + + // when track is about to close, do not resolve + if info.Track != nil && !info.Track.IsOpen() { + return nil + } + return info +} + +func (r *RoomTrackManager) NotifyTrackChanged(trackID livekit.TrackID) { + n := r.changedNotifier.GetNotifier(string(trackID)) + if n != nil { + n.NotifyChanged() + } +} + +// HasObservers lets caller know if the current media track has any observers +// this is used to signal interest in the track. when another MediaTrack with the same +// trackID is being used, track is not considered to be observed. +func (r *RoomTrackManager) HasObservers(track types.MediaTrack) bool { + n := r.changedNotifier.GetNotifier(string(track.ID())) + if n == nil || !n.HasObservers() { + return false + } + + info := r.GetTrackInfo(track.ID()) + if info == nil || info.Track != track { + return false + } + return true +} + +func (r *RoomTrackManager) GetOrCreateTrackChangeNotifier(trackID livekit.TrackID) *utils.ChangeNotifier { + return r.changedNotifier.GetOrCreateNotifier(string(trackID)) +} + +func (r *RoomTrackManager) GetOrCreateTrackRemoveNotifier(trackID livekit.TrackID) *utils.ChangeNotifier { + return r.removedNotifier.GetOrCreateNotifier(string(trackID)) +} + +func (r *RoomTrackManager) AddDataTrack(dataTrack types.DataTrack, publisherIdentity livekit.ParticipantIdentity, publisherID livekit.ParticipantID) { + trackID := dataTrack.ID() + r.lock.Lock() + infos, ok := r.dataTracks[trackID] + if ok { + for _, info := range infos { + if info.DataTrack == dataTrack { + r.lock.Unlock() + r.logger.Infow("not adding duplicate data track", "trackID", trackID) + return + } + } + } + r.dataTracks[trackID] = append(r.dataTracks[trackID], &DataTrackInfo{ + DataTrack: dataTrack, + PublisherIdentity: publisherIdentity, + PublisherID: publisherID, + }) + r.lock.Unlock() + + r.NotifyTrackChanged(trackID) +} + +func (r *RoomTrackManager) RemoveDataTrack(dataTrack types.DataTrack) { + trackID := dataTrack.ID() + r.lock.Lock() + // ensure we are removing the same track as added + infos, ok := r.dataTracks[trackID] + if !ok { + r.lock.Unlock() + return + } + + numRemoved := 0 + idx := 0 + for _, info := range infos { + if info.DataTrack == dataTrack { + numRemoved++ + } else { + r.dataTracks[trackID][idx] = info + idx++ + } + } + for j := idx; j < len(infos); j++ { + r.dataTracks[trackID][j] = nil + } + r.dataTracks[trackID] = r.dataTracks[trackID][:idx] + if len(r.dataTracks[trackID]) == 0 { + delete(r.dataTracks, trackID) + } + r.lock.Unlock() + if numRemoved == 0 { + return + } + if numRemoved > 1 { + r.logger.Warnw("removed multiple data tracks", nil, "trackID", trackID, "numRemoved", numRemoved) + } + + n := r.removedNotifier.GetNotifier(string(trackID)) + if n != nil { + n.NotifyChanged() + } + + r.changedNotifier.RemoveNotifier(string(trackID), true) + r.removedNotifier.RemoveNotifier(string(trackID), true) +} + +func (r *RoomTrackManager) GetDataTrackInfo(trackID livekit.TrackID) *DataTrackInfo { + r.lock.RLock() + defer r.lock.RUnlock() + + infos := r.dataTracks[trackID] + if len(infos) == 0 { + return nil + } + + // earliest added data track is used till it is removed + return infos[0] +} + +func (r *RoomTrackManager) Report() (int, int) { + r.lock.RLock() + defer r.lock.RUnlock() + + return len(r.tracks), len(r.dataTracks) +} diff --git a/livekit/pkg/rtc/signalling/errors.go b/livekit/pkg/rtc/signalling/errors.go new file mode 100644 index 0000000..c78a4c3 --- /dev/null +++ b/livekit/pkg/rtc/signalling/errors.go @@ -0,0 +1,27 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "errors" +) + +var ( + ErrInvalidMessageType = errors.New("invalid message type") + ErrNameExceedsLimits = errors.New("name length exceeds limits") + ErrMetadataExceedsLimits = errors.New("metadata size exceeds limits") + ErrAttributesExceedsLimits = errors.New("attributes size exceeds limits") + ErrUpdateOwnMetadataNotAllowed = errors.New("update own metadata not allowed") +) diff --git a/livekit/pkg/rtc/signalling/interfaces.go b/livekit/pkg/rtc/signalling/interfaces.go new file mode 100644 index 0000000..2470e9d --- /dev/null +++ b/livekit/pkg/rtc/signalling/interfaces.go @@ -0,0 +1,65 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" + + "google.golang.org/protobuf/proto" +) + +type ParticipantSignalHandler interface { + HandleMessage(msg proto.Message) error +} + +type ParticipantSignaller interface { + SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) + GetResponseSink() routing.MessageSink + CloseSignalConnection(reason types.SignallingCloseReason) + + WriteMessage(msg proto.Message) error +} + +type ParticipantSignalling interface { + SignalJoinResponse(join *livekit.JoinResponse) proto.Message + SignalParticipantUpdate(participants []*livekit.ParticipantInfo) proto.Message + SignalSpeakerUpdate(speakers []*livekit.SpeakerInfo) proto.Message + SignalRoomUpdate(room *livekit.Room) proto.Message + SignalConnectionQualityUpdate(connectionQuality *livekit.ConnectionQualityUpdate) proto.Message + SignalRefreshToken(token string) proto.Message + SignalRequestResponse(requestResponse *livekit.RequestResponse) proto.Message + SignalRoomMovedResponse(roomMoved *livekit.RoomMovedResponse) proto.Message + SignalReconnectResponse(reconnect *livekit.ReconnectResponse) proto.Message + SignalICECandidate(trickle *livekit.TrickleRequest) proto.Message + SignalTrackMuted(mute *livekit.MuteTrackRequest) proto.Message + SignalTrackPublished(trackPublished *livekit.TrackPublishedResponse) proto.Message + SignalTrackUnpublished(trackUnpublished *livekit.TrackUnpublishedResponse) proto.Message + SignalTrackSubscribed(trackSubscribed *livekit.TrackSubscribed) proto.Message + SignalLeaveRequest(leave *livekit.LeaveRequest) proto.Message + SignalSdpAnswer(answer *livekit.SessionDescription) proto.Message + SignalSdpOffer(offer *livekit.SessionDescription) proto.Message + SignalStreamStateUpdate(streamStateUpdate *livekit.StreamStateUpdate) proto.Message + SignalSubscribedQualityUpdate(subscribedQualityUpdate *livekit.SubscribedQualityUpdate) proto.Message + SignalSubscriptionResponse(subscriptionResponse *livekit.SubscriptionResponse) proto.Message + SignalSubscriptionPermissionUpdate(subscriptionPermissionUpdate *livekit.SubscriptionPermissionUpdate) proto.Message + SignalMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) proto.Message + SignalSubscribedAudioCodecUpdate(subscribedAudioCodecUpdate *livekit.SubscribedAudioCodecUpdate) proto.Message + SignalPublishDataTrackResponse(publishDataTrackResponse *livekit.PublishDataTrackResponse) proto.Message + SignalUnpublishDataTrackResponse(unpublishDataTrackResponse *livekit.UnpublishDataTrackResponse) proto.Message + SignalDataTrackSubscriberHandles(dataTrackSubscriberHandles *livekit.DataTrackSubscriberHandles) proto.Message +} diff --git a/livekit/pkg/rtc/signalling/signalhandler.go b/livekit/pkg/rtc/signalling/signalhandler.go new file mode 100644 index 0000000..490e1c4 --- /dev/null +++ b/livekit/pkg/rtc/signalling/signalhandler.go @@ -0,0 +1,157 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "fmt" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +var _ ParticipantSignalHandler = (*signalhandler)(nil) + +type SignalHandlerParams struct { + Logger logger.Logger + Participant types.LocalParticipant +} + +type signalhandler struct { + signalhandlerUnimplemented + + params SignalHandlerParams +} + +func NewSignalHandler(params SignalHandlerParams) ParticipantSignalHandler { + return &signalhandler{ + params: params, + } +} + +func (s *signalhandler) HandleMessage(msg proto.Message) error { + req, ok := msg.(*livekit.SignalRequest) + if !ok { + s.params.Logger.Warnw( + "unknown message type", nil, + "messageType", fmt.Sprintf("%T", msg), + ) + return ErrInvalidMessageType + } + s.params.Participant.UpdateLastSeenSignal() + + s.params.Logger.Debugw("handling signal request", "request", logger.Proto(req)) + switch msg := req.GetMessage().(type) { + case *livekit.SignalRequest_Offer: + s.params.Participant.HandleOffer(msg.Offer) + + case *livekit.SignalRequest_Answer: + s.params.Participant.HandleAnswer(msg.Answer) + + case *livekit.SignalRequest_Trickle: + s.params.Participant.HandleICETrickle(msg.Trickle) + + case *livekit.SignalRequest_AddTrack: + s.params.Participant.AddTrack(msg.AddTrack) + + case *livekit.SignalRequest_Mute: + s.params.Participant.SetTrackMuted(msg.Mute, false) + + case *livekit.SignalRequest_Subscription: + // allow participant to indicate their interest in the subscription + // permission check happens later in SubscriptionManager + s.params.Participant.HandleUpdateSubscriptions( + livekit.StringsAsIDs[livekit.TrackID](msg.Subscription.TrackSids), + msg.Subscription.ParticipantTracks, + msg.Subscription.Subscribe, + ) + + case *livekit.SignalRequest_TrackSetting: + for _, sid := range livekit.StringsAsIDs[livekit.TrackID](msg.TrackSetting.TrackSids) { + s.params.Participant.UpdateSubscribedTrackSettings(sid, msg.TrackSetting) + } + + case *livekit.SignalRequest_Leave: + reason := types.ParticipantCloseReasonClientRequestLeave + switch msg.Leave.Reason { + case livekit.DisconnectReason_CLIENT_INITIATED: + reason = types.ParticipantCloseReasonClientRequestLeave + case livekit.DisconnectReason_USER_UNAVAILABLE: + reason = types.ParticipantCloseReasonUserUnavailable + case livekit.DisconnectReason_USER_REJECTED: + reason = types.ParticipantCloseReasonUserRejected + } + s.params.Logger.Debugw("client leaving room", "reason", reason) + s.params.Participant.HandleLeaveRequest(reason) + + case *livekit.SignalRequest_SubscriptionPermission: + err := s.params.Participant.HandleUpdateSubscriptionPermission(msg.SubscriptionPermission) + if err != nil { + s.params.Logger.Warnw( + "could not update subscription permission", err, + "permissions", msg.SubscriptionPermission, + ) + } + + case *livekit.SignalRequest_SyncState: + err := s.params.Participant.HandleSyncState(msg.SyncState) + if err != nil { + s.params.Logger.Warnw( + "could not sync state", err, + "state", msg.SyncState, + ) + } + + case *livekit.SignalRequest_Simulate: + err := s.params.Participant.HandleSimulateScenario(msg.Simulate) + if err != nil { + s.params.Logger.Warnw( + "could not simulate scenario", err, + "simulate", msg.Simulate, + ) + } + + case *livekit.SignalRequest_PingReq: + if msg.PingReq.Rtt > 0 { + s.params.Participant.UpdateSignalingRTT(uint32(msg.PingReq.Rtt)) + } + + case *livekit.SignalRequest_UpdateMetadata: + s.params.Participant.UpdateMetadata(msg.UpdateMetadata, false) + + case *livekit.SignalRequest_UpdateAudioTrack: + if err := s.params.Participant.UpdateAudioTrack(msg.UpdateAudioTrack); err != nil { + s.params.Logger.Warnw("could not update audio track", err, "update", msg.UpdateAudioTrack) + } + + case *livekit.SignalRequest_UpdateVideoTrack: + if err := s.params.Participant.UpdateVideoTrack(msg.UpdateVideoTrack); err != nil { + s.params.Logger.Warnw("could not update video track", err, "update", msg.UpdateVideoTrack) + } + + case *livekit.SignalRequest_PublishDataTrackRequest: + s.params.Participant.HandlePublishDataTrackRequest(msg.PublishDataTrackRequest) + + case *livekit.SignalRequest_UnpublishDataTrackRequest: + s.params.Participant.HandleUnpublishDataTrackRequest(msg.UnpublishDataTrackRequest) + + case *livekit.SignalRequest_UpdateDataSubscription: + s.params.Participant.HandleUpdateDataSubscription(msg.UpdateDataSubscription) + } + + return nil +} diff --git a/livekit/pkg/rtc/signalling/signalhandlerunimplemented.go b/livekit/pkg/rtc/signalling/signalhandlerunimplemented.go new file mode 100644 index 0000000..9df37bd --- /dev/null +++ b/livekit/pkg/rtc/signalling/signalhandlerunimplemented.go @@ -0,0 +1,27 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "google.golang.org/protobuf/proto" +) + +var _ ParticipantSignalHandler = (*signalhandlerUnimplemented)(nil) + +type signalhandlerUnimplemented struct{} + +func (u *signalhandlerUnimplemented) HandleMessage(msg proto.Message) error { + return nil +} diff --git a/livekit/pkg/rtc/signalling/signallerasync.go b/livekit/pkg/rtc/signalling/signallerasync.go new file mode 100644 index 0000000..15fb581 --- /dev/null +++ b/livekit/pkg/rtc/signalling/signallerasync.go @@ -0,0 +1,108 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "fmt" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" + + "google.golang.org/protobuf/proto" +) + +var _ ParticipantSignaller = (*signallerAsync)(nil) + +type SignallerAsyncParams struct { + Logger logger.Logger + Participant types.LocalParticipant +} + +type signallerAsync struct { + params SignallerAsyncParams + + *signallerAsyncBase +} + +func NewSignallerAsync(params SignallerAsyncParams) ParticipantSignaller { + return &signallerAsync{ + params: params, + signallerAsyncBase: newSignallerAsyncBase(signallerAsyncBaseParams{Logger: params.Logger}), + } +} + +func (s *signallerAsync) WriteMessage(msg proto.Message) error { + if msg == nil { + return nil + } + + if s.params.Participant.IsDisconnected() { + return nil + } + + if !s.params.Participant.IsReady() { + if typed, ok := msg.(*livekit.SignalResponse); !ok { + s.params.Logger.Warnw( + "unknown message type", nil, + "messageType", fmt.Sprintf("%T", msg), + ) + } else { + if typed.GetJoin() == nil { + return nil + } + } + } + + sink := s.GetResponseSink() + if sink == nil { + if typed, ok := msg.(*livekit.SignalResponse); ok { + s.params.Logger.Debugw( + "could not send message to participant", + "messageType", fmt.Sprintf("%T", typed.Message), + ) + } + return nil + } + + err := sink.WriteMessage(msg) + if err != nil { + if utils.ErrorIsOneOf(err, psrpc.Canceled, routing.ErrChannelClosed) { + if typed, ok := msg.(*livekit.SignalResponse); ok { + s.params.Logger.Debugw( + "could not send message to participant", + "error", err, + "messageType", fmt.Sprintf("%T", typed.Message), + ) + } + return nil + } else { + if typed, ok := msg.(*livekit.SignalResponse); ok { + s.params.Logger.Warnw( + "could not send message to participant", err, + "messageType", fmt.Sprintf("%T", typed.Message), + ) + } + return err + } + } else { + s.params.Logger.Debugw("sent signal response", "response", logger.Proto(msg)) + } + return nil +} diff --git a/livekit/pkg/rtc/signalling/signallerasyncbase.go b/livekit/pkg/rtc/signalling/signallerasyncbase.go new file mode 100644 index 0000000..0b4d2b6 --- /dev/null +++ b/livekit/pkg/rtc/signalling/signallerasyncbase.go @@ -0,0 +1,79 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "sync" + + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +type signallerAsyncBaseParams struct { + Logger logger.Logger +} + +type signallerAsyncBase struct { + signallerUnimplemented + + params signallerAsyncBaseParams + + resSinkMu sync.Mutex + resSink routing.MessageSink +} + +func newSignallerAsyncBase(params signallerAsyncBaseParams) *signallerAsyncBase { + return &signallerAsyncBase{ + params: params, + } +} + +func (s *signallerAsyncBase) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) { + s.resSinkMu.Lock() + oldSink := s.resSink + s.resSink = sink + s.resSinkMu.Unlock() + + if oldSink != nil { + if sink != nil { + s.params.Logger.Debugw( + "swapping signal connection", + "reason", reason, + "connID", oldSink.ConnectionID(), + "newConnID", sink.ConnectionID(), + ) + } else { + s.params.Logger.Debugw( + "closing signal connection", + "reason", reason, + "connID", oldSink.ConnectionID(), + ) + } + oldSink.Close() + } +} + +func (s *signallerAsyncBase) GetResponseSink() routing.MessageSink { + s.resSinkMu.Lock() + defer s.resSinkMu.Unlock() + return s.resSink +} + +// closes signal connection to notify client to resume/reconnect +func (s *signallerAsyncBase) CloseSignalConnection(reason types.SignallingCloseReason) { + s.SwapResponseSink(nil, reason) +} diff --git a/livekit/pkg/rtc/signalling/signallerunimplemented.go b/livekit/pkg/rtc/signalling/signallerunimplemented.go new file mode 100644 index 0000000..3592fd8 --- /dev/null +++ b/livekit/pkg/rtc/signalling/signallerunimplemented.go @@ -0,0 +1,39 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" + + "google.golang.org/protobuf/proto" +) + +var _ ParticipantSignaller = (*signallerUnimplemented)(nil) + +type signallerUnimplemented struct{} + +func (u *signallerUnimplemented) SwapResponseSink(sink routing.MessageSink, reason types.SignallingCloseReason) { +} + +func (u *signallerUnimplemented) GetResponseSink() routing.MessageSink { + return nil +} + +func (u *signallerUnimplemented) CloseSignalConnection(reason types.SignallingCloseReason) {} + +func (u *signallerUnimplemented) WriteMessage(msg proto.Message) error { + return nil +} diff --git a/livekit/pkg/rtc/signalling/signalling.go b/livekit/pkg/rtc/signalling/signalling.go new file mode 100644 index 0000000..700d125 --- /dev/null +++ b/livekit/pkg/rtc/signalling/signalling.go @@ -0,0 +1,262 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "google.golang.org/protobuf/proto" +) + +var _ ParticipantSignalling = (*signalling)(nil) + +type SignallingParams struct { + Logger logger.Logger +} + +type signalling struct { + signallingUnimplemented + + params SignallingParams +} + +func NewSignalling(params SignallingParams) ParticipantSignalling { + return &signalling{ + params: params, + } +} + +func (s *signalling) SignalJoinResponse(join *livekit.JoinResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Join{ + Join: join, + }, + } +} + +func (s *signalling) SignalParticipantUpdate(participants []*livekit.ParticipantInfo) proto.Message { + if len(participants) == 0 { + return nil + } + + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Update{ + Update: &livekit.ParticipantUpdate{ + Participants: participants, + }, + }, + } +} + +func (s *signalling) SignalSpeakerUpdate(speakers []*livekit.SpeakerInfo) proto.Message { + if len(speakers) == 0 { + return nil + } + + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_SpeakersChanged{ + SpeakersChanged: &livekit.SpeakersChanged{ + Speakers: speakers, + }, + }, + } +} + +func (s *signalling) SignalRoomUpdate(room *livekit.Room) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_RoomUpdate{ + RoomUpdate: &livekit.RoomUpdate{ + Room: room, + }, + }, + } +} + +func (s *signalling) SignalConnectionQualityUpdate(connectionQuality *livekit.ConnectionQualityUpdate) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_ConnectionQuality{ + ConnectionQuality: connectionQuality, + }, + } +} + +func (s *signalling) SignalRefreshToken(token string) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_RefreshToken{ + RefreshToken: token, + }, + } +} + +func (s *signalling) SignalRequestResponse(requestResponse *livekit.RequestResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_RequestResponse{ + RequestResponse: requestResponse, + }, + } +} + +func (s *signalling) SignalRoomMovedResponse(roomMoved *livekit.RoomMovedResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_RoomMoved{ + RoomMoved: roomMoved, + }, + } +} + +func (s *signalling) SignalReconnectResponse(reconnect *livekit.ReconnectResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Reconnect{ + Reconnect: reconnect, + }, + } +} + +func (s *signalling) SignalICECandidate(trickle *livekit.TrickleRequest) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Trickle{ + Trickle: trickle, + }, + } +} + +func (s *signalling) SignalTrackMuted(mute *livekit.MuteTrackRequest) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Mute{ + Mute: mute, + }, + } +} + +func (s *signalling) SignalTrackPublished(trackPublished *livekit.TrackPublishedResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_TrackPublished{ + TrackPublished: trackPublished, + }, + } +} + +func (s *signalling) SignalTrackUnpublished(trackUnpublished *livekit.TrackUnpublishedResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_TrackUnpublished{ + TrackUnpublished: trackUnpublished, + }, + } +} + +func (s *signalling) SignalTrackSubscribed(trackSubscribed *livekit.TrackSubscribed) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_TrackSubscribed{ + TrackSubscribed: trackSubscribed, + }, + } +} + +func (s *signalling) SignalLeaveRequest(leave *livekit.LeaveRequest) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Leave{ + Leave: leave, + }, + } +} + +func (s *signalling) SignalSdpAnswer(answer *livekit.SessionDescription) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Answer{ + Answer: answer, + }, + } +} + +func (s *signalling) SignalSdpOffer(offer *livekit.SessionDescription) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Offer{ + Offer: offer, + }, + } +} + +func (s *signalling) SignalStreamStateUpdate(streamStateUpdate *livekit.StreamStateUpdate) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_StreamStateUpdate{ + StreamStateUpdate: streamStateUpdate, + }, + } +} + +func (s *signalling) SignalSubscribedQualityUpdate(subscribedQualityUpdate *livekit.SubscribedQualityUpdate) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_SubscribedQualityUpdate{ + SubscribedQualityUpdate: subscribedQualityUpdate, + }, + } +} + +func (s *signalling) SignalSubscriptionResponse(subscriptionResponse *livekit.SubscriptionResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_SubscriptionResponse{ + SubscriptionResponse: subscriptionResponse, + }, + } +} + +func (s *signalling) SignalSubscriptionPermissionUpdate(subscriptionPermissionUpdate *livekit.SubscriptionPermissionUpdate) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_SubscriptionPermissionUpdate{ + SubscriptionPermissionUpdate: subscriptionPermissionUpdate, + }, + } +} + +func (u *signalling) SignalMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_MediaSectionsRequirement{ + MediaSectionsRequirement: mediaSectionsRequirement, + }, + } +} + +func (s *signalling) SignalSubscribedAudioCodecUpdate(subscribedAudioCodecUpdate *livekit.SubscribedAudioCodecUpdate) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_SubscribedAudioCodecUpdate{ + SubscribedAudioCodecUpdate: subscribedAudioCodecUpdate, + }, + } +} + +func (u *signalling) SignalPublishDataTrackResponse(publishDataTrackResponse *livekit.PublishDataTrackResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_PublishDataTrackResponse{ + PublishDataTrackResponse: publishDataTrackResponse, + }, + } +} + +func (u *signalling) SignalUnpublishDataTrackResponse(unpublishDataTrackResponse *livekit.UnpublishDataTrackResponse) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_UnpublishDataTrackResponse{ + UnpublishDataTrackResponse: unpublishDataTrackResponse, + }, + } +} + +func (u *signalling) SignalDataTrackSubscriberHandles(dataTrackSubscriberHandles *livekit.DataTrackSubscriberHandles) proto.Message { + return &livekit.SignalResponse{ + Message: &livekit.SignalResponse_DataTrackSubscriberHandles{ + DataTrackSubscriberHandles: dataTrackSubscriberHandles, + }, + } +} diff --git a/livekit/pkg/rtc/signalling/signallingunimplemented.go b/livekit/pkg/rtc/signalling/signallingunimplemented.go new file mode 100644 index 0000000..dca4877 --- /dev/null +++ b/livekit/pkg/rtc/signalling/signallingunimplemented.go @@ -0,0 +1,129 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "github.com/livekit/protocol/livekit" + + "google.golang.org/protobuf/proto" +) + +var _ ParticipantSignalling = (*signallingUnimplemented)(nil) + +type signallingUnimplemented struct{} + +func (u *signallingUnimplemented) SignalJoinResponse(join *livekit.JoinResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalParticipantUpdate(participants []*livekit.ParticipantInfo) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSpeakerUpdate(speakers []*livekit.SpeakerInfo) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalRoomUpdate(room *livekit.Room) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalConnectionQualityUpdate(connectionQuality *livekit.ConnectionQualityUpdate) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalRefreshToken(token string) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalRequestResponse(requestResponse *livekit.RequestResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalRoomMovedResponse(roomMoved *livekit.RoomMovedResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalReconnectResponse(reconnect *livekit.ReconnectResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalICECandidate(trickle *livekit.TrickleRequest) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalTrackMuted(mute *livekit.MuteTrackRequest) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalTrackPublished(trackPublished *livekit.TrackPublishedResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalTrackUnpublished(trackUnpublished *livekit.TrackUnpublishedResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalTrackSubscribed(trackSubscribed *livekit.TrackSubscribed) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalLeaveRequest(leave *livekit.LeaveRequest) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSdpAnswer(answer *livekit.SessionDescription) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSdpOffer(offer *livekit.SessionDescription) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalStreamStateUpdate(streamStateUpdate *livekit.StreamStateUpdate) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSubscribedQualityUpdate(subscribedQualityUpdate *livekit.SubscribedQualityUpdate) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSubscriptionResponse(subscriptionResponse *livekit.SubscriptionResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSubscriptionPermissionUpdate(subscriptionPermissionUpdate *livekit.SubscriptionPermissionUpdate) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalSubscribedAudioCodecUpdate(subscribedAudioCodecUpdate *livekit.SubscribedAudioCodecUpdate) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalPublishDataTrackResponse(publishDataTrackResponse *livekit.PublishDataTrackResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalUnpublishDataTrackResponse(unpublishDataTrackResponse *livekit.UnpublishDataTrackResponse) proto.Message { + return nil +} + +func (u *signallingUnimplemented) SignalDataTrackSubscriberHandles(dataTrackSubscriberHandles *livekit.DataTrackSubscriberHandles) proto.Message { + return nil +} diff --git a/livekit/pkg/rtc/subscribedtrack.go b/livekit/pkg/rtc/subscribedtrack.go new file mode 100644 index 0000000..6daa69d --- /dev/null +++ b/livekit/pkg/rtc/subscribedtrack.go @@ -0,0 +1,481 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "sync" + "time" + + "github.com/bep/debounce" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/telemetry" + sutils "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" +) + +const ( + subscriptionDebounceInterval = 100 * time.Millisecond +) + +var _ types.SubscribedTrack = (*SubscribedTrack)(nil) + +type SubscribedTrackParams struct { + ReceiverConfig ReceiverConfig + SubscriberConfig DirectionConfig + Subscriber types.LocalParticipant + MediaTrack types.MediaTrack + AdaptiveStream bool + Telemetry telemetry.TelemetryService + WrappedReceiver *WrappedReceiver + IsRelayed bool + OnDownTrackCreated func(downTrack *sfu.DownTrack) + OnDownTrackClosed func(subscriberID livekit.ParticipantID) + OnSubscriberMaxQualityChange func(subscriberID livekit.ParticipantID, mime mime.MimeType, layer int32) + OnSubscriberAudioCodecChange func(subscriberID livekit.ParticipantID, mime mime.MimeType, enabled bool) +} + +type SubscribedTrack struct { + params SubscribedTrackParams + logger logger.Logger + downTrack *sfu.DownTrack + sender atomic.Pointer[webrtc.RTPSender] + needsNegotiation atomic.Bool + + versionGenerator utils.TimedVersionGenerator + settingsLock sync.Mutex + settings *livekit.UpdateTrackSettings + settingsVersion utils.TimedVersion + + bindLock sync.Mutex + bound bool + onBindCallbacks []func(error) + + onClose atomic.Value // func(bool) + + debouncer func(func()) + + statsKey telemetry.StatsKey + reporter roomobs.TrackReporter +} + +func NewSubscribedTrack(params SubscribedTrackParams) (*SubscribedTrack, error) { + s := &SubscribedTrack{ + params: params, + logger: params.Subscriber.GetLogger().WithComponent(sutils.ComponentSub).WithValues( + "trackID", params.MediaTrack.ID(), + "publisherID", params.MediaTrack.PublisherID(), + "publisher", params.MediaTrack.PublisherIdentity(), + ), + versionGenerator: utils.NewDefaultTimedVersionGenerator(), + debouncer: debounce.New(subscriptionDebounceInterval), + statsKey: telemetry.StatsKeyForTrack( + params.Subscriber.GetCountry(), + livekit.StreamType_DOWNSTREAM, + params.Subscriber.ID(), + params.MediaTrack.ID(), + params.MediaTrack.Source(), + params.MediaTrack.Kind(), + ), + reporter: params.Subscriber.GetReporter().WithTrack(params.MediaTrack.ID().String()), + } + + var rtcpFeedback []webrtc.RTCPFeedback + var maxTrack int + switch params.MediaTrack.Kind() { + case livekit.TrackType_AUDIO: + rtcpFeedback = params.SubscriberConfig.RTCPFeedback.Audio + maxTrack = params.ReceiverConfig.PacketBufferSizeAudio + case livekit.TrackType_VIDEO: + rtcpFeedback = params.SubscriberConfig.RTCPFeedback.Video + maxTrack = params.ReceiverConfig.PacketBufferSizeVideo + default: + s.logger.Warnw("unexpected track type", nil, "kind", params.MediaTrack.Kind()) + } + codecs := params.WrappedReceiver.Codecs() + for _, c := range codecs { + c.RTCPFeedback = rtcpFeedback + } + + streamID := params.WrappedReceiver.StreamID() + if params.Subscriber.SupportsSyncStreamID() && params.MediaTrack.Stream() != "" { + streamID = PackSyncStreamID(params.MediaTrack.PublisherID(), params.MediaTrack.Stream()) + } + + isEncrypted := params.MediaTrack.IsEncrypted() + var trailer []byte + if isEncrypted { + trailer = params.Subscriber.GetTrailer() + } + downTrack, err := sfu.NewDownTrack(sfu.DownTrackParams{ + Codecs: codecs, + IsEncrypted: isEncrypted, + Source: params.MediaTrack.Source(), + Receiver: params.WrappedReceiver, + BufferFactory: params.Subscriber.GetBufferFactory(), + SubID: params.Subscriber.ID(), + StreamID: streamID, + MaxTrack: maxTrack, + PlayoutDelayLimit: params.Subscriber.GetPlayoutDelayConfig(), + Pacer: params.Subscriber.GetPacer(), + Trailer: trailer, + Logger: LoggerWithTrack( + params.Subscriber.GetLogger().WithComponent(sutils.ComponentSub), + params.MediaTrack.ID(), + params.IsRelayed, + ), + RTCPWriter: params.Subscriber.WriteSubscriberRTCP, + DisableSenderReportPassThrough: params.Subscriber.GetDisableSenderReportPassThrough(), + SupportsCodecChange: params.Subscriber.SupportsCodecChange(), + Listener: s, + }) + if err != nil { + return nil, err + } + + if params.OnDownTrackCreated != nil { + params.OnDownTrackCreated(downTrack) + } + + downTrack.AddReceiverReportListener(params.Subscriber.HandleReceiverReport) + + s.downTrack = downTrack + return s, nil +} + +func (t *SubscribedTrack) AddOnBind(f func(error)) { + t.bindLock.Lock() + bound := t.bound + if !bound { + t.onBindCallbacks = append(t.onBindCallbacks, f) + } + t.bindLock.Unlock() + + if bound { + // fire immediately, do not need to persist since bind is a one time event + go f(nil) + } +} + +// for DownTrack callback to notify us that it's bound +func (t *SubscribedTrack) Bound(err error) { + t.bindLock.Lock() + if err == nil { + t.bound = true + } + callbacks := t.onBindCallbacks + t.onBindCallbacks = nil + t.bindLock.Unlock() + + if err == nil && t.MediaTrack().Kind() == livekit.TrackType_VIDEO { + // When AdaptiveStream is enabled, default the subscriber to LOW quality stream + // we would want LOW instead of OFF for a couple of reasons + // 1. when a subscriber unsubscribes from a track, we would forget their previously defined settings + // depending on client implementation, subscription on/off is kept separately from adaptive stream + // So when there are no changes to desired resolution, but the user re-subscribes, we may leave stream at OFF + // 2. when interacting with dynacast *and* adaptive stream. If the publisher was not publishing at the + // time of subscription, we might not be able to trigger adaptive stream updates on the client side + // (since there isn't any video frames coming through). this will leave the stream "stuck" on off, without + // a trigger to re-enable it + t.settingsLock.Lock() + if t.settings != nil { + if t.params.AdaptiveStream { + // remove `disabled` flag to force a visibility update + t.settings.Disabled = false + t.logger.Debugw("enabling subscriber track settings on bind", "settings", logger.Proto(t.settings)) + } + } else { + if t.params.AdaptiveStream { + t.settings = &livekit.UpdateTrackSettings{Quality: livekit.VideoQuality_LOW} + } else { + t.settings = &livekit.UpdateTrackSettings{Quality: livekit.VideoQuality_HIGH} + } + t.logger.Debugw("initializing subscriber track settings on bind", "settings", logger.Proto(t.settings)) + } + t.settingsLock.Unlock() + t.applySettings() + } + + for _, cb := range callbacks { + go cb(err) + } +} + +// for DownTrack callback to notify us that it's closed +func (t *SubscribedTrack) Close(isExpectedToResume bool) { + if onClose := t.onClose.Load(); onClose != nil { + go onClose.(func(bool))(isExpectedToResume) + } +} + +func (t *SubscribedTrack) OnClose(f func(bool)) { + t.onClose.Store(f) +} + +func (t *SubscribedTrack) IsBound() bool { + t.bindLock.Lock() + defer t.bindLock.Unlock() + + return t.bound +} + +func (t *SubscribedTrack) ID() livekit.TrackID { + return livekit.TrackID(t.downTrack.ID()) +} + +func (t *SubscribedTrack) PublisherID() livekit.ParticipantID { + return t.params.MediaTrack.PublisherID() +} + +func (t *SubscribedTrack) PublisherIdentity() livekit.ParticipantIdentity { + return t.params.MediaTrack.PublisherIdentity() +} + +func (t *SubscribedTrack) PublisherVersion() uint32 { + return t.params.MediaTrack.PublisherVersion() +} + +func (t *SubscribedTrack) SubscriberID() livekit.ParticipantID { + return t.params.Subscriber.ID() +} + +func (t *SubscribedTrack) SubscriberIdentity() livekit.ParticipantIdentity { + return t.params.Subscriber.Identity() +} + +func (t *SubscribedTrack) Subscriber() types.LocalParticipant { + return t.params.Subscriber +} + +func (t *SubscribedTrack) DownTrack() *sfu.DownTrack { + return t.downTrack +} + +func (t *SubscribedTrack) MediaTrack() types.MediaTrack { + return t.params.MediaTrack +} + +// has subscriber indicated it wants to mute this track +func (t *SubscribedTrack) IsMuted() bool { + t.settingsLock.Lock() + defer t.settingsLock.Unlock() + + return t.isMutedLocked() +} + +func (t *SubscribedTrack) isMutedLocked() bool { + if t.settings == nil { + return false + } + + return t.settings.Disabled +} + +func (t *SubscribedTrack) SetPublisherMuted(muted bool) { + t.downTrack.PubMute(muted) +} + +func (t *SubscribedTrack) UpdateSubscriberSettings(settings *livekit.UpdateTrackSettings, isImmediate bool) { + t.settingsLock.Lock() + if proto.Equal(t.settings, settings) { + t.logger.Debugw("skipping subscriber track settings", "settings", logger.Proto(t.settings)) + t.settingsLock.Unlock() + return + } + + isImmediate = isImmediate || (!settings.Disabled && settings.Disabled != t.isMutedLocked()) + t.settings = utils.CloneProto(settings) + t.logger.Debugw("saving subscriber track settings", "settings", logger.Proto(t.settings)) + t.settingsLock.Unlock() + + if isImmediate { + t.applySettings() + } else { + // avoid frequent changes to mute & video layers, unless it became visible + t.debouncer(t.applySettings) + } +} + +func (t *SubscribedTrack) UpdateVideoLayer() { + t.applySettings() +} + +func (t *SubscribedTrack) applySettings() { + t.settingsLock.Lock() + if t.settings == nil { + t.settingsLock.Unlock() + return + } + + t.settingsVersion = t.versionGenerator.Next() + settingsVersion := t.settingsVersion + t.settingsLock.Unlock() + + dt := t.DownTrack() + spatial := buffer.InvalidLayerSpatial + temporal := buffer.InvalidLayerTemporal + if dt.Kind() == webrtc.RTPCodecTypeVideo { + mt := t.MediaTrack() + quality := t.settings.Quality + mimeType := dt.Mime() + if t.settings.Width > 0 { + quality = mt.GetQualityForDimension(mimeType, t.settings.Width, t.settings.Height) + } + + spatial = buffer.GetSpatialLayerForVideoQuality(mimeType, quality, mt.ToProto()) + if t.settings.Fps > 0 { + temporal = mt.GetTemporalLayerForSpatialFps(mimeType, spatial, t.settings.Fps) + } + } + + t.settingsLock.Lock() + if settingsVersion != t.settingsVersion { + // a newer settings has superseded this one + t.settingsLock.Unlock() + return + } + + t.logger.Debugw("applying subscriber track settings", "settings", logger.Proto(t.settings)) + if t.settings.Disabled { + dt.Mute(true) + t.settingsLock.Unlock() + return + } else { + dt.Mute(false) + } + + if dt.Kind() == webrtc.RTPCodecTypeVideo { + dt.SetMaxSpatialLayer(spatial) + if temporal != buffer.InvalidLayerTemporal { + dt.SetMaxTemporalLayer(temporal) + } + } + t.settingsLock.Unlock() +} + +func (t *SubscribedTrack) NeedsNegotiation() bool { + return t.needsNegotiation.Load() +} + +func (t *SubscribedTrack) SetNeedsNegotiation(needs bool) { + t.needsNegotiation.Store(needs) +} + +func (t *SubscribedTrack) RTPSender() *webrtc.RTPSender { + return t.sender.Load() +} + +func (t *SubscribedTrack) SetRTPSender(sender *webrtc.RTPSender) { + t.sender.Store(sender) +} + +// DownTrackListener implementation +var _ sfu.DownTrackListener = (*SubscribedTrack)(nil) + +func (t *SubscribedTrack) OnBindAndConnected() { + if t.params.Subscriber.Hidden() { + return + } + + t.params.MediaTrack.OnTrackSubscribed() +} + +func (t *SubscribedTrack) OnStatsUpdate(stat *livekit.AnalyticsStat) { + t.params.Telemetry.TrackStats(t.statsKey, stat) + + if cs, ok := telemetry.CondenseStat(stat); ok { + ti := t.params.WrappedReceiver.TrackInfo() + t.reporter.Tx(func(tx roomobs.TrackTx) { + tx.ReportName(ti.Name) + tx.ReportKind(roomobs.TrackKindSub) + tx.ReportType(roomobs.TrackTypeFromProto(ti.Type)) + tx.ReportSource(roomobs.TrackSourceFromProto(ti.Source)) + tx.ReportMime(mime.NormalizeMimeType(ti.MimeType).ReporterType()) + tx.ReportLayer(roomobs.PackTrackLayer(ti.Height, ti.Width)) + tx.ReportDuration(uint16(cs.EndTime.Sub(cs.StartTime).Milliseconds())) + tx.ReportFrames(uint16(cs.Frames)) + tx.ReportSendBytes(uint32(cs.Bytes)) + tx.ReportSendPackets(cs.Packets) + tx.ReportPacketsLost(cs.PacketsLost) + tx.ReportScore(stat.Score) + }) + } +} + +func (t *SubscribedTrack) OnMaxSubscribedLayerChanged(layer int32) { + if t.params.OnSubscriberMaxQualityChange != nil { + t.params.OnSubscriberMaxQualityChange(t.downTrack.SubscriberID(), t.downTrack.Mime(), layer) + } +} + +func (t *SubscribedTrack) OnRttUpdate(rtt uint32) { + go t.params.Subscriber.UpdateMediaRTT(rtt) +} + +func (t *SubscribedTrack) OnCodecNegotiated(codec webrtc.RTPCodecCapability) { + if isAvailable, needsPublish := t.params.WrappedReceiver.DetermineReceiver(codec); !isAvailable || !needsPublish { + return + } + + if t.params.OnSubscriberMaxQualityChange != nil || t.params.OnSubscriberAudioCodecChange != nil { + go func() { + mimeType := mime.NormalizeMimeType(codec.MimeType) + switch t.params.MediaTrack.Kind() { + case livekit.TrackType_VIDEO: + spatial := buffer.GetSpatialLayerForVideoQuality( + mimeType, + livekit.VideoQuality_HIGH, + t.params.MediaTrack.ToProto(), + ) + if t.params.OnSubscriberMaxQualityChange != nil { + t.params.OnSubscriberMaxQualityChange(t.downTrack.SubscriberID(), mimeType, spatial) + } + + case livekit.TrackType_AUDIO: + if t.params.OnSubscriberAudioCodecChange != nil { + t.params.OnSubscriberAudioCodecChange(t.downTrack.SubscriberID(), mimeType, true) + } + } + }() + } +} + +func (t *SubscribedTrack) OnDownTrackClose(isExpectedToResume bool) { + // Cache transceiver for potential re-use on resume. + // To ensure subscription manager does not re-subscribe before caching, + // delete the subscribed track only after caching. + if isExpectedToResume { + if tr := t.downTrack.GetTransceiver(); tr != nil { + t.params.Subscriber.CacheDownTrack(t.ID(), tr, t.downTrack.GetState()) + } + } + + go func() { + if t.params.OnDownTrackClosed != nil { + t.params.OnDownTrackClosed(t.params.Subscriber.ID()) + } + t.Close(isExpectedToResume) + }() +} diff --git a/livekit/pkg/rtc/subscriptionmanager.go b/livekit/pkg/rtc/subscriptionmanager.go new file mode 100644 index 0000000..8a5d070 --- /dev/null +++ b/livekit/pkg/rtc/subscriptionmanager.go @@ -0,0 +1,1577 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rtc + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/pion/webrtc/v4/pkg/rtcerr" + "go.uber.org/atomic" + "golang.org/x/exp/maps" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" +) + +// using var instead of const to override in tests +var ( + reconcileInterval = 3 * time.Second + // amount of time to give up if a track or publisher isn't found + // ensuring this is longer than iceFailedTimeout so we are certain the participant won't return + notFoundTimeout = time.Minute + // amount of time to try otherwise before flagging subscription as failed + subscriptionTimeout = iceFailedTimeoutTotal + maxUnsubscribeWait = time.Second +) + +const ( + trackIDForReconcileSubscriptions = livekit.TrackID("subscriptions_reconcile") +) + +type SubscriptionManagerParams struct { + Logger logger.Logger + Participant types.LocalParticipant + TrackResolver types.MediaTrackResolver + OnTrackSubscribed func(subTrack types.SubscribedTrack) + OnTrackUnsubscribed func(subTrack types.SubscribedTrack) + OnSubscriptionError func(trackID livekit.TrackID, fatal bool, err error) + Telemetry telemetry.TelemetryService + + SubscriptionLimitVideo, SubscriptionLimitAudio int32 + + DataTrackResolver types.DataTrackResolver + + UseOneShotSignallingMode bool +} + +// SubscriptionManager manages a participant's subscriptions +type SubscriptionManager struct { + params SubscriptionManagerParams + lock sync.RWMutex + subscriptions map[livekit.TrackID]*mediaTrackSubscription + pendingUnsubscribes atomic.Int32 + + subscribedVideoCount, subscribedAudioCount atomic.Int32 + + subscribedTo map[livekit.ParticipantID]map[livekit.TrackID]struct{} + reconcileCh chan livekit.TrackID + reconcileDataTrackCh chan livekit.TrackID + closeCh chan struct{} + doneCh chan struct{} + + onSubscribeStatusChanged func(publisherID livekit.ParticipantID, subscribed bool) + + dataTrackSubscriptions map[livekit.TrackID]*dataTrackSubscription +} + +func NewSubscriptionManager(params SubscriptionManagerParams) *SubscriptionManager { + m := &SubscriptionManager{ + params: params, + subscriptions: make(map[livekit.TrackID]*mediaTrackSubscription), + subscribedTo: make(map[livekit.ParticipantID]map[livekit.TrackID]struct{}), + reconcileCh: make(chan livekit.TrackID, 50), + reconcileDataTrackCh: make(chan livekit.TrackID, 5), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), + dataTrackSubscriptions: make(map[livekit.TrackID]*dataTrackSubscription), + } + + go m.reconcileWorker() + return m +} + +func (m *SubscriptionManager) Close(isExpectedToResume bool) { + m.lock.Lock() + if m.isClosed() { + m.lock.Unlock() + return + } + close(m.closeCh) + m.lock.Unlock() + + <-m.doneCh + + prometheus.RecordTrackSubscribeCancels(int32(m.getNumCancellations())) + + subTracks := m.GetSubscribedTracks() + downTracksToClose := make([]*sfu.DownTrack, 0, len(subTracks)) + for _, st := range subTracks { + m.setDesired(st.ID(), false) + dt := st.DownTrack() + // nil check exists primarily for tests + if dt != nil { + downTracksToClose = append(downTracksToClose, st.DownTrack()) + } + } + + if isExpectedToResume { + for _, dt := range downTracksToClose { + dt.CloseWithFlush(false, true) + } + } else { + // flush blocks, so execute in parallel + for _, dt := range downTracksToClose { + go dt.CloseWithFlush(true, true) + } + } + + m.lock.Lock() + for _, sub := range m.dataTrackSubscriptions { + dataDownTrack := sub.getDataDownTrack() + if dataDownTrack == nil { + // already unsubscribed + continue + } + + dataTrack := dataDownTrack.PublishDataTrack() + if dataTrack == nil { + continue + } + + dataTrack.RemoveSubscriber(sub.subscriberID) + } + m.lock.Unlock() + m.notifyDataTrackSubscriberHandles() +} + +func (m *SubscriptionManager) isClosed() bool { + select { + case <-m.closeCh: + return true + default: + return false + } +} + +func (m *SubscriptionManager) SubscribeToTrack(trackID livekit.TrackID, isSync bool) { + if m.params.UseOneShotSignallingMode || isSync { + m.subscribeSynchronous(trackID) + return + } + + sub, desireChanged := m.setDesired(trackID, true) + if sub == nil { + sLogger := m.params.Logger.WithValues( + "trackID", trackID, + ) + sub = newMediaTrackSubscription(m.params.Participant.ID(), trackID, sLogger) + + m.lock.Lock() + m.subscriptions[trackID] = sub + m.lock.Unlock() + + sub, desireChanged = m.setDesired(trackID, true) + } + if desireChanged { + sub.logger.Debugw("subscribing to track") + } + + // always reconcile, since SubscribeToTrack could be called when the track is ready + m.queueReconcile(trackID) +} + +func (m *SubscriptionManager) UnsubscribeFromTrack(trackID livekit.TrackID) { + if m.params.UseOneShotSignallingMode { + m.unsubscribeSynchronous(trackID) + return + } + + sub, desireChanged := m.setDesired(trackID, false) + if sub == nil || !desireChanged { + return + } + + if sub.isCanceled() { + prometheus.RecordTrackSubscribeCancels(1) + } + + sub.logger.Debugw("unsubscribing from track") + m.queueReconcile(trackID) +} + +func (m *SubscriptionManager) SubscribeToDataTrack(trackID livekit.TrackID) { + sub, desireChanged := m.setDataTrackDesired(trackID, true) + if sub == nil { + sLogger := m.params.Logger.WithValues( + "trackID", trackID, + ) + sub = newDataTrackSubscription(m.params.Participant.ID(), trackID, sLogger) + + m.lock.Lock() + m.dataTrackSubscriptions[trackID] = sub + m.lock.Unlock() + + sub, desireChanged = m.setDataTrackDesired(trackID, true) + } + if desireChanged { + sub.logger.Debugw("subscribing to data track") + } + + m.queueReconcileDataTrack(trackID) +} + +func (m *SubscriptionManager) UnsubscribeFromDataTrack(trackID livekit.TrackID) { + sub, desireChanged := m.setDataTrackDesired(trackID, false) + if sub == nil || !desireChanged { + return + } + + sub.logger.Debugw("unsubscribing from data track") + m.queueReconcileDataTrack(trackID) +} + +func (m *SubscriptionManager) ClearAllSubscriptions() { + m.params.Logger.Debugw("clearing all subscriptions") + + if m.params.UseOneShotSignallingMode { + for _, track := range m.GetSubscribedTracks() { + m.unsubscribeSynchronous(track.ID()) + } + + // no synchronous data tracks + } + + numCancellations := 0 + m.lock.RLock() + for _, sub := range m.subscriptions { + if sub.isCanceled() { + numCancellations++ + } + sub.setDesired(false) + } + + for _, sub := range m.dataTrackSubscriptions { + sub.setDesired(false) + } + m.lock.RUnlock() + prometheus.RecordTrackSubscribeCancels(int32(numCancellations)) + + m.ReconcileAll() +} + +func (m *SubscriptionManager) GetSubscribedTracks() []types.SubscribedTrack { + m.lock.RLock() + defer m.lock.RUnlock() + + tracks := make([]types.SubscribedTrack, 0, len(m.subscriptions)) + for _, t := range m.subscriptions { + st := t.getSubscribedTrack() + if st != nil { + tracks = append(tracks, st) + } + } + return tracks +} + +func (m *SubscriptionManager) IsTrackNameSubscribed(publisherIdentity livekit.ParticipantIdentity, trackName string) bool { + m.lock.RLock() + defer m.lock.RUnlock() + + for _, s := range m.subscriptions { + st := s.getSubscribedTrack() + if st != nil && st.PublisherIdentity() == publisherIdentity && st.MediaTrack() != nil && st.MediaTrack().Name() == trackName { + return true + } + } + return false +} + +func (m *SubscriptionManager) StopAndGetSubscribedTracksForwarderState() map[livekit.TrackID]*livekit.RTPForwarderState { + m.lock.RLock() + defer m.lock.RUnlock() + + states := make(map[livekit.TrackID]*livekit.RTPForwarderState, len(m.subscriptions)) + for trackID, t := range m.subscriptions { + st := t.getSubscribedTrack() + if st != nil { + dt := st.DownTrack() + if dt != nil { + state := dt.StopWriteAndGetState() + if state.ForwarderState != nil { + states[trackID] = state.ForwarderState + } + } + } + } + return states +} + +func (m *SubscriptionManager) HasSubscriptions() bool { + m.lock.RLock() + defer m.lock.RUnlock() + for _, s := range m.subscriptions { + if s.isDesired() { + return true + } + } + return false +} + +func (m *SubscriptionManager) GetSubscribedParticipants() []livekit.ParticipantID { + m.lock.RLock() + defer m.lock.RUnlock() + + return maps.Keys(m.subscribedTo) +} + +func (m *SubscriptionManager) IsSubscribedTo(participantID livekit.ParticipantID) bool { + m.lock.RLock() + defer m.lock.RUnlock() + + _, ok := m.subscribedTo[participantID] + return ok +} + +func (m *SubscriptionManager) UpdateSubscribedTrackSettings(trackID livekit.TrackID, settings *livekit.UpdateTrackSettings) { + m.lock.Lock() + sub, ok := m.subscriptions[trackID] + if !ok { + sLogger := m.params.Logger.WithValues( + "trackID", trackID, + ) + sub = newMediaTrackSubscription(m.params.Participant.ID(), trackID, sLogger) + m.subscriptions[trackID] = sub + } + m.lock.Unlock() + + sub.setSettings(settings) +} + +func (m *SubscriptionManager) UpdateDataTrackSubscriptionOptions(trackID livekit.TrackID, subscriptionOptions *livekit.DataTrackSubscriptionOptions) { + m.lock.Lock() + sub, ok := m.dataTrackSubscriptions[trackID] + if !ok { + sLogger := m.params.Logger.WithValues( + "trackID", trackID, + ) + sub = newDataTrackSubscription(m.params.Participant.ID(), trackID, sLogger) + m.dataTrackSubscriptions[trackID] = sub + } + m.lock.Unlock() + + sub.setSubscriptionOptions(subscriptionOptions) +} + +// OnSubscribeStatusChanged callback will be notified when a participant subscribes or unsubscribes to another participant +// it will only fire once per publisher. If current participant is subscribed to multiple tracks from another, this +// callback will only fire once. +func (m *SubscriptionManager) OnSubscribeStatusChanged(fn func(publisherID livekit.ParticipantID, subscribed bool)) { + m.lock.Lock() + m.onSubscribeStatusChanged = fn + m.lock.Unlock() +} + +func (m *SubscriptionManager) WaitUntilSubscribed(timeout time.Duration) error { + expiresAt := time.Now().Add(timeout) + for expiresAt.After(time.Now()) { + allSubscribed := true + m.lock.RLock() + for _, sub := range m.subscriptions { + if sub.needsSubscribe() { + allSubscribed = false + break + } + } + m.lock.RUnlock() + if allSubscribed { + return nil + } + time.Sleep(50 * time.Millisecond) + } + + return context.DeadlineExceeded +} + +func (m *SubscriptionManager) ReconcileAll() { + m.queueReconcile(trackIDForReconcileSubscriptions) + m.queueReconcileDataTrack(trackIDForReconcileSubscriptions) +} + +func (m *SubscriptionManager) setDesired(trackID livekit.TrackID, desired bool) (*mediaTrackSubscription, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + + sub, ok := m.subscriptions[trackID] + if !ok { + return nil, false + } + + return sub, sub.setDesired(desired) +} + +func (m *SubscriptionManager) setDataTrackDesired(trackID livekit.TrackID, desired bool) (*dataTrackSubscription, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + + sub, ok := m.dataTrackSubscriptions[trackID] + if !ok { + return nil, false + } + + return sub, sub.setDesired(desired) +} + +func (m *SubscriptionManager) canReconcile() bool { + p := m.params.Participant + if m.isClosed() || p.IsClosed() || p.IsDisconnected() { + return false + } + return true +} + +func (m *SubscriptionManager) reconcileSubscriptions() { + var needsToReconcile []*mediaTrackSubscription + m.lock.RLock() + for _, sub := range m.subscriptions { + if sub.needsSubscribe() || sub.needsUnsubscribe() || sub.needsBind() || sub.needsCleanup() { + needsToReconcile = append(needsToReconcile, sub) + } + } + m.lock.RUnlock() + + for _, s := range needsToReconcile { + m.reconcileSubscription(s) + } +} + +func (m *SubscriptionManager) reconcileSubscription(s *mediaTrackSubscription) { + if !m.canReconcile() { + return + } + if s.needsSubscribe() { + if m.pendingUnsubscribes.Load() != 0 && s.durationSinceStart() < maxUnsubscribeWait { + // enqueue this in a bit, after pending unsubscribes are complete + go func() { + time.Sleep(time.Duration(sfu.RTPBlankFramesCloseSeconds * float32(time.Second))) + m.queueReconcile(s.trackID) + }() + return + } + + numAttempts := s.getNumAttempts() + if numAttempts == 0 { + m.params.Telemetry.TrackSubscribeRequested( + context.Background(), + s.subscriberID, + &livekit.TrackInfo{ + Sid: string(s.trackID), + }, + ) + } + if err := m.subscribe(s); err != nil { + s.recordAttempt(false) + + switch err { + case ErrNoTrackPermission, ErrNoSubscribePermission, ErrNoReceiver, ErrNotOpen, ErrSubscriptionLimitExceeded: + // these are errors that are outside of our control, so we'll keep trying + // - ErrNoTrackPermission: publisher did not grant subscriber permission, may change any moment + // - ErrNoSubscribePermission: participant was not granted canSubscribe, may change any moment + // - ErrNoReceiver: Track is in the process of closing (another local track published to the same instance) + // - ErrNotOpen: Track is closing or already closed + // - ErrSubscriptionLimitExceeded: the participant have reached the limit of subscriptions, wait for the other subscription to be unsubscribed + // We'll still log an event to reflect this in telemetry since it's been too long + if s.durationSinceStart() > subscriptionTimeout { + s.maybeRecordError(m.params.Telemetry, s.subscriberID, err, true) + } + case ErrTrackNotFound: + // source track was never published or closed + // if after timeout we'd unsubscribe from it. + // this is the *only* case we'd change desired state + if s.durationSinceStart() > notFoundTimeout { + s.maybeRecordError(m.params.Telemetry, s.subscriberID, err, true) + s.logger.Infow("unsubscribing from track after notFoundTimeout", "error", err) + s.setDesired(false) + m.queueReconcile(s.trackID) + m.params.OnSubscriptionError(s.trackID, false, err) + } + default: + // all other errors + if s.durationSinceStart() > subscriptionTimeout { + s.logger.Warnw( + "failed to subscribe, triggering error handler", err, + "attempt", s.getNumAttempts(), + ) + s.maybeRecordError(m.params.Telemetry, s.subscriberID, err, false) + m.params.OnSubscriptionError(s.trackID, true, err) + } else { + s.logger.Debugw( + "failed to subscribe, retrying", + "error", err, + "attempt", s.getNumAttempts(), + ) + } + } + } else { + s.recordAttempt(true) + } + + return + } + + if s.needsUnsubscribe() { + if err := m.unsubscribe(s); err != nil { + s.logger.Warnw("failed to unsubscribe", err) + } + // do not remove subscription from map. Wait for subscribed track to close + // and the callback (handleSubscribedTrackClose) to set the subscribedTrack to nil + // and the clean up path to handle removing subscription from the subscription map. + // It is possible that the track is re-published before subscribed track is closed. + // That could create a new subscription and a duplicate entry in SDP. + // Waiting for subscribed track close would ensure that the track is removed from + // the peer connection before re-published track is re-subscribed and added back to the SDP. + return + } + + if s.needsBind() { + // check bound status, notify error callback if it's not bound + // if a publisher leaves or closes the source track, SubscribedTrack will be closed as well and it will go + // back to needsSubscribe state + if activeAt := m.params.Participant.ActiveAt(); !activeAt.IsZero() { + wait := min(time.Since(activeAt), s.durationSinceStart()) + if wait > subscriptionTimeout { + s.logger.Warnw("track not bound after timeout", nil) + s.maybeRecordError(m.params.Telemetry, s.subscriberID, ErrTrackNotBound, false) + m.params.OnSubscriptionError(s.trackID, true, ErrTrackNotBound) + } + } + } + + m.lock.Lock() + if s.needsCleanup() { + s.logger.Debugw("cleanup removing subscription") + delete(m.subscriptions, s.trackID) + } + m.lock.Unlock() +} + +func (m *SubscriptionManager) reconcileDataTrackSubscriptions() { + var needsToReconcile []*dataTrackSubscription + m.lock.RLock() + for _, sub := range m.dataTrackSubscriptions { + if sub.needsSubscribe() || sub.needsUnsubscribe() { + needsToReconcile = append(needsToReconcile, sub) + } + } + m.lock.RUnlock() + + for _, s := range needsToReconcile { + m.reconcileDataTrackSubscription(s) + } +} + +func (m *SubscriptionManager) reconcileDataTrackSubscription(s *dataTrackSubscription) { + if !m.canReconcile() { + return + } + if s.needsSubscribe() { + if err := m.subscribeDataTrack(s); err != nil { + s.recordAttempt(false) + + switch err { + case ErrNoSubscribePermission: + // these are errors that are outside of our control, so we'll keep trying + // - ErrNoSubscribePermission: participant was not granted canSubscribe, may change any moment + case ErrTrackNotFound: + // source track was never published or closed + // if after timeout we'd unsubscribe from it. + // this is the *only* case we'd change desired state + if s.durationSinceStart() > notFoundTimeout { + s.logger.Infow("unsubscribing from data track after notFoundTimeout", "error", err) + s.setDesired(false) + m.queueReconcile(s.trackID) + } + default: + // all other errors + if s.durationSinceStart() > subscriptionTimeout { + s.logger.Warnw( + "failed to subscribe, triggering error handler", err, + "attempt", s.getNumAttempts(), + ) + } else { + s.logger.Debugw( + "failed to subscribe, retrying", + "error", err, + "attempt", s.getNumAttempts(), + ) + } + } + } else { + s.recordAttempt(true) + m.notifyDataTrackSubscriberHandles() + } + + return + } + + if s.needsUnsubscribe() { + if err := m.unsubscribeDataTrack(s); err != nil { + s.logger.Warnw("failed to unsubscribe", err) + } + + m.lock.Lock() + delete(m.dataTrackSubscriptions, s.trackID) + m.lock.Unlock() + m.notifyDataTrackSubscriberHandles() + return + } + + m.lock.Lock() + if s.needsCleanup() { + s.logger.Debugw("cleanup removing data track subscription") + delete(m.dataTrackSubscriptions, s.trackID) + m.notifyDataTrackSubscriberHandles() + } + m.lock.Unlock() +} + +// trigger an immediate reconciliation, when trackID is empty, will reconcile all subscriptions +func (m *SubscriptionManager) queueReconcile(trackID livekit.TrackID) { + select { + case m.reconcileCh <- trackID: + default: + // queue is full, will reconcile based on timer + } +} + +func (m *SubscriptionManager) queueReconcileDataTrack(trackID livekit.TrackID) { + select { + case m.reconcileDataTrackCh <- trackID: + default: + // queue is full, will reconcile based on timer + } +} + +func (m *SubscriptionManager) reconcileWorker() { + defer close(m.doneCh) + + reconcileTicker := time.NewTicker(reconcileInterval) + defer reconcileTicker.Stop() + + for { + select { + case <-m.closeCh: + return + + case <-reconcileTicker.C: + m.reconcileSubscriptions() + m.reconcileDataTrackSubscriptions() + + case trackID := <-m.reconcileCh: + m.lock.Lock() + s := m.subscriptions[trackID] + m.lock.Unlock() + if s != nil { + m.reconcileSubscription(s) + } else { + m.reconcileSubscriptions() + } + + case trackID := <-m.reconcileDataTrackCh: + m.lock.Lock() + s := m.dataTrackSubscriptions[trackID] + m.lock.Unlock() + if s != nil { + m.reconcileDataTrackSubscription(s) + } else { + m.reconcileDataTrackSubscriptions() + } + } + } +} + +func (m *SubscriptionManager) hasCapacityForSubscription(kind livekit.TrackType) bool { + switch kind { + case livekit.TrackType_VIDEO: + if m.params.SubscriptionLimitVideo > 0 && m.subscribedVideoCount.Load() >= m.params.SubscriptionLimitVideo { + return false + } + + case livekit.TrackType_AUDIO: + if m.params.SubscriptionLimitAudio > 0 && m.subscribedAudioCount.Load() >= m.params.SubscriptionLimitAudio { + return false + } + } + return true +} + +func (m *SubscriptionManager) subscribe(sub *mediaTrackSubscription) error { + sub.logger.Debugw("executing subscribe") + + if !m.params.Participant.CanSubscribe() { + return ErrNoSubscribePermission + } + + if kind, ok := sub.getKind(); ok && !m.hasCapacityForSubscription(kind) { + return ErrSubscriptionLimitExceeded + } + + trackID := sub.trackID + res := m.params.TrackResolver(m.params.Participant, trackID) + sub.logger.Debugw("resolved track", "result", res) + + if res.TrackChangedNotifier != nil && sub.setChangedNotifier(res.TrackChangedNotifier) { + // set callback only when we haven't done it before + // we set the observer before checking for existence of track, so that we may get notified + // when the track becomes available + res.TrackChangedNotifier.AddObserver(string(sub.subscriberID), func() { + m.queueReconcile(trackID) + }) + } + if res.TrackRemovedNotifier != nil && sub.setRemovedNotifier(res.TrackRemovedNotifier) { + res.TrackRemovedNotifier.AddObserver(string(sub.subscriberID), func() { + // re-resolve the track in case the same track had been re-published + res := m.params.TrackResolver(m.params.Participant, trackID) + if res.Track != nil { + // do not unsubscribe, track is still available + return + } + m.handleSourceTrackRemoved(trackID) + }) + } + + track := res.Track + if track == nil { + return ErrTrackNotFound + } + sub.trySetKind(track.Kind()) + if !m.hasCapacityForSubscription(track.Kind()) { + return ErrSubscriptionLimitExceeded + } + + sub.setPublisher(res.PublisherIdentity, res.PublisherID) + + permChanged := sub.setHasPermission(res.HasPermission) + if permChanged { + m.params.Participant.SendSubscriptionPermissionUpdate(sub.getPublisherID(), trackID, res.HasPermission) + } + if !res.HasPermission { + return ErrNoTrackPermission + } + + if err := m.addSubscriber(sub, track, false); err != nil { + return err + } + + m.markSubscribedTo(sub.getPublisherID(), trackID) + return nil +} + +func (m *SubscriptionManager) subscribeSynchronous(trackID livekit.TrackID) error { + m.params.Logger.Debugw("executing subscribe synchronous", "trackID", trackID) + + if !m.params.Participant.CanSubscribe() { + return ErrNoSubscribePermission + } + + res := m.params.TrackResolver(m.params.Participant, trackID) + m.params.Logger.Debugw("resolved track", "trackID", trackID, " result", res) + + track := res.Track + if track == nil { + return ErrTrackNotFound + } + + m.lock.Lock() + sub, ok := m.subscriptions[trackID] + if !ok { + sLogger := m.params.Logger.WithValues( + "trackID", trackID, + ) + sub = newMediaTrackSubscription(m.params.Participant.ID(), trackID, sLogger) + m.subscriptions[trackID] = sub + } + m.lock.Unlock() + sub.setDesired(true) + + return m.addSubscriber(sub, track, true) +} + +func (m *SubscriptionManager) addSubscriber(sub *mediaTrackSubscription, track types.MediaTrack, isSync bool) error { + trackID := track.ID() + subTrack, err := track.AddSubscriber(m.params.Participant) + if err != nil && !errors.Is(err, errAlreadySubscribed) { + // ignore error(s): already subscribed + if !utils.ErrorIsOneOf(err, ErrNoReceiver) { + // as track resolution could take some time, not logging errors due to waiting for track resolution + m.params.Logger.Warnw("add subscriber failed", err, "trackID", trackID) + } + return err + } + if err == errAlreadySubscribed { + sub.logger.Debugw( + "already subscribed to track", + "subscribedAudioCount", m.subscribedAudioCount.Load(), + "subscribedVideoCount", m.subscribedVideoCount.Load(), + ) + } + if err == nil && subTrack != nil { // subTrack could be nil if already subscribed + subTrack.OnClose(func(isExpectedToResume bool) { + m.handleSubscribedTrackClose(sub, isExpectedToResume) + + if isSync { + m.lock.Lock() + delete(m.subscriptions, trackID) + m.lock.Unlock() + } + }) + subTrack.AddOnBind(func(err error) { + if err != nil { + sub.logger.Infow("failed to bind track", "err", err) + sub.maybeRecordError(m.params.Telemetry, sub.subscriberID, err, true) + m.UnsubscribeFromTrack(trackID) + m.params.OnSubscriptionError(trackID, false, err) + return + } + sub.setBound() + sub.maybeRecordSuccess(m.params.Telemetry, sub.subscriberID) + }) + sub.setSubscribedTrack(subTrack) + + switch track.Kind() { + case livekit.TrackType_VIDEO: + m.subscribedVideoCount.Inc() + case livekit.TrackType_AUDIO: + m.subscribedAudioCount.Inc() + } + + if !isSync && subTrack.NeedsNegotiation() { + m.params.Participant.Negotiate(false) + } + + go m.params.OnTrackSubscribed(subTrack) + + sub.logger.Debugw( + "subscribed to track", + "subscribedAudioCount", m.subscribedAudioCount.Load(), + "subscribedVideoCount", m.subscribedVideoCount.Load(), + ) + } + return nil +} + +func (m *SubscriptionManager) unsubscribe(s *mediaTrackSubscription) error { + s.logger.Debugw("executing unsubscribe") + + subTrack := s.getSubscribedTrack() + if subTrack == nil { + // already unsubscribed + return nil + } + + track := subTrack.MediaTrack() + pID := s.subscriberID + m.pendingUnsubscribes.Inc() + go func() { + defer m.pendingUnsubscribes.Dec() + track.RemoveSubscriber(pID, false) + }() + + return nil +} + +func (m *SubscriptionManager) unsubscribeSynchronous(trackID livekit.TrackID) error { + m.lock.Lock() + sub := m.subscriptions[trackID] + delete(m.subscriptions, trackID) + m.lock.Unlock() + if sub == nil { + // already unsubscribed or not subscribed + return nil + } + + sub.logger.Debugw("executing unsubscribe synchronous") + + if sub.isCanceled() { + prometheus.RecordTrackSubscribeCancels(1) + } + + subTrack := sub.getSubscribedTrack() + if subTrack == nil { + // already unsubscribed + return nil + } + + track := subTrack.MediaTrack() + track.RemoveSubscriber(sub.subscriberID, false) + return nil +} + +func (m *SubscriptionManager) handleSourceTrackRemoved(trackID livekit.TrackID) { + m.lock.Lock() + sub := m.subscriptions[trackID] + m.lock.Unlock() + + if sub != nil { + if sub.isCanceled() { + prometheus.RecordTrackSubscribeCancels(1) + } + sub.handleSourceTrackRemoved() + } +} + +// DownTrack closing is how the publisher signifies that the subscription is no longer fulfilled +// this could be due to a few reasons: +// - subscriber-initiated unsubscribe +// - UpTrack was closed +// - publisher revoked permissions for the participant +func (m *SubscriptionManager) handleSubscribedTrackClose(s *mediaTrackSubscription, isExpectedToResume bool) { + s.logger.Debugw( + "subscribed track closed", + "isExpectedToResume", isExpectedToResume, + ) + wasBound := s.isBound() + subTrack := s.getSubscribedTrack() + if subTrack == nil { + return + } + s.setSubscribedTrack(nil) + + var relieveFromLimits bool + switch subTrack.MediaTrack().Kind() { + case livekit.TrackType_VIDEO: + videoCount := m.subscribedVideoCount.Dec() + relieveFromLimits = m.params.SubscriptionLimitVideo > 0 && videoCount == m.params.SubscriptionLimitVideo-1 + case livekit.TrackType_AUDIO: + audioCount := m.subscribedAudioCount.Dec() + relieveFromLimits = m.params.SubscriptionLimitAudio > 0 && audioCount == m.params.SubscriptionLimitAudio-1 + } + + m.unmarkSubscribedTo(s.getPublisherID(), s.trackID) + + go m.params.OnTrackUnsubscribed(subTrack) + + // trigger to decrement unsubscribed counter as long as track has been bound + // Only log an analytics event when + // * the participant isn't closing + // * it's not a migration + if wasBound { + m.params.Telemetry.TrackUnsubscribed( + context.Background(), + s.subscriberID, + &livekit.TrackInfo{Sid: string(s.trackID), Type: subTrack.MediaTrack().Kind()}, + !isExpectedToResume, + ) + + dt := subTrack.DownTrack() + if dt != nil { + stats := dt.GetTrackStats() + if stats != nil { + m.params.Telemetry.TrackSubscribeRTPStats( + context.Background(), + s.subscriberID, + s.trackID, + dt.Mime(), + stats, + ) + } + } + } + + if !isExpectedToResume { + sender := subTrack.RTPSender() + if sender != nil { + s.logger.Debugw("removing PeerConnection track", + "kind", subTrack.MediaTrack().Kind(), + ) + + if err := m.params.Participant.RemoveTrackLocal(sender); err != nil { + if _, ok := err.(*rtcerr.InvalidStateError); !ok { + // most of these are safe to ignore, since the track state might have already + // been set to Inactive + m.params.Logger.Debugw("could not remove remoteTrack from forwarder", + "error", err, + "publisher", subTrack.PublisherIdentity(), + "publisherID", subTrack.PublisherID(), + ) + } + } + } + + m.params.Participant.Negotiate(false) + } else { + t := time.Now() + s.subscribeAt.Store(&t) + } + if !m.params.UseOneShotSignallingMode { + if relieveFromLimits { + m.queueReconcile(trackIDForReconcileSubscriptions) + } else { + m.queueReconcile(s.trackID) + } + } +} + +func (m *SubscriptionManager) subscribeDataTrack(sub *dataTrackSubscription) error { + sub.logger.Debugw("executing subscribe") + + if !m.params.Participant.CanSubscribe() { + return ErrNoSubscribePermission + } + + trackID := sub.trackID + res := m.params.DataTrackResolver(m.params.Participant, trackID) + sub.logger.Debugw("resolved data track", "result", res) + + if res.TrackChangedNotifier != nil && sub.setChangedNotifier(res.TrackChangedNotifier) { + // set callback only when we haven't done it before + // we set the observer before checking for existence of track, so that we may get notified + // when the track becomes available + res.TrackChangedNotifier.AddObserver(string(sub.subscriberID), func() { + m.queueReconcileDataTrack(trackID) + }) + } + if res.TrackRemovedNotifier != nil && sub.setRemovedNotifier(res.TrackRemovedNotifier) { + res.TrackRemovedNotifier.AddObserver(string(sub.subscriberID), func() { + // re-resolve the track in case the same track had been re-published + res := m.params.DataTrackResolver(m.params.Participant, trackID) + if res.DataTrack != nil { + // do not unsubscribe, track is still available + return + } + m.handleSourceDataTrackRemoved(trackID) + }) + } + + dataTrack := res.DataTrack + if dataTrack == nil { + return ErrTrackNotFound + } + + sub.setPublisher(res.PublisherIdentity, res.PublisherID) + + dataDownTrack, err := dataTrack.AddSubscriber(m.params.Participant) + if err != nil && !errors.Is(err, errAlreadySubscribed) { + return err + } + if err == errAlreadySubscribed { + sub.logger.Debugw("already subscribed to data track") + } + if err == nil && dataDownTrack != nil { // subTrack could be nil if already subscribed + sub.setDataDownTrack(dataDownTrack) + sub.logger.Debugw("subscribed to data track") + } + + m.markSubscribedTo(sub.getPublisherID(), trackID) + return nil +} + +func (m *SubscriptionManager) unsubscribeDataTrack(s *dataTrackSubscription) error { + s.logger.Debugw("executing unsubscribe") + + dataDownTrack := s.getDataDownTrack() + if dataDownTrack == nil { + // already unsubscribed + return nil + } + + dataTrack := dataDownTrack.PublishDataTrack() + dataTrack.RemoveSubscriber(s.subscriberID) + + m.unmarkSubscribedTo(s.getPublisherID(), s.trackID) + return nil +} + +func (m *SubscriptionManager) notifyDataTrackSubscriberHandles() { + m.lock.Lock() + handles := make(map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack, len(m.dataTrackSubscriptions)) + for _, sub := range m.dataTrackSubscriptions { + if !sub.isDesired() { + continue + } + dataDownTrack := sub.getDataDownTrack() + if dataDownTrack == nil { + continue + } + handles[uint32(dataDownTrack.Handle())] = &livekit.DataTrackSubscriberHandles_PublishedDataTrack{ + PublisherIdentity: string(sub.publisherIdentity), + PublisherSid: string(sub.publisherID), + TrackSid: string(sub.trackID), + } + } + m.lock.Unlock() + + m.params.Participant.SendDataTrackSubscriberHandles(handles) +} + +func (m *SubscriptionManager) handleSourceDataTrackRemoved(trackID livekit.TrackID) { + m.lock.Lock() + sub := m.dataTrackSubscriptions[trackID] + m.lock.Unlock() + + if sub != nil { + sub.handleSourceTrackRemoved() + } +} + +func (m *SubscriptionManager) markSubscribedTo(publisherID livekit.ParticipantID, trackID livekit.TrackID) { + // add mark the participant as someone we've subscribed to + firstSubscribe := false + m.lock.Lock() + pTracks := m.subscribedTo[publisherID] + changedCB := m.onSubscribeStatusChanged + if pTracks == nil { + pTracks = make(map[livekit.TrackID]struct{}) + m.subscribedTo[publisherID] = pTracks + firstSubscribe = true + } + pTracks[trackID] = struct{}{} + m.lock.Unlock() + + if changedCB != nil && firstSubscribe { + changedCB(publisherID, true) + } +} + +func (m *SubscriptionManager) unmarkSubscribedTo(publisherID livekit.ParticipantID, trackID livekit.TrackID) { + // remove from subscribedTo + lastSubscription := false + m.lock.Lock() + changedCB := m.onSubscribeStatusChanged + pTracks := m.subscribedTo[publisherID] + if pTracks != nil { + delete(pTracks, trackID) + if len(pTracks) == 0 { + delete(m.subscribedTo, publisherID) + lastSubscription = true + } + } + m.lock.Unlock() + if changedCB != nil && lastSubscription { + go changedCB(publisherID, false) + } +} + +func (m *SubscriptionManager) getNumCancellations() int { + m.lock.RLock() + defer m.lock.RUnlock() + + numCancellations := 0 + for _, sub := range m.subscriptions { + if sub.isCanceled() { + numCancellations++ + } + } + return numCancellations +} + +// -------------------------------------------------------------------------------------- + +type trackSubscription struct { + subscriberID livekit.ParticipantID + trackID livekit.TrackID + logger logger.Logger + + lock sync.RWMutex + desired bool + publisherID livekit.ParticipantID + publisherIdentity livekit.ParticipantIdentity + changedNotifier types.ChangeNotifier + removedNotifier types.ChangeNotifier + + numAttempts atomic.Int32 + + // the later of when subscription was requested OR when the first failure was encountered OR when permission is granted + // this timestamp determines when failures are reported + subStartedAt atomic.Pointer[time.Time] + + // the timestamp when the subscription was started, will be reset when downtrack is closed with expected resume + subscribeAt atomic.Pointer[time.Time] +} + +func (s *trackSubscription) setPublisher(publisherIdentity livekit.ParticipantIdentity, publisherID livekit.ParticipantID) { + s.lock.Lock() + defer s.lock.Unlock() + + s.publisherID = publisherID + s.publisherIdentity = publisherIdentity +} + +func (s *trackSubscription) getPublisherID() livekit.ParticipantID { + s.lock.RLock() + defer s.lock.RUnlock() + return s.publisherID +} + +func (s *trackSubscription) setDesired(desired bool) bool { + s.lock.Lock() + defer s.lock.Unlock() + + if desired { + // as long as user explicitly set it to desired + // we'll reset the timer so it has sufficient time to reconcile + t := time.Now() + s.subStartedAt.Store(&t) + s.subscribeAt.Store(&t) + } + + if s.desired == desired { + return false + } + s.desired = desired + + // when no longer desired, we no longer care about change notifications + if desired { + // reset attempts + s.numAttempts.Store(0) + } else { + s.setChangedNotifierLocked(nil) + s.setRemovedNotifierLocked(nil) + } + return true +} + +func (s *trackSubscription) isDesired() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return s.desired +} + +func (s *trackSubscription) recordAttempt(success bool) { + if !success { + if s.numAttempts.Load() == 0 { + // on first failure, we'd want to start the timer + t := time.Now() + s.subStartedAt.Store(&t) + } + s.numAttempts.Add(1) + } else { + s.numAttempts.Store(0) + } +} + +func (s *trackSubscription) getNumAttempts() int32 { + return s.numAttempts.Load() +} + +func (s *trackSubscription) durationSinceStart() time.Duration { + t := s.subStartedAt.Load() + if t == nil { + return 0 + } + return time.Since(*t) +} + +func (s *trackSubscription) setChangedNotifier(notifier types.ChangeNotifier) bool { + s.lock.Lock() + defer s.lock.Unlock() + return s.setChangedNotifierLocked(notifier) +} + +func (s *trackSubscription) setChangedNotifierLocked(notifier types.ChangeNotifier) bool { + if s.changedNotifier == notifier { + return false + } + + existing := s.changedNotifier + s.changedNotifier = notifier + + if existing != nil { + go existing.RemoveObserver(string(s.subscriberID)) + } + return true +} + +func (s *trackSubscription) setRemovedNotifier(notifier types.ChangeNotifier) bool { + s.lock.Lock() + defer s.lock.Unlock() + return s.setRemovedNotifierLocked(notifier) +} + +func (s *trackSubscription) setRemovedNotifierLocked(notifier types.ChangeNotifier) bool { + if s.removedNotifier == notifier { + return false + } + + existing := s.removedNotifier + s.removedNotifier = notifier + + if existing != nil { + go existing.RemoveObserver(string(s.subscriberID)) + } + return true +} + +func (s *trackSubscription) handleSourceTrackRemoved() { + s.lock.Lock() + defer s.lock.Unlock() + + // source track removed, we would unsubscribe + s.logger.Debugw("unsubscribing from track since source track was removed") + s.desired = false + + s.setChangedNotifierLocked(nil) + s.setRemovedNotifierLocked(nil) +} + +// -------------------------------------------------------------------------------------- + +type mediaTrackSubscription struct { + trackSubscription + + settings *livekit.UpdateTrackSettings + hasPermissionInitialized bool + hasPermission bool + subscribedTrack types.SubscribedTrack + eventSent atomic.Bool + bound bool + kind atomic.Pointer[livekit.TrackType] + + succRecordCounter atomic.Int32 +} + +func newMediaTrackSubscription(subscriberID livekit.ParticipantID, trackID livekit.TrackID, l logger.Logger) *mediaTrackSubscription { + s := &mediaTrackSubscription{ + trackSubscription: trackSubscription{ + subscriberID: subscriberID, + trackID: trackID, + logger: l, + }, + } + t := time.Now() + s.subscribeAt.Store(&t) + return s +} + +// set permission and return true if it has changed +func (s *mediaTrackSubscription) setHasPermission(perm bool) bool { + s.lock.Lock() + defer s.lock.Unlock() + if s.hasPermissionInitialized && s.hasPermission == perm { + return false + } + + s.hasPermissionInitialized = true + s.hasPermission = perm + if s.hasPermission { + // when permission is granted, reset the timer so it has sufficient time to reconcile + t := time.Now() + s.subStartedAt.Store(&t) + s.subscribeAt.Store(&t) + } + return true +} + +func (s *mediaTrackSubscription) getHasPermission() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return s.hasPermission +} + +func (s *mediaTrackSubscription) setSubscribedTrack(track types.SubscribedTrack) { + s.lock.Lock() + oldTrack := s.subscribedTrack + s.subscribedTrack = track + s.bound = false + settings := s.settings + s.lock.Unlock() + + if settings != nil && track != nil { + s.logger.Debugw("restoring subscriber settings", "settings", logger.Proto(settings)) + track.UpdateSubscriberSettings(settings, true) + } + if oldTrack != nil { + oldTrack.OnClose(nil) + } +} + +func (s *mediaTrackSubscription) getSubscribedTrack() types.SubscribedTrack { + s.lock.RLock() + defer s.lock.RUnlock() + return s.subscribedTrack +} + +func (s *mediaTrackSubscription) trySetKind(kind livekit.TrackType) { + s.kind.CompareAndSwap(nil, &kind) +} + +func (s *mediaTrackSubscription) getKind() (livekit.TrackType, bool) { + kind := s.kind.Load() + if kind == nil { + return livekit.TrackType_AUDIO, false + } + return *kind, true +} + +func (s *mediaTrackSubscription) setSettings(settings *livekit.UpdateTrackSettings) { + s.lock.Lock() + s.settings = settings + subTrack := s.subscribedTrack + s.lock.Unlock() + if subTrack != nil { + subTrack.UpdateSubscriberSettings(settings, false) + } +} + +// mark the subscription as bound - when we've received the client's answer +func (s *mediaTrackSubscription) setBound() { + s.lock.Lock() + defer s.lock.Unlock() + s.bound = true +} + +func (s *mediaTrackSubscription) isBound() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return s.bound +} + +func (s *mediaTrackSubscription) maybeRecordError(ts telemetry.TelemetryService, pID livekit.ParticipantID, err error, isUserError bool) { + if s.eventSent.Swap(true) { + return + } + + ts.TrackSubscribeFailed(context.Background(), pID, s.trackID, err, isUserError) +} + +func (s *mediaTrackSubscription) maybeRecordSuccess(ts telemetry.TelemetryService, pID livekit.ParticipantID) { + subTrack := s.getSubscribedTrack() + if subTrack == nil { + return + } + mediaTrack := subTrack.MediaTrack() + if mediaTrack == nil { + return + } + + d := time.Since(*s.subscribeAt.Load()) + s.logger.Debugw("track subscribed", "cost", d.Milliseconds()) + subscriber := subTrack.Subscriber() + prometheus.RecordSubscribeTime( + subscriber.GetCountry(), + mediaTrack.Source(), + mediaTrack.Kind(), + d, + subscriber.GetClientInfo().GetSdk(), + subscriber.Kind(), + int(s.succRecordCounter.Inc()), + ) + + eventSent := s.eventSent.Swap(true) + + pi := &livekit.ParticipantInfo{ + Identity: string(subTrack.PublisherIdentity()), + Sid: string(subTrack.PublisherID()), + } + ts.TrackSubscribed(context.Background(), pID, mediaTrack.ToProto(), pi, !eventSent) +} + +func (s *mediaTrackSubscription) isCanceled() bool { + return !s.eventSent.Load() && s.isDesired() +} + +func (s *mediaTrackSubscription) needsSubscribe() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return s.desired && s.subscribedTrack == nil +} + +func (s *mediaTrackSubscription) needsUnsubscribe() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return !s.desired && s.subscribedTrack != nil +} + +func (s *mediaTrackSubscription) needsBind() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return s.desired && s.subscribedTrack != nil && !s.bound +} + +func (s *mediaTrackSubscription) needsCleanup() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return !s.desired && s.subscribedTrack == nil +} + +// ----------------------------------------------------------------- + +type dataTrackSubscription struct { + trackSubscription + + subscriptionOptions *livekit.DataTrackSubscriptionOptions + + dataDownTrack types.DataDownTrack +} + +func newDataTrackSubscription(subscriberID livekit.ParticipantID, trackID livekit.TrackID, l logger.Logger) *dataTrackSubscription { + s := &dataTrackSubscription{ + trackSubscription: trackSubscription{ + subscriberID: subscriberID, + trackID: trackID, + logger: l, + }, + } + t := time.Now() + s.subscribeAt.Store(&t) + return s +} + +func (s *dataTrackSubscription) needsSubscribe() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return s.desired && s.dataDownTrack == nil +} + +func (s *dataTrackSubscription) needsUnsubscribe() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return !s.desired && s.dataDownTrack != nil +} + +func (s *dataTrackSubscription) needsCleanup() bool { + s.lock.RLock() + defer s.lock.RUnlock() + return !s.desired && s.dataDownTrack == nil +} + +func (s *dataTrackSubscription) setDataDownTrack(dataDownTrack types.DataDownTrack) { + s.lock.Lock() + s.dataDownTrack = dataDownTrack + subscriptionOptions := s.subscriptionOptions + s.lock.Unlock() + + if dataDownTrack != nil { + s.logger.Debugw("restoring data track subscription options", "subscriptionOptions", logger.Proto(subscriptionOptions)) + dataDownTrack.UpdateSubscriptionOptions(subscriptionOptions) + } + + // DT-TODO - DataTrack close callback on previous if not nil?, see setSubscribedTrack for example +} + +func (s *dataTrackSubscription) getDataDownTrack() types.DataDownTrack { + s.lock.RLock() + defer s.lock.RUnlock() + return s.dataDownTrack +} + +func (s *dataTrackSubscription) setSubscriptionOptions(subscriptionOptions *livekit.DataTrackSubscriptionOptions) { + s.lock.Lock() + s.subscriptionOptions = subscriptionOptions + dataDownTrack := s.dataDownTrack + s.lock.Unlock() + if dataDownTrack != nil { + dataDownTrack.UpdateSubscriptionOptions(subscriptionOptions) + } +} diff --git a/livekit/pkg/rtc/subscriptionmanager_test.go b/livekit/pkg/rtc/subscriptionmanager_test.go new file mode 100644 index 0000000..af7fce8 --- /dev/null +++ b/livekit/pkg/rtc/subscriptionmanager_test.go @@ -0,0 +1,547 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rtc + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" + "github.com/livekit/livekit-server/pkg/telemetry/telemetryfakes" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +func init() { + reconcileInterval = 50 * time.Millisecond + notFoundTimeout = 200 * time.Millisecond + subscriptionTimeout = 200 * time.Millisecond +} + +const ( + subSettleTimeout = 600 * time.Millisecond + subCheckInterval = 10 * time.Millisecond +) + +func TestSubscribe(t *testing.T) { + t.Run("happy path subscribe", func(t *testing.T) { + sm := newTestSubscriptionManager() + defer sm.Close(false) + resolver := newTestResolver(true, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + subCount := atomic.Int32{} + failed := atomic.Bool{} + sm.params.OnTrackSubscribed = func(subTrack types.SubscribedTrack) { + subCount.Add(1) + } + sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) { + failed.Store(true) + } + numParticipantSubscribed := atomic.Int32{} + numParticipantUnsubscribed := atomic.Int32{} + sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) { + if subscribed { + numParticipantSubscribed.Add(1) + } else { + numParticipantUnsubscribed.Add(1) + } + }) + + sm.SubscribeToTrack("track", false) + s := sm.subscriptions["track"] + require.True(t, s.isDesired()) + require.Eventually(t, func() bool { + return subCount.Load() == 1 + }, subSettleTimeout, subCheckInterval, "track was not subscribed") + + require.NotNil(t, s.getSubscribedTrack()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + require.Eventually(t, func() bool { + return len(sm.GetSubscribedParticipants()) == 1 + }, subSettleTimeout, subCheckInterval, "GetSubscribedParticipants should have returned one item") + require.Equal(t, "pubID", string(sm.GetSubscribedParticipants()[0])) + + // ensure telemetry events are sent + tm := sm.params.Telemetry.(*telemetryfakes.FakeTelemetryService) + require.Equal(t, 1, tm.TrackSubscribeRequestedCallCount()) + + // ensure bound + setTestSubscribedTrackBound(t, s.getSubscribedTrack()) + require.Eventually(t, func() bool { + return !s.needsBind() + }, subSettleTimeout, subCheckInterval, "track was not bound") + + // telemetry event should have been sent + require.Equal(t, 1, tm.TrackSubscribedCallCount()) + + time.Sleep(notFoundTimeout) + require.False(t, failed.Load()) + + resolver.SetPause(true) + // ensure its resilience after being closed + setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false) + require.Eventually(t, func() bool { + return s.needsSubscribe() + }, subSettleTimeout, subCheckInterval, "needs subscribe did not persist across track close") + resolver.SetPause(false) + + require.Eventually(t, func() bool { + return s.isDesired() && !s.needsSubscribe() + }, subSettleTimeout, subCheckInterval, "track was not resubscribed") + + // was subscribed twice, unsubscribed once (due to close) + require.Eventually(t, func() bool { + return numParticipantSubscribed.Load() == 2 + }, subSettleTimeout, subCheckInterval, "participant subscribe status was not updated twice") + require.Equal(t, int32(1), numParticipantUnsubscribed.Load()) + }) + + t.Run("no track permission", func(t *testing.T) { + sm := newTestSubscriptionManager() + defer sm.Close(false) + resolver := newTestResolver(false, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + failed := atomic.Bool{} + sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) { + failed.Store(true) + } + + sm.SubscribeToTrack("track", false) + s := sm.subscriptions["track"] + require.Eventually(t, func() bool { + return !s.getHasPermission() + }, subSettleTimeout, subCheckInterval, "should not have permission to subscribe") + + time.Sleep(subscriptionTimeout) + + // should not have called failed callbacks, isDesired remains unchanged + require.True(t, s.isDesired()) + require.False(t, failed.Load()) + require.True(t, s.needsSubscribe()) + require.Len(t, sm.GetSubscribedTracks(), 0) + + // trackSubscribed telemetry not sent + tm := sm.params.Telemetry.(*telemetryfakes.FakeTelemetryService) + require.Equal(t, 1, tm.TrackSubscribeRequestedCallCount()) + require.Equal(t, 0, tm.TrackSubscribedCallCount()) + + // give permissions now + resolver.lock.Lock() + resolver.hasPermission = true + resolver.lock.Unlock() + + require.Eventually(t, func() bool { + return !s.needsSubscribe() + }, subSettleTimeout, subCheckInterval, "should be subscribed") + + require.Len(t, sm.GetSubscribedTracks(), 1) + }) + + t.Run("publisher left", func(t *testing.T) { + sm := newTestSubscriptionManager() + defer sm.Close(false) + resolver := newTestResolver(true, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + failed := atomic.Bool{} + sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) { + failed.Store(true) + } + + sm.SubscribeToTrack("track", false) + s := sm.subscriptions["track"] + require.Eventually(t, func() bool { + return !s.needsSubscribe() + }, subSettleTimeout, subCheckInterval, "should be subscribed") + + resolver.lock.Lock() + resolver.hasTrack = false + resolver.lock.Unlock() + + // publisher triggers close + setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false) + + require.Eventually(t, func() bool { + return !s.isDesired() + }, subSettleTimeout, subCheckInterval, "isDesired not set to false") + }) +} + +func TestUnsubscribe(t *testing.T) { + sm := newTestSubscriptionManager() + defer sm.Close(false) + unsubCount := atomic.Int32{} + sm.params.OnTrackUnsubscribed = func(subTrack types.SubscribedTrack) { + unsubCount.Add(1) + } + + resolver := newTestResolver(true, true, "pub", "pubID") + + s := &mediaTrackSubscription{ + trackSubscription: trackSubscription{ + trackID: "track", + desired: true, + subscriberID: sm.params.Participant.ID(), + publisherID: "pubID", + publisherIdentity: "pub", + logger: logger.GetLogger(), + }, + hasPermission: true, + bound: true, + } + // a bunch of unfortunate manual wiring + res := resolver.Resolve(nil, s.trackID) + res.TrackChangedNotifier.AddObserver(string(sm.params.Participant.ID()), func() {}) + s.changedNotifier = res.TrackChangedNotifier + st, err := res.Track.AddSubscriber(sm.params.Participant) + require.NoError(t, err) + s.subscribedTrack = st + st.OnClose(func(isExpectedToResume bool) { + sm.handleSubscribedTrackClose(s, isExpectedToResume) + }) + res.Track.(*typesfakes.FakeMediaTrack).RemoveSubscriberCalls(func(pID livekit.ParticipantID, isExpectedToResume bool) { + setTestSubscribedTrackClosed(t, st, isExpectedToResume) + }) + + sm.lock.Lock() + sm.subscriptions["track"] = s + sm.lock.Unlock() + + require.False(t, s.needsSubscribe()) + require.False(t, s.needsUnsubscribe()) + + // unsubscribe + sm.UnsubscribeFromTrack("track") + require.False(t, s.isDesired()) + + require.Eventually(t, func() bool { + if s.needsUnsubscribe() { + return false + } + if sm.pendingUnsubscribes.Load() != 0 { + return false + } + sm.lock.RLock() + subLen := len(sm.subscriptions) + sm.lock.RUnlock() + if subLen != 0 { + return false + } + return true + }, subSettleTimeout, subCheckInterval, "Track was not unsubscribed") + + // no traces should be left + require.Len(t, sm.GetSubscribedTracks(), 0) + require.False(t, res.TrackChangedNotifier.HasObservers()) + + tm := sm.params.Telemetry.(*telemetryfakes.FakeTelemetryService) + require.Equal(t, 1, tm.TrackUnsubscribedCallCount()) +} + +func TestSubscribeStatusChanged(t *testing.T) { + sm := newTestSubscriptionManager() + defer sm.Close(false) + resolver := newTestResolver(true, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + numParticipantSubscribed := atomic.Int32{} + numParticipantUnsubscribed := atomic.Int32{} + sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) { + if subscribed { + numParticipantSubscribed.Add(1) + } else { + numParticipantUnsubscribed.Add(1) + } + }) + + sm.SubscribeToTrack("track1", false) + sm.SubscribeToTrack("track2", false) + s1 := sm.subscriptions["track1"] + s2 := sm.subscriptions["track2"] + require.Eventually(t, func() bool { + return !s1.needsSubscribe() && !s2.needsSubscribe() + }, subSettleTimeout, subCheckInterval, "track1 and track2 should be subscribed") + st1 := s1.getSubscribedTrack() + st1.OnClose(func(isExpectedToResume bool) { + sm.handleSubscribedTrackClose(s1, isExpectedToResume) + }) + st2 := s2.getSubscribedTrack() + st2.OnClose(func(isExpectedToResume bool) { + sm.handleSubscribedTrackClose(s2, isExpectedToResume) + }) + st1.MediaTrack().(*typesfakes.FakeMediaTrack).RemoveSubscriberCalls(func(pID livekit.ParticipantID, isExpectedToResume bool) { + setTestSubscribedTrackClosed(t, st1, isExpectedToResume) + }) + st2.MediaTrack().(*typesfakes.FakeMediaTrack).RemoveSubscriberCalls(func(pID livekit.ParticipantID, isExpectedToResume bool) { + setTestSubscribedTrackClosed(t, st2, isExpectedToResume) + }) + + require.Eventually(t, func() bool { + return numParticipantSubscribed.Load() == 1 + }, subSettleTimeout, subCheckInterval, "should be subscribed to publisher") + require.Equal(t, int32(0), numParticipantUnsubscribed.Load()) + require.True(t, sm.IsSubscribedTo("pubID")) + + // now unsubscribe track2, no event should be fired + sm.UnsubscribeFromTrack("track2") + require.Eventually(t, func() bool { + return !s2.needsUnsubscribe() + }, subSettleTimeout, subCheckInterval, "track2 should be unsubscribed") + require.Equal(t, int32(0), numParticipantUnsubscribed.Load()) + + // unsubscribe track1, expect event + sm.UnsubscribeFromTrack("track1") + require.Eventually(t, func() bool { + return !s1.needsUnsubscribe() + }, subSettleTimeout, subCheckInterval, "track1 should be unsubscribed") + require.Eventually(t, func() bool { + return numParticipantUnsubscribed.Load() == 1 + }, subSettleTimeout, subCheckInterval, "should be subscribed to publisher") + require.False(t, sm.IsSubscribedTo("pubID")) +} + +// clients may send update subscribed settings prior to subscription events coming through +// settings should be persisted and used when the subscription does take place. +func TestUpdateSettingsBeforeSubscription(t *testing.T) { + sm := newTestSubscriptionManager() + defer sm.Close(false) + resolver := newTestResolver(true, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + + settings := &livekit.UpdateTrackSettings{ + Disabled: true, + Width: 100, + Height: 100, + } + sm.UpdateSubscribedTrackSettings("track", settings) + + sm.SubscribeToTrack("track", false) + + s := sm.subscriptions["track"] + require.Eventually(t, func() bool { + return !s.needsSubscribe() + }, subSettleTimeout, subCheckInterval, "Track should be subscribed") + + st := s.getSubscribedTrack().(*typesfakes.FakeSubscribedTrack) + require.Eventually(t, func() bool { + return st.UpdateSubscriberSettingsCallCount() == 1 + }, subSettleTimeout, subCheckInterval, "UpdateSubscriberSettings should be called once") + + applied, _ := st.UpdateSubscriberSettingsArgsForCall(0) + require.Equal(t, settings.Disabled, applied.Disabled) + require.Equal(t, settings.Width, applied.Width) + require.Equal(t, settings.Height, applied.Height) +} + +func TestSubscriptionLimits(t *testing.T) { + sm := newTestSubscriptionManagerWithParams(testSubscriptionParams{ + SubscriptionLimitAudio: 1, + SubscriptionLimitVideo: 1, + }) + defer sm.Close(false) + resolver := newTestResolver(true, true, "pub", "pubID") + sm.params.TrackResolver = resolver.Resolve + subCount := atomic.Int32{} + failed := atomic.Bool{} + sm.params.OnTrackSubscribed = func(subTrack types.SubscribedTrack) { + subCount.Add(1) + } + sm.params.OnSubscriptionError = func(trackID livekit.TrackID, fatal bool, err error) { + failed.Store(true) + } + numParticipantSubscribed := atomic.Int32{} + numParticipantUnsubscribed := atomic.Int32{} + sm.OnSubscribeStatusChanged(func(pubID livekit.ParticipantID, subscribed bool) { + if subscribed { + numParticipantSubscribed.Add(1) + } else { + numParticipantUnsubscribed.Add(1) + } + }) + + sm.SubscribeToTrack("track", false) + s := sm.subscriptions["track"] + require.True(t, s.isDesired()) + require.Eventually(t, func() bool { + return subCount.Load() == 1 + }, subSettleTimeout, subCheckInterval, "track was not subscribed") + + require.NotNil(t, s.getSubscribedTrack()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + require.Eventually(t, func() bool { + return len(sm.GetSubscribedParticipants()) == 1 + }, subSettleTimeout, subCheckInterval, "GetSubscribedParticipants should have returned one item") + require.Equal(t, "pubID", string(sm.GetSubscribedParticipants()[0])) + + // ensure telemetry events are sent + tm := sm.params.Telemetry.(*telemetryfakes.FakeTelemetryService) + require.Equal(t, 1, tm.TrackSubscribeRequestedCallCount()) + + // ensure bound + setTestSubscribedTrackBound(t, s.getSubscribedTrack()) + require.Eventually(t, func() bool { + return !s.needsBind() + }, subSettleTimeout, subCheckInterval, "track was not bound") + + // telemetry event should have been sent + require.Equal(t, 1, tm.TrackSubscribedCallCount()) + + // reach subscription limit, subscribe pending + sm.SubscribeToTrack("track2", false) + s2 := sm.subscriptions["track2"] + time.Sleep(subscriptionTimeout * 2) + require.True(t, s2.needsSubscribe()) + require.Equal(t, 2, tm.TrackSubscribeRequestedCallCount()) + require.Equal(t, 1, tm.TrackSubscribeFailedCallCount()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + // unsubscribe track1, then track2 should be subscribed + sm.UnsubscribeFromTrack("track") + require.False(t, s.isDesired()) + require.True(t, s.needsUnsubscribe()) + // wait for unsubscribe to take effect + time.Sleep(reconcileInterval) + setTestSubscribedTrackClosed(t, s.getSubscribedTrack(), false) + require.Nil(t, s.getSubscribedTrack()) + + time.Sleep(reconcileInterval) + require.True(t, s2.isDesired()) + require.False(t, s2.needsSubscribe()) + require.EqualValues(t, 2, subCount.Load()) + require.NotNil(t, s2.getSubscribedTrack()) + require.Equal(t, 2, tm.TrackSubscribeRequestedCallCount()) + require.Len(t, sm.GetSubscribedTracks(), 1) + + // ensure bound + setTestSubscribedTrackBound(t, s2.getSubscribedTrack()) + require.Eventually(t, func() bool { + return !s2.needsBind() + }, subSettleTimeout, subCheckInterval, "track was not bound") + + // subscribe to track1 again, which should pending + sm.SubscribeToTrack("track", false) + s = sm.subscriptions["track"] + require.True(t, s.isDesired()) + time.Sleep(subscriptionTimeout * 2) + require.True(t, s.needsSubscribe()) + require.Equal(t, 3, tm.TrackSubscribeRequestedCallCount()) + require.Equal(t, 2, tm.TrackSubscribeFailedCallCount()) + require.Len(t, sm.GetSubscribedTracks(), 1) +} + +type testSubscriptionParams struct { + SubscriptionLimitAudio int32 + SubscriptionLimitVideo int32 +} + +func newTestSubscriptionManager() *SubscriptionManager { + return newTestSubscriptionManagerWithParams(testSubscriptionParams{}) +} + +func newTestSubscriptionManagerWithParams(params testSubscriptionParams) *SubscriptionManager { + p := &typesfakes.FakeLocalParticipant{} + p.CanSubscribeReturns(true) + p.IDReturns("subID") + p.IdentityReturns("sub") + p.KindReturns(livekit.ParticipantInfo_STANDARD) + return NewSubscriptionManager(SubscriptionManagerParams{ + Participant: p, + Logger: logger.GetLogger(), + OnTrackSubscribed: func(subTrack types.SubscribedTrack) {}, + OnTrackUnsubscribed: func(subTrack types.SubscribedTrack) {}, + OnSubscriptionError: func(trackID livekit.TrackID, fatal bool, err error) {}, + TrackResolver: func(sub types.LocalParticipant, trackID livekit.TrackID) types.MediaResolverResult { + return types.MediaResolverResult{} + }, + Telemetry: &telemetryfakes.FakeTelemetryService{}, + SubscriptionLimitAudio: params.SubscriptionLimitAudio, + SubscriptionLimitVideo: params.SubscriptionLimitVideo, + }) +} + +type testResolver struct { + lock sync.Mutex + hasPermission bool + hasTrack bool + pubIdentity livekit.ParticipantIdentity + pubID livekit.ParticipantID + + paused bool +} + +func newTestResolver(hasPermission bool, hasTrack bool, pubIdentity livekit.ParticipantIdentity, pubID livekit.ParticipantID) *testResolver { + return &testResolver{ + hasPermission: hasPermission, + hasTrack: hasTrack, + pubIdentity: pubIdentity, + pubID: pubID, + } +} + +func (t *testResolver) SetPause(paused bool) { + t.lock.Lock() + defer t.lock.Unlock() + t.paused = paused +} + +func (t *testResolver) Resolve(_subscriber types.LocalParticipant, trackID livekit.TrackID) types.MediaResolverResult { + t.lock.Lock() + defer t.lock.Unlock() + res := types.MediaResolverResult{ + TrackChangedNotifier: utils.NewChangeNotifier(), + TrackRemovedNotifier: utils.NewChangeNotifier(), + HasPermission: t.hasPermission, + PublisherID: t.pubID, + PublisherIdentity: t.pubIdentity, + } + if t.hasTrack && !t.paused { + mt := &typesfakes.FakeMediaTrack{} + st := &typesfakes.FakeSubscribedTrack{} + st.IDReturns(trackID) + st.PublisherIDReturns(t.pubID) + st.PublisherIdentityReturns(t.pubIdentity) + mt.AddSubscriberCalls(func(sub types.LocalParticipant) (types.SubscribedTrack, error) { + st.SubscriberReturns(sub) + return st, nil + }) + st.MediaTrackReturns(mt) + res.Track = mt + } + return res +} + +func setTestSubscribedTrackBound(t *testing.T, st types.SubscribedTrack) { + fst, ok := st.(*typesfakes.FakeSubscribedTrack) + require.True(t, ok) + + for i := 0; i < fst.AddOnBindCallCount(); i++ { + fst.AddOnBindArgsForCall(i)(nil) + } +} + +func setTestSubscribedTrackClosed(t *testing.T, st types.SubscribedTrack, isExpectedToResume bool) { + fst, ok := st.(*typesfakes.FakeSubscribedTrack) + require.True(t, ok) + + fst.OnCloseArgsForCall(0)(isExpectedToResume) +} diff --git a/livekit/pkg/rtc/supervisor/participant_supervisor.go b/livekit/pkg/rtc/supervisor/participant_supervisor.go new file mode 100644 index 0000000..1dc6b9d --- /dev/null +++ b/livekit/pkg/rtc/supervisor/participant_supervisor.go @@ -0,0 +1,183 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package supervisor + +import ( + "sync" + "time" + + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +const ( + monitorInterval = 1 * time.Second +) + +type ParticipantSupervisorParams struct { + Logger logger.Logger +} + +type trackMonitor struct { + opMon types.OperationMonitor + err error +} + +type ParticipantSupervisor struct { + params ParticipantSupervisorParams + + lock sync.RWMutex + isPublisherConnected bool + publications map[livekit.TrackID]*trackMonitor + + isStopped atomic.Bool + + onPublicationError func(trackID livekit.TrackID) +} + +func NewParticipantSupervisor(params ParticipantSupervisorParams) *ParticipantSupervisor { + p := &ParticipantSupervisor{ + params: params, + publications: make(map[livekit.TrackID]*trackMonitor), + } + + go p.checkState() + + return p +} + +func (p *ParticipantSupervisor) Stop() { + p.isStopped.Store(true) +} + +func (p *ParticipantSupervisor) OnPublicationError(f func(trackID livekit.TrackID)) { + p.lock.Lock() + defer p.lock.Unlock() + + p.onPublicationError = f +} + +func (p *ParticipantSupervisor) getOnPublicationError() func(trackID livekit.TrackID) { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.onPublicationError +} + +func (p *ParticipantSupervisor) SetPublisherPeerConnectionConnected(isConnected bool) { + p.lock.Lock() + p.isPublisherConnected = isConnected + + for _, pm := range p.publications { + pm.opMon.PostEvent(types.OperationMonitorEventPublisherPeerConnectionConnected, p.isPublisherConnected) + } + p.lock.Unlock() +} + +func (p *ParticipantSupervisor) AddPublication(trackID livekit.TrackID) { + p.lock.Lock() + pm, ok := p.publications[trackID] + if !ok { + pm = &trackMonitor{ + opMon: NewPublicationMonitor( + PublicationMonitorParams{ + TrackID: trackID, + IsPeerConnectionConnected: p.isPublisherConnected, + Logger: p.params.Logger, + }, + ), + } + p.publications[trackID] = pm + } + pm.opMon.PostEvent(types.OperationMonitorEventAddPendingPublication, nil) + p.lock.Unlock() +} + +func (p *ParticipantSupervisor) SetPublicationMute(trackID livekit.TrackID, isMuted bool) { + p.lock.Lock() + pm, ok := p.publications[trackID] + if ok { + pm.opMon.PostEvent(types.OperationMonitorEventSetPublicationMute, isMuted) + } + p.lock.Unlock() +} + +func (p *ParticipantSupervisor) SetPublishedTrack(trackID livekit.TrackID, pubTrack types.LocalMediaTrack) { + p.lock.RLock() + pm, ok := p.publications[trackID] + if ok { + pm.opMon.PostEvent(types.OperationMonitorEventSetPublishedTrack, pubTrack) + } + p.lock.RUnlock() +} + +func (p *ParticipantSupervisor) ClearPublishedTrack(trackID livekit.TrackID, pubTrack types.LocalMediaTrack) { + p.lock.RLock() + pm, ok := p.publications[trackID] + if ok { + pm.opMon.PostEvent(types.OperationMonitorEventClearPublishedTrack, pubTrack) + } + p.lock.RUnlock() +} + +func (p *ParticipantSupervisor) checkState() { + ticker := time.NewTicker(monitorInterval) + defer ticker.Stop() + + for !p.isStopped.Load() { + <-ticker.C + + p.checkPublications() + } +} + +func (p *ParticipantSupervisor) checkPublications() { + var erroredPublications []livekit.TrackID + var removablePublications []livekit.TrackID + p.lock.RLock() + for trackID, pm := range p.publications { + if err := pm.opMon.Check(); err != nil { + if pm.err == nil { + p.params.Logger.Errorw("supervisor error on publication", err, "trackID", trackID) + pm.err = err + erroredPublications = append(erroredPublications, trackID) + } + } else { + if pm.err != nil { + p.params.Logger.Infow("supervisor publication recovered", "trackID", trackID) + pm.err = err + } + if pm.opMon.IsIdle() { + removablePublications = append(removablePublications, trackID) + } + } + } + p.lock.RUnlock() + + p.lock.Lock() + for _, trackID := range removablePublications { + delete(p.publications, trackID) + } + p.lock.Unlock() + + if onPublicationError := p.getOnPublicationError(); onPublicationError != nil { + for _, trackID := range erroredPublications { + onPublicationError(trackID) + } + } +} diff --git a/livekit/pkg/rtc/supervisor/publication_monitor.go b/livekit/pkg/rtc/supervisor/publication_monitor.go new file mode 100644 index 0000000..199d218 --- /dev/null +++ b/livekit/pkg/rtc/supervisor/publication_monitor.go @@ -0,0 +1,190 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package supervisor + +import ( + "errors" + "sync" + "time" + + "github.com/gammazero/deque" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +const ( + publishWaitDuration = 30 * time.Second +) + +var ( + errPublishTimeout = errors.New("publish time out") +) + +type publish struct { + isStart bool +} + +type PublicationMonitorParams struct { + TrackID livekit.TrackID + IsPeerConnectionConnected bool + Logger logger.Logger +} + +type PublicationMonitor struct { + params PublicationMonitorParams + + lock sync.RWMutex + desiredPublishes deque.Deque[*publish] + + isConnected bool + + publishedTrack types.LocalMediaTrack + isMuted bool + unmutedAt time.Time +} + +func NewPublicationMonitor(params PublicationMonitorParams) *PublicationMonitor { + p := &PublicationMonitor{ + params: params, + isConnected: params.IsPeerConnectionConnected, + } + p.desiredPublishes.SetBaseCap(4) + return p +} + +func (p *PublicationMonitor) PostEvent(ome types.OperationMonitorEvent, omd types.OperationMonitorData) { + switch ome { + case types.OperationMonitorEventPublisherPeerConnectionConnected: + p.setConnected(omd.(bool)) + case types.OperationMonitorEventAddPendingPublication: + p.addPending() + case types.OperationMonitorEventSetPublicationMute: + p.setMute(omd.(bool)) + case types.OperationMonitorEventSetPublishedTrack: + p.setPublishedTrack(omd.(types.LocalMediaTrack)) + case types.OperationMonitorEventClearPublishedTrack: + p.clearPublishedTrack(omd.(types.LocalMediaTrack)) + } +} + +func (p *PublicationMonitor) addPending() { + p.lock.Lock() + p.desiredPublishes.PushBack( + &publish{ + isStart: true, + }, + ) + + // synthesize an end + p.desiredPublishes.PushBack( + &publish{ + isStart: false, + }, + ) + p.update() + p.lock.Unlock() +} + +func (p *PublicationMonitor) maybeStartMonitor() { + if p.isConnected && !p.isMuted { + p.unmutedAt = time.Now() + } +} + +func (p *PublicationMonitor) setConnected(isConnected bool) { + p.lock.Lock() + p.isConnected = isConnected + p.maybeStartMonitor() + p.lock.Unlock() +} + +func (p *PublicationMonitor) setMute(isMuted bool) { + p.lock.Lock() + p.isMuted = isMuted + p.maybeStartMonitor() + p.lock.Unlock() +} + +func (p *PublicationMonitor) setPublishedTrack(pubTrack types.LocalMediaTrack) { + p.lock.Lock() + p.publishedTrack = pubTrack + p.update() + p.lock.Unlock() +} + +func (p *PublicationMonitor) clearPublishedTrack(pubTrack types.LocalMediaTrack) { + p.lock.Lock() + if p.publishedTrack == pubTrack { + p.publishedTrack = nil + } else { + p.params.Logger.Errorw("supervisor: mismatched published track on clear", nil, "trackID", p.params.TrackID) + } + + p.update() + p.lock.Unlock() +} + +func (p *PublicationMonitor) Check() error { + p.lock.RLock() + var pub *publish + if p.desiredPublishes.Len() > 0 { + pub = p.desiredPublishes.Front() + } + + isMuted := p.isMuted + unmutedAt := p.unmutedAt + p.lock.RUnlock() + + if pub == nil { + return nil + } + + if pub.isStart && !isMuted && !unmutedAt.IsZero() && time.Since(unmutedAt) > publishWaitDuration { + // timed out waiting for publish + return errPublishTimeout + } + + // give more time for publish to happen + // NOTE: synthesized end events do not have a start time, so do not check them for time out + return nil +} + +func (p *PublicationMonitor) IsIdle() bool { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.desiredPublishes.Len() == 0 && p.publishedTrack == nil +} + +func (p *PublicationMonitor) update() { + for { + var pub *publish + if p.desiredPublishes.Len() > 0 { + pub = p.desiredPublishes.PopFront() + } + + if pub == nil { + return + } + + if (pub.isStart && p.publishedTrack == nil) || (!pub.isStart && p.publishedTrack != nil) { + // put it back as the condition is not satisfied + p.desiredPublishes.PushFront(pub) + return + } + } +} diff --git a/livekit/pkg/rtc/testutils.go b/livekit/pkg/rtc/testutils.go new file mode 100644 index 0000000..0855821 --- /dev/null +++ b/livekit/pkg/rtc/testutils.go @@ -0,0 +1,88 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" +) + +func NewMockParticipant( + identity livekit.ParticipantIdentity, + protocol types.ProtocolVersion, + hidden bool, + publisher bool, + participantListener types.LocalParticipantListener, +) *typesfakes.FakeLocalParticipant { + p := &typesfakes.FakeLocalParticipant{} + sid := guid.New(utils.ParticipantPrefix) + p.IDReturns(livekit.ParticipantID(sid)) + p.IdentityReturns(identity) + p.StateReturns(livekit.ParticipantInfo_JOINED) + p.ProtocolVersionReturns(protocol) + p.CanSubscribeReturns(true) + p.CanPublishSourceReturns(!hidden) + p.CanPublishDataReturns(!hidden) + p.HiddenReturns(hidden) + p.ToProtoReturns(&livekit.ParticipantInfo{ + Sid: sid, + Identity: string(identity), + State: livekit.ParticipantInfo_JOINED, + IsPublisher: publisher, + }) + p.ToProtoWithVersionReturns(&livekit.ParticipantInfo{ + Sid: sid, + Identity: string(identity), + State: livekit.ParticipantInfo_JOINED, + IsPublisher: publisher, + }, utils.TimedVersion(0)) + + p.SetMetadataCalls(func(m string) { + participantListener.OnParticipantUpdate(p) + }) + updateTrack := func() { + participantListener.OnTrackUpdated(p, NewMockTrack(livekit.TrackType_VIDEO, "testcam")) + } + + p.SetTrackMutedCalls(func(mute *livekit.MuteTrackRequest, fromServer bool) *livekit.TrackInfo { + updateTrack() + return nil + }) + p.AddTrackCalls(func(req *livekit.AddTrackRequest) { + updateTrack() + }) + p.GetLoggerReturns(logger.GetLogger()) + p.GetReporterReturns(roomobs.NewNoopParticipantSessionReporter()) + + return p +} + +func NewMockTrack(kind livekit.TrackType, name string) *typesfakes.FakeMediaTrack { + t := &typesfakes.FakeMediaTrack{} + t.IDReturns(livekit.TrackID(guid.New(utils.TrackPrefix))) + t.KindReturns(kind) + t.NameReturns(name) + t.ToProtoReturns(&livekit.TrackInfo{ + Type: kind, + Name: name, + }) + return t +} diff --git a/livekit/pkg/rtc/transport.go b/livekit/pkg/rtc/transport.go new file mode 100644 index 0000000..1dbd2a4 --- /dev/null +++ b/livekit/pkg/rtc/transport.go @@ -0,0 +1,3222 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "io" + "maps" + "math/rand" + "net" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/pion/dtls/v3/pkg/crypto/elliptic" + "github.com/pion/ice/v4" + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/pkg/gcc" + "github.com/pion/interceptor/pkg/twcc" + "github.com/pion/rtcp" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "github.com/pkg/errors" + "go.uber.org/atomic" + "go.uber.org/zap/zapcore" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/bwe/remotebwe" + "github.com/livekit/livekit-server/pkg/sfu/bwe/sendsidebwe" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" + sfuinterceptor "github.com/livekit/livekit-server/pkg/sfu/interceptor" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + sfuutils "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" + lkinterceptor "github.com/livekit/mediatransportutil/pkg/interceptor" + lktwcc "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/logger/pionlogger" + lksdp "github.com/livekit/protocol/sdp" + "github.com/livekit/protocol/utils/mono" +) + +const ( + LossyDataChannel = "_lossy" + ReliableDataChannel = "_reliable" + DataTrackDataChannel = "_data_track" + + fastNegotiationFrequency = 10 * time.Millisecond + negotiationFrequency = 150 * time.Millisecond + negotiationFailedTimeout = 15 * time.Second + dtlsRetransmissionInterval = 100 * time.Millisecond + + iceDisconnectedTimeout = 10 * time.Second // compatible for ice-lite with firefox client + iceFailedTimeout = 5 * time.Second // time between disconnected and failed + iceFailedTimeoutTotal = iceFailedTimeout + iceDisconnectedTimeout // total time between connecting and failure + iceKeepaliveInterval = 2 * time.Second // pion's default + + minTcpICEConnectTimeout = 5 * time.Second + maxTcpICEConnectTimeout = 12 * time.Second // js-sdk has a default 15s timeout for first connection, let server detect failure earlier before that + + minConnectTimeoutAfterICE = 10 * time.Second + maxConnectTimeoutAfterICE = 20 * time.Second // max duration for waiting pc to connect after ICE is connected + + shortConnectionThreshold = 90 * time.Second + + dataChannelBufferSize = 65535 + lossyDataChannelMinBufferedAmount = 8 * 1024 +) + +var ( + ErrNoICETransport = errors.New("no ICE transport") + ErrIceRestartWithoutLocalSDP = errors.New("ICE restart without local SDP settled") + ErrIceRestartOnClosedPeerConnection = errors.New("ICE restart on closed peer connection") + ErrNoTransceiver = errors.New("no transceiver") + ErrNoSender = errors.New("no sender") + ErrMidNotFound = errors.New("mid not found") + ErrNotSynchronousLocalCandidatesMode = errors.New("not using synchronous local candidates mode") + ErrNoRemoteDescription = errors.New("no remote description") + ErrNoLocalDescription = errors.New("no local description") + ErrInvalidSDPFragment = errors.New("invalid sdp fragment") + ErrNoBundleMid = errors.New("could not get bundle mid") + ErrMidMismatch = errors.New("media mid does not match bundle mid") + ErrICECredentialMismatch = errors.New("ice credential mismatch") +) + +// ------------------------------------------------------------------------- + +type signal int + +const ( + signalICEGatheringComplete signal = iota + signalLocalICECandidate + signalRemoteICECandidate + signalSendOffer + signalRemoteDescriptionReceived + signalICERestart +) + +func (s signal) String() string { + switch s { + case signalICEGatheringComplete: + return "ICE_GATHERING_COMPLETE" + case signalLocalICECandidate: + return "LOCAL_ICE_CANDIDATE" + case signalRemoteICECandidate: + return "REMOTE_ICE_CANDIDATE" + case signalSendOffer: + return "SEND_OFFER" + case signalRemoteDescriptionReceived: + return "REMOTE_DESCRIPTION_RECEIVED" + case signalICERestart: + return "ICE_RESTART" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +// ------------------------------------------------------- + +type event struct { + *PCTransport + signal signal + data any +} + +func (e event) String() string { + return fmt.Sprintf("PCTransport:Event{signal: %s, data: %+v}", e.signal, e.data) +} + +// ------------------------------------------------------- + +type wrappedICECandidatePairLogger struct { + pair *webrtc.ICECandidatePair +} + +func (w wrappedICECandidatePairLogger) MarshalLogObject(e zapcore.ObjectEncoder) error { + if w.pair == nil { + return nil + } + + if w.pair.Local != nil { + e.AddString("localProtocol", w.pair.Local.Protocol.String()) + e.AddString("localCandidateType", w.pair.Local.Typ.String()) + e.AddString("localAddress", w.pair.Local.Address) + e.AddUint16("localPort", w.pair.Local.Port) + } + if w.pair.Remote != nil { + e.AddString("remoteProtocol", w.pair.Remote.Protocol.String()) + e.AddString("remoteCandidateType", w.pair.Remote.Typ.String()) + e.AddString("remoteAddress", MaybeTruncateIP(w.pair.Remote.Address)) + e.AddUint16("remotePort", w.pair.Remote.Port) + if w.pair.Remote.RelatedAddress != "" { + e.AddString("relatedAddress", MaybeTruncateIP(w.pair.Remote.RelatedAddress)) + e.AddUint16("relatedPort", w.pair.Remote.RelatedPort) + } + } + return nil +} + +// ------------------------------------------------------------------- + +type trackDescription struct { + mid string + sender *webrtc.RTPSender +} + +// PCTransport is a wrapper around PeerConnection, with some helper methods +type PCTransport struct { + params TransportParams + pc *webrtc.PeerConnection + iceTransport *webrtc.ICETransport + me *webrtc.MediaEngine + + lock sync.RWMutex + + firstOfferReceived bool + firstOfferNoDataChannel bool + reliableDC *datachannel.DataChannelWriter[*webrtc.DataChannel] + reliableDCOpened bool + lossyDC *datachannel.DataChannelWriter[*webrtc.DataChannel] + lossyDCOpened bool + dataTrackDC *datachannel.DataChannelWriter[*webrtc.DataChannel] + unlabeledDataChannels []*datachannel.DataChannelWriter[*webrtc.DataChannel] + + iceStartedAt time.Time + iceConnectedAt time.Time + firstConnectedAt time.Time + connectedAt time.Time + tcpICETimer *time.Timer + connectAfterICETimer *time.Timer // timer to wait for pc to connect after ice connected + resetShortConnOnICERestart atomic.Bool + signalingRTT atomic.Uint32 // milliseconds + + debouncedNegotiate *sfuutils.Debouncer + debouncePending bool + lastNegotiate time.Time + + onNegotiationStateChanged func(state transport.NegotiationState) + + rtxInfoExtractorFactory *sfuinterceptor.RTXInfoExtractorFactory + + // stream allocator for subscriber PC + streamAllocator *streamallocator.StreamAllocator + + // only for subscriber PC + bwe bwe.BWE + pacer pacer.Pacer + + // transceivers (senders) waiting for SetRemoteDescription (offer) to happen before + // SetCodecPreferences can be invoked on them. + // Pion adapts codecs/payload types from remote description. + // If SetCodecPreferences are done before the remote description is processed, + // it is possible that the transceiver gets payload types from media engine. + // Subssequently if the peer sends an offer with different payload type for the + // same codec, there could be two payload types for the same codec and the wrong + // one could be used in the forwarding path. So, wait for `SetRemoteDescription` + // to happen so that remote side payload types are adapted. + sendersPendingConfigMu sync.Mutex + sendersPendingConfig []configureSenderParams + + previousAnswer *webrtc.SessionDescription + // track id -> description map in previous offer sdp + previousTrackDescription map[string]*trackDescription + canReuseTransceiver bool + + preferTCP atomic.Bool + isClosed atomic.Bool + + // used to check for offer/answer pairing, + // i. e. every offer should have an answer before another offer can be sent + localOfferId atomic.Uint32 + remoteAnswerId atomic.Uint32 + + remoteOfferId atomic.Uint32 + localAnswerId atomic.Uint32 + + eventsQueue *utils.TypedOpsQueue[event] + + connectionDetails *types.ICEConnectionDetails + selectedPair atomic.Pointer[webrtc.ICECandidatePair] + mayFailedICEStats []iceCandidatePairStats + mayFailedICEStatsTimer *time.Timer + + numOutstandingAudios uint32 + numRequestSentAudios uint32 + numOutstandingVideos uint32 + numRequestSentVideos uint32 + + // the following should be accessed only in event processing go routine + cacheLocalCandidates bool + cachedLocalCandidates []*webrtc.ICECandidate + pendingRemoteCandidates []*webrtc.ICECandidateInit + restartAfterGathering bool + restartAtNextOffer bool + negotiationState transport.NegotiationState + negotiateCounter atomic.Int32 + signalStateCheckTimer *time.Timer + currentOfferIceCredential string // ice user:pwd, for publish side ice restart checking + pendingRestartIceOffer *webrtc.SessionDescription +} + +type TransportParams struct { + Handler transport.Handler + ProtocolVersion types.ProtocolVersion + Config *WebRTCConfig + Twcc *lktwcc.Responder + DirectionConfig DirectionConfig + CongestionControlConfig config.CongestionControlConfig + EnabledCodecs []*livekit.Codec + Logger logger.Logger + Transport livekit.SignalTarget + SimTracks map[uint32]sfuinterceptor.SimulcastTrackInfo + ClientInfo ClientInfo + IsOfferer bool + IsSendSide bool + AllowPlayoutDelay bool + UseOneShotSignallingMode bool + FireOnTrackBySdp bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int + DatachannelLossyTargetLatency time.Duration + + // for development test + DatachannelMaxReceiverBufferSize int + + EnableDataTracks bool +} + +func newPeerConnection( + params TransportParams, + onBandwidthEstimator func(estimator cc.BandwidthEstimator), +) (*webrtc.PeerConnection, *webrtc.MediaEngine, *sfuinterceptor.RTXInfoExtractorFactory, error) { + directionConfig := params.DirectionConfig + if params.AllowPlayoutDelay { + directionConfig.RTPHeaderExtension.Video = append(directionConfig.RTPHeaderExtension.Video, pd.PlayoutDelayURI) + } + + // Some of the browser clients do not handle H.264 High Profile in signalling properly. + // They still decode if the actual stream is H.264 High Profile, but do not handle it well in signalling. + // So, disable H.264 High Profile for SUBSCRIBER peer connection to ensure it is not offered. + me, err := createMediaEngine(params.EnabledCodecs, directionConfig, params.IsOfferer) + if err != nil { + return nil, nil, nil, err + } + + se := params.Config.SettingEngine + se.DisableMediaEngineCopy(true) + // simulcast layer disable/enable signalled via signalling channel, + // so disable rid pause in SDP + se.SetIgnoreRidPauseForRecv(true) + + // Change elliptic curve to improve connectivity + // https://github.com/pion/dtls/pull/474 + se.SetDTLSEllipticCurves(elliptic.X25519, elliptic.P384, elliptic.P256) + + // Disable close by dtls to avoid peerconnection close too early in migration + // https://github.com/pion/webrtc/pull/2961 + se.DisableCloseByDTLS(true) + + se.DetachDataChannels() + if params.DatachannelSlowThreshold > 0 { + se.EnableDataChannelBlockWrite(true) + } + if params.DatachannelMaxReceiverBufferSize > 0 { + se.SetSCTPMaxReceiveBufferSize(uint32(params.DatachannelMaxReceiverBufferSize)) + } + if params.FireOnTrackBySdp { + se.SetFireOnTrackBeforeFirstRTP(true) + } + + if params.ClientInfo.SupportsSctpZeroChecksum() { + se.EnableSCTPZeroChecksum(true) + } + + // + // Disable SRTP replay protection (https://datatracker.ietf.org/doc/html/rfc3711#page-15). + // Needed due to lack of RTX stream support in Pion. + // + // When clients probe for bandwidth, there are several possible approaches + // 1. Use padding packet (Chrome uses this) + // 2. Use an older packet (Firefox uses this) + // Typically, these are sent over the RTX stream and hence SRTP replay protection will not + // trigger. As Pion does not support RTX, when firefox uses older packet for probing, they + // trigger the replay protection. + // + // That results in two issues + // - Firefox bandwidth probing is not successful + // - Pion runs out of read buffer capacity - this potentially looks like a Pion issue + // + // NOTE: It is not required to disable RTCP replay protection, but doing it to be symmetric. + // + se.DisableSRTPReplayProtection(true) + se.DisableSRTCPReplayProtection(true) + if !params.ProtocolVersion.SupportsICELite() || !params.ClientInfo.SupportsPrflxOverRelay() { + // if client don't support prflx over relay which is only Firefox, disable ICE Lite to ensure that + // aggressive nomination is handled properly. Firefox does aggressive nomination even if peer is + // ICE Lite (see comment as to historical reasons: https://github.com/pion/ice/pull/739#issuecomment-2452245066). + // pion/ice (as of v2.3.37) will accept all use-candidate switches when in ICE Lite mode. + // That combined with aggressive nomination from Firefox could potentially lead to the two ends + // ending up with different candidates. + // As Firefox does not support migration, ICE Lite can be disabled. + se.SetLite(false) + } + se.SetDTLSRetransmissionInterval(dtlsRetransmissionInterval) + se.SetICETimeouts(iceDisconnectedTimeout, iceFailedTimeout, iceKeepaliveInterval) + + // if client don't support prflx over relay, we should not expose private address to it, use single external ip as host candidate + if !params.ClientInfo.SupportsPrflxOverRelay() && len(params.Config.NAT1To1IPs) > 0 { + var nat1to1Ips []string + var includeIps []string + for _, mapping := range params.Config.NAT1To1IPs { + if ips := strings.Split(mapping, "/"); len(ips) == 2 { + if ips[0] != ips[1] { + nat1to1Ips = append(nat1to1Ips, mapping) + includeIps = append(includeIps, ips[1]) + } + } + } + if len(nat1to1Ips) > 0 { + params.Logger.Infow("client doesn't support prflx over relay, use external ip only as host candidate", "ips", nat1to1Ips) + se.SetNAT1To1IPs(nat1to1Ips, webrtc.ICECandidateTypeHost) + se.SetIPFilter(func(ip net.IP) bool { + if ip.To4() == nil { + return true + } + ipstr := ip.String() + return slices.Contains(includeIps, ipstr) + }) + } + } + + lf := pionlogger.NewLoggerFactory(params.Logger) + if lf != nil { + se.LoggerFactory = lf + } + + ir := &interceptor.Registry{} + if params.IsSendSide { + if params.CongestionControlConfig.UseSendSideBWEInterceptor && !params.CongestionControlConfig.UseSendSideBWE { + params.Logger.Infow("using send side BWE - interceptor") + gf, err := cc.NewInterceptor(func() (cc.BandwidthEstimator, error) { + return gcc.NewSendSideBWE( + gcc.SendSideBWEInitialBitrate(1*1000*1000), + gcc.SendSideBWEPacer(gcc.NewNoOpPacer()), + ) + }) + if err == nil { + gf.OnNewPeerConnection(func(id string, estimator cc.BandwidthEstimator) { + if onBandwidthEstimator != nil { + onBandwidthEstimator(estimator) + } + }) + ir.Add(gf) + + tf, err := twcc.NewHeaderExtensionInterceptor() + if err == nil { + ir.Add(tf) + } + } + } + } + if !params.IsOfferer { + // sfu only use interceptor to send XR but don't read response from it (use buffer instead), + // so use a empty callback here + ir.Add(lkinterceptor.NewRTTFromXRFactory(func(rtt uint32) {})) + } + if len(params.SimTracks) > 0 { + f, err := sfuinterceptor.NewUnhandleSimulcastInterceptorFactory(sfuinterceptor.UnhandleSimulcastTracks(params.Logger, params.SimTracks)) + if err != nil { + params.Logger.Warnw("NewUnhandleSimulcastInterceptorFactory failed", err) + } else { + ir.Add(f) + } + } + + setTWCCForVideo := func(info *interceptor.StreamInfo) { + if !mime.IsMimeTypeStringVideo(info.MimeType) { + return + } + // rtx stream don't have rtcp feedback, always set twcc for rtx stream + twccFb := mime.GetMimeTypeCodec(info.MimeType) == mime.MimeTypeCodecRTX + if !twccFb { + for _, fb := range info.RTCPFeedback { + if fb.Type == webrtc.TypeRTCPFBTransportCC { + twccFb = true + break + } + } + } + if !twccFb { + return + } + + twccExtID := sfuutils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.TransportCCURI}) + if twccExtID != 0 { + if buffer := params.Config.BufferFactory.GetBuffer(info.SSRC); buffer != nil { + params.Logger.Debugw( + "set twcc and ext id", + "ssrc", info.SSRC, + "isRTX", mime.GetMimeTypeCodec(info.MimeType) == mime.MimeTypeCodecRTX, + "twccExtID", twccExtID, + ) + buffer.SetTWCCAndExtID(params.Twcc, uint8(twccExtID)) + } else { + params.Logger.Warnw("failed to get buffer for stream", nil, "ssrc", info.SSRC) + } + } + } + rtxInfoExtractorFactory := sfuinterceptor.NewRTXInfoExtractorFactory( + setTWCCForVideo, + func(repair, base uint32, rsid string) { + params.Logger.Debugw("rtx pair found from extension", "repair", repair, "base", base, "rsid", rsid) + params.Config.BufferFactory.SetRTXPair(repair, base, rsid) + }, + params.Logger, + ) + // put rtx interceptor behind unhandle simulcast interceptor so it can get the correct mid & rid + ir.Add(rtxInfoExtractorFactory) + + api := webrtc.NewAPI( + webrtc.WithMediaEngine(me), + webrtc.WithSettingEngine(se), + webrtc.WithInterceptorRegistry(ir), + ) + pc, err := api.NewPeerConnection(params.Config.Configuration) + return pc, me, rtxInfoExtractorFactory, err +} + +func NewPCTransport(params TransportParams) (*PCTransport, error) { + if params.Logger == nil { + params.Logger = logger.GetLogger() + } + t := &PCTransport{ + params: params, + debouncedNegotiate: sfuutils.NewDebouncer(negotiationFrequency), + negotiationState: transport.NegotiationStateNone, + eventsQueue: utils.NewTypedOpsQueue[event](utils.OpsQueueParams{ + Name: "transport", + MinSize: 64, + Logger: params.Logger, + }), + previousTrackDescription: make(map[string]*trackDescription), + canReuseTransceiver: true, + connectionDetails: types.NewICEConnectionDetails(params.Transport, params.Logger), + lastNegotiate: time.Now(), + } + t.localOfferId.Store(uint32(rand.Intn(1<<8) + 1)) + + bwe, err := t.createPeerConnection() + if err != nil { + return nil, err + } + + if params.IsSendSide { + if params.CongestionControlConfig.UseSendSideBWE { + params.Logger.Infow("using send side BWE", "pacerBehavior", params.CongestionControlConfig.SendSideBWEPacer) + t.bwe = sendsidebwe.NewSendSideBWE(sendsidebwe.SendSideBWEParams{ + Config: params.CongestionControlConfig.SendSideBWE, + Logger: params.Logger, + }) + switch pacer.PacerBehavior(params.CongestionControlConfig.SendSideBWEPacer) { + case pacer.PacerBehaviorPassThrough: + t.pacer = pacer.NewPassThrough(params.Logger, t.bwe) + case pacer.PacerBehaviorNoQueue: + t.pacer = pacer.NewNoQueue(params.Logger, t.bwe) + default: + t.pacer = pacer.NewNoQueue(params.Logger, t.bwe) + } + } else { + t.bwe = remotebwe.NewRemoteBWE(remotebwe.RemoteBWEParams{ + Config: params.CongestionControlConfig.RemoteBWE, + Logger: params.Logger, + }) + t.pacer = pacer.NewPassThrough(params.Logger, nil) + } + + t.streamAllocator = streamallocator.NewStreamAllocator(streamallocator.StreamAllocatorParams{ + Config: params.CongestionControlConfig.StreamAllocator, + BWE: t.bwe, + Pacer: t.pacer, + RTTGetter: t.GetRTT, + Logger: params.Logger.WithComponent(utils.ComponentCongestionControl), + }, params.CongestionControlConfig.Enabled, params.CongestionControlConfig.AllowPause) + t.streamAllocator.OnStreamStateChange(params.Handler.OnStreamStateChange) + t.streamAllocator.Start() + + if bwe != nil { + t.streamAllocator.SetSendSideBWEInterceptor(bwe) + } + } + + t.eventsQueue.Start() + + return t, nil +} + +func (t *PCTransport) createPeerConnection() (cc.BandwidthEstimator, error) { + var bwe cc.BandwidthEstimator + pc, me, rtxInfoExtractorFactory, err := newPeerConnection(t.params, func(estimator cc.BandwidthEstimator) { + bwe = estimator + }) + if err != nil { + return bwe, err + } + + t.pc = pc + if !t.params.UseOneShotSignallingMode { + // one shot signalling mode gathers all candidates and sends in answer + t.pc.OnICEGatheringStateChange(t.onICEGatheringStateChange) + t.pc.OnICECandidate(t.onICECandidateTrickle) + } + t.pc.OnICEConnectionStateChange(t.onICEConnectionStateChange) + t.pc.OnConnectionStateChange(t.onPeerConnectionStateChange) + + t.pc.OnDataChannel(t.onDataChannel) + t.pc.OnTrack(t.params.Handler.OnTrack) + + t.iceTransport = t.pc.SCTP().Transport().ICETransport() + if t.iceTransport == nil { + return bwe, ErrNoICETransport + } + t.iceTransport.OnSelectedCandidatePairChange(func(pair *webrtc.ICECandidatePair) { + t.params.Logger.Debugw("selected ICE candidate pair changed", "pair", wrappedICECandidatePairLogger{pair}) + t.connectionDetails.SetSelectedPair(pair) + existingPair := t.selectedPair.Load() + if existingPair != nil { + t.params.Logger.Infow( + "ice reconnected or switched pair", + "existingPair", wrappedICECandidatePairLogger{existingPair}, + "newPair", wrappedICECandidatePairLogger{pair}) + } + t.selectedPair.Store(pair) + }) + + t.me = me + + t.rtxInfoExtractorFactory = rtxInfoExtractorFactory + return bwe, nil +} + +func (t *PCTransport) RTPStreamPublished(ssrc uint32, mid, rid string) { + t.rtxInfoExtractorFactory.SetStreamInfo(ssrc, mid, rid, "") +} + +func (t *PCTransport) GetPacer() pacer.Pacer { + return t.pacer +} + +func (t *PCTransport) SetSignalingRTT(rtt uint32) { + t.signalingRTT.Store(rtt) +} + +func (t *PCTransport) setICEStartedAt(at time.Time) { + t.lock.Lock() + if t.iceStartedAt.IsZero() { + t.iceStartedAt = at + + // checklist of ice agent will be cleared on ice failed, get stats before that + t.mayFailedICEStatsTimer = time.AfterFunc(iceFailedTimeoutTotal-time.Second, t.logMayFailedICEStats) + + // set failure timer for tcp ice connection based on signaling RTT + if t.preferTCP.Load() { + signalingRTT := t.signalingRTT.Load() + if signalingRTT < 1000 { + tcpICETimeout := time.Duration(signalingRTT*8) * time.Millisecond + if tcpICETimeout < minTcpICEConnectTimeout { + tcpICETimeout = minTcpICEConnectTimeout + } else if tcpICETimeout > maxTcpICEConnectTimeout { + tcpICETimeout = maxTcpICEConnectTimeout + } + t.params.Logger.Debugw("set TCP ICE connect timer", "timeout", tcpICETimeout, "signalRTT", signalingRTT) + t.tcpICETimer = time.AfterFunc(tcpICETimeout, func() { + if t.pc.ICEConnectionState() == webrtc.ICEConnectionStateChecking { + t.params.Logger.Infow("TCP ICE connect timeout", "timeout", tcpICETimeout, "signalRTT", signalingRTT) + t.logMayFailedICEStats() + t.handleConnectionFailed(true) + } + }) + } + } + } + t.lock.Unlock() +} + +func (t *PCTransport) setICEConnectedAt(at time.Time) { + t.lock.Lock() + if t.iceConnectedAt.IsZero() { + // + // Record initial connection time. + // This prevents reset of connected at time if ICE goes `Connected` -> `Disconnected` -> `Connected`. + // + t.iceConnectedAt = at + + // set failure timer for dtls handshake + iceDuration := at.Sub(t.iceStartedAt) + connTimeoutAfterICE := min(max(minConnectTimeoutAfterICE, 3*iceDuration), maxConnectTimeoutAfterICE) + t.params.Logger.Debugw("setting connection timer after ICE connected", "timeout", connTimeoutAfterICE, "iceDuration", iceDuration) + t.connectAfterICETimer = time.AfterFunc(connTimeoutAfterICE, func() { + state := t.pc.ConnectionState() + // if pc is still checking or connected but not fully established after timeout, then fire connection fail + if state != webrtc.PeerConnectionStateClosed && state != webrtc.PeerConnectionStateFailed && !t.isFullyEstablished() { + t.params.Logger.Infow("connect timeout after ICE connected", "timeout", connTimeoutAfterICE, "iceDuration", iceDuration) + t.handleConnectionFailed(false) + } + }) + + // clear tcp ice connect timer + if t.tcpICETimer != nil { + t.tcpICETimer.Stop() + t.tcpICETimer = nil + } + } + + if t.mayFailedICEStatsTimer != nil { + t.mayFailedICEStatsTimer.Stop() + t.mayFailedICEStatsTimer = nil + } + t.mayFailedICEStats = nil + t.lock.Unlock() +} + +func (t *PCTransport) logMayFailedICEStats() { + if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + return + } + + var candidatePairStats []webrtc.ICECandidatePairStats + pairStats := t.pc.GetStats() + candidateStats := make(map[string]webrtc.ICECandidateStats) + for _, stat := range pairStats { + switch stat := stat.(type) { + case webrtc.ICECandidatePairStats: + candidatePairStats = append(candidatePairStats, stat) + case webrtc.ICECandidateStats: + candidateStats[stat.ID] = stat + } + } + + iceStats := make([]iceCandidatePairStats, 0, len(candidatePairStats)) + for _, pairStat := range candidatePairStats { + iceStat := iceCandidatePairStats{ICECandidatePairStats: pairStat} + if local, ok := candidateStats[pairStat.LocalCandidateID]; ok { + iceStat.local = local + } + if remote, ok := candidateStats[pairStat.RemoteCandidateID]; ok { + remote.IP = MaybeTruncateIP(remote.IP) + iceStat.remote = remote + } + iceStats = append(iceStats, iceStat) + } + + t.lock.Lock() + t.mayFailedICEStats = iceStats + t.lock.Unlock() +} + +func (t *PCTransport) resetShortConn() { + t.params.Logger.Infow("resetting short connection on ICE restart") + t.lock.Lock() + t.iceStartedAt = time.Time{} + t.iceConnectedAt = time.Time{} + t.connectedAt = time.Time{} + if t.connectAfterICETimer != nil { + t.connectAfterICETimer.Stop() + t.connectAfterICETimer = nil + } + if t.tcpICETimer != nil { + t.tcpICETimer.Stop() + t.tcpICETimer = nil + } + t.lock.Unlock() +} + +func (t *PCTransport) IsShortConnection(at time.Time) (bool, time.Duration) { + t.lock.RLock() + defer t.lock.RUnlock() + + if t.iceConnectedAt.IsZero() { + return false, 0 + } + + duration := at.Sub(t.iceConnectedAt) + return duration < shortConnectionThreshold, duration +} + +func (t *PCTransport) setConnectedAt(at time.Time) bool { + t.lock.Lock() + t.connectedAt = at + if !t.firstConnectedAt.IsZero() { + t.lock.Unlock() + return false + } + + t.firstConnectedAt = at + prometheus.RecordServiceOperationSuccess("peer_connection") + t.lock.Unlock() + return true +} + +func (t *PCTransport) onICEGatheringStateChange(state webrtc.ICEGatheringState) { + t.params.Logger.Debugw("ice gathering state change", "state", state.String()) + if state != webrtc.ICEGatheringStateComplete { + return + } + + t.postEvent(event{ + signal: signalICEGatheringComplete, + }) +} + +func (t *PCTransport) onICECandidateTrickle(c *webrtc.ICECandidate) { + t.postEvent(event{ + signal: signalLocalICECandidate, + data: c, + }) +} + +func (t *PCTransport) handleConnectionFailed(forceShortConn bool) { + isShort := forceShortConn + if !isShort { + var duration time.Duration + isShort, duration = t.IsShortConnection(time.Now()) + if isShort { + t.params.Logger.Debugw("short ICE connection", "pair", wrappedICECandidatePairLogger{t.selectedPair.Load()}, "duration", duration) + } + } + + t.params.Handler.OnFailed(isShort, t.GetICEConnectionInfo()) +} + +func (t *PCTransport) onICEConnectionStateChange(state webrtc.ICEConnectionState) { + t.params.Logger.Debugw("ice connection state change", "state", state.String()) + switch state { + case webrtc.ICEConnectionStateConnected: + t.setICEConnectedAt(time.Now()) + + case webrtc.ICEConnectionStateChecking: + t.setICEStartedAt(time.Now()) + } +} + +func (t *PCTransport) onPeerConnectionStateChange(state webrtc.PeerConnectionState) { + t.params.Logger.Debugw("peer connection state change", "state", state.String()) + switch state { + case webrtc.PeerConnectionStateConnected: + t.clearConnTimer() + isInitialConnection := t.setConnectedAt(time.Now()) + if isInitialConnection { + t.params.Handler.OnInitialConnected() + + t.maybeNotifyFullyEstablished() + } + case webrtc.PeerConnectionStateFailed: + t.clearConnTimer() + t.handleConnectionFailed(false) + } +} + +func (t *PCTransport) onDataChannel(dc *webrtc.DataChannel) { + dc.OnOpen(func() { + t.params.Logger.Debugw(dc.Label() + " data channel open") + var kind livekit.DataPacket_Kind + var isDataTrack bool + var isUnlabeled bool + switch dc.Label() { + case ReliableDataChannel: + kind = livekit.DataPacket_RELIABLE + + case LossyDataChannel: + kind = livekit.DataPacket_LOSSY + + case DataTrackDataChannel: + isDataTrack = true + + default: + t.params.Logger.Infow("unlabeled datachannel added", "label", dc.Label()) + isUnlabeled = true + } + + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Errorw("failed to detach data channel", err, "label", dc.Label()) + return + } + + isHandled := true + t.lock.Lock() + switch { + case isUnlabeled: + t.unlabeledDataChannels = append( + t.unlabeledDataChannels, + datachannel.NewDataChannelWriterReliable(dc, rawDC, t.params.DatachannelSlowThreshold), + ) + + case isDataTrack: + if !t.params.EnableDataTracks { + t.params.Logger.Debugw("data tracks not enabled") + isHandled = false + } else { + if t.dataTrackDC != nil { + t.dataTrackDC.Close() + } + t.dataTrackDC = datachannel.NewDataChannelWriterUnreliable(dc, rawDC, 0, 0) + } + + case kind == livekit.DataPacket_RELIABLE: + if t.reliableDC != nil { + t.reliableDC.Close() + } + t.reliableDC = datachannel.NewDataChannelWriterReliable(dc, rawDC, t.params.DatachannelSlowThreshold) + t.reliableDCOpened = true + + case kind == livekit.DataPacket_LOSSY: + if t.lossyDC != nil { + t.lossyDC.Close() + } + t.lossyDC = datachannel.NewDataChannelWriterUnreliable(dc, rawDC, t.params.DatachannelLossyTargetLatency, uint64(lossyDataChannelMinBufferedAmount)) + t.lossyDCOpened = true + } + t.lock.Unlock() + + if !isHandled { + rawDC.Close() + return + } + + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "state=Closed") { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + switch { + case isUnlabeled: + t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) + + case isDataTrack: + t.params.Handler.OnDataTrackMessage(buffer[:n], mono.UnixNano()) + + default: + t.params.Handler.OnDataMessage(kind, buffer[:n]) + } + } + }() + + t.maybeNotifyFullyEstablished() + }) +} + +func (t *PCTransport) maybeNotifyFullyEstablished() { + if t.isFullyEstablished() { + t.params.Handler.OnFullyEstablished() + } +} + +func (t *PCTransport) isFullyEstablished() bool { + t.lock.RLock() + defer t.lock.RUnlock() + + dataChannelReady := t.params.UseOneShotSignallingMode || t.firstOfferNoDataChannel || (t.reliableDCOpened && t.lossyDCOpened) + + return dataChannelReady && !t.connectedAt.IsZero() +} + +func (t *PCTransport) SetPreferTCP(preferTCP bool) { + t.preferTCP.Store(preferTCP) +} + +func (t *PCTransport) AddICECandidate(candidate webrtc.ICECandidateInit) { + t.postEvent(event{ + signal: signalRemoteICECandidate, + data: &candidate, + }) +} + +func (t *PCTransport) queueOrConfigureSender( + transceiver *webrtc.RTPTransceiver, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, + enableAudioStereo bool, + enableAudioNACK bool, +) { + params := configureSenderParams{ + transceiver, + enabledCodecs, + rtcpFeedbackConfig, + !t.params.IsOfferer, + enableAudioStereo, + enableAudioNACK, + } + if !t.params.IsOfferer { + t.sendersPendingConfigMu.Lock() + t.sendersPendingConfig = append(t.sendersPendingConfig, params) + t.sendersPendingConfigMu.Unlock() + return + } + + configureSender(params) +} + +func (t *PCTransport) processSendersPendingConfig() { + t.sendersPendingConfigMu.Lock() + pending := t.sendersPendingConfig + t.sendersPendingConfig = nil + t.sendersPendingConfigMu.Unlock() + + var unprocessed []configureSenderParams + for _, p := range pending { + if p.transceiver.Mid() == "" { + unprocessed = append(unprocessed, p) + continue + } + + configureSender(p) + } + + if len(unprocessed) != 0 { + t.sendersPendingConfigMu.Lock() + t.sendersPendingConfig = append(t.sendersPendingConfig, unprocessed...) + t.sendersPendingConfigMu.Unlock() + } +} + +func (t *PCTransport) AddTrack( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, +) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { + t.lock.Lock() + canReuse := t.canReuseTransceiver + td, ok := t.previousTrackDescription[trackLocal.ID()] + if ok { + delete(t.previousTrackDescription, trackLocal.ID()) + } + t.lock.Unlock() + + // keep track use same mid after migration if possible + if td != nil && td.sender != nil { + for _, tr := range t.pc.GetTransceivers() { + if tr.Mid() == td.mid { + return td.sender, tr, tr.SetSender(td.sender, trackLocal) + } + } + } + + // if never negotiated with client, can't reuse transceiver for track not subscribed before migration + if !canReuse { + return t.AddTransceiverFromTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) + } + + sender, err = t.pc.AddTrack(trackLocal) + if err != nil { + return + } + + for _, tr := range t.pc.GetTransceivers() { + if tr.Sender() == sender { + transceiver = tr + break + } + } + + if transceiver == nil { + err = ErrNoTransceiver + return + } + + t.queueOrConfigureSender( + transceiver, + enabledCodecs, + rtcpFeedbackConfig, + params.Stereo, + !params.Red || !t.params.ClientInfo.SupportsAudioRED(), + ) + + t.adjustNumOutstandingMedia(transceiver) + return +} + +func (t *PCTransport) AddTransceiverFromTrack( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, +) (sender *webrtc.RTPSender, transceiver *webrtc.RTPTransceiver, err error) { + transceiver, err = t.pc.AddTransceiverFromTrack(trackLocal) + if err != nil { + return + } + + sender = transceiver.Sender() + if sender == nil { + err = ErrNoSender + return + } + + t.queueOrConfigureSender( + transceiver, + enabledCodecs, + rtcpFeedbackConfig, + params.Stereo, + !params.Red || !t.params.ClientInfo.SupportsAudioRED(), + ) + + t.adjustNumOutstandingMedia(transceiver) + return +} + +func (t *PCTransport) AddTransceiverFromKind( + kind webrtc.RTPCodecType, + init webrtc.RTPTransceiverInit, +) (*webrtc.RTPTransceiver, error) { + return t.pc.AddTransceiverFromKind(kind, init) +} + +func (t *PCTransport) RemoveTrack(sender *webrtc.RTPSender) error { + return t.pc.RemoveTrack(sender) +} + +func (t *PCTransport) CurrentLocalDescription() *webrtc.SessionDescription { + cld := t.pc.CurrentLocalDescription() + if cld == nil { + return nil + } + + ld := *cld + return &ld +} + +func (t *PCTransport) CurrentRemoteDescription() *webrtc.SessionDescription { + crd := t.pc.CurrentRemoteDescription() + if crd == nil { + return nil + } + + rd := *crd + return &rd +} + +func (t *PCTransport) PendingRemoteDescription() *webrtc.SessionDescription { + prd := t.pc.PendingRemoteDescription() + if prd == nil { + return nil + } + + rd := *prd + return &rd +} + +func (t *PCTransport) GetMid(rtpReceiver *webrtc.RTPReceiver) string { + tr := rtpReceiver.RTPTransceiver() + if tr != nil { + return tr.Mid() + } + + return "" +} + +func (t *PCTransport) GetRTPTransceiver(mid string) *webrtc.RTPTransceiver { + for _, tr := range t.pc.GetTransceivers() { + if tr.Mid() == mid { + return tr + } + } + + return nil +} + +func (t *PCTransport) GetRTPReceiver(mid string) *webrtc.RTPReceiver { + for _, tr := range t.pc.GetTransceivers() { + if tr.Mid() == mid { + return tr.Receiver() + } + } + + return nil +} + +func (t *PCTransport) getNumUnmatchedTransceivers() (uint32, uint32) { + if t.isClosed.Load() || t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + return 0, 0 + } + + numAudios := uint32(0) + numVideos := uint32(0) + for _, tr := range t.pc.GetTransceivers() { + if tr.Mid() != "" { + continue + } + + switch tr.Kind() { + case webrtc.RTPCodecTypeAudio: + numAudios++ + + case webrtc.RTPCodecTypeVideo: + numVideos++ + } + } + + return numAudios, numVideos +} + +func (t *PCTransport) CreateDataChannel(label string, dci *webrtc.DataChannelInit) error { + if label == DataTrackDataChannel && !t.params.EnableDataTracks { + t.params.Logger.Debugw("data tracks not enabled") + return nil + } + + dc, err := t.pc.CreateDataChannel(label, dci) + if err != nil { + return err + } + var ( + dcPtr **datachannel.DataChannelWriter[*webrtc.DataChannel] + dcReady *bool + isDataTrack bool + isUnlabeled bool + kind livekit.DataPacket_Kind + ) + switch dc.Label() { + default: + isUnlabeled = true + t.params.Logger.Infow("unlabeled datachannel added", "label", dc.Label()) + + case ReliableDataChannel: + dcPtr = &t.reliableDC + dcReady = &t.reliableDCOpened + kind = livekit.DataPacket_RELIABLE + + case LossyDataChannel: + dcPtr = &t.lossyDC + dcReady = &t.lossyDCOpened + kind = livekit.DataPacket_LOSSY + + case DataTrackDataChannel: + dcPtr = &t.dataTrackDC + isDataTrack = true + } + + dc.OnOpen(func() { + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Warnw("failed to detach data channel", err) + return + } + + var slowThreshold int + if dc.Label() == ReliableDataChannel || isUnlabeled { + slowThreshold = t.params.DatachannelSlowThreshold + } + + t.lock.Lock() + if isUnlabeled { + t.unlabeledDataChannels = append( + t.unlabeledDataChannels, + datachannel.NewDataChannelWriterReliable(dc, rawDC, slowThreshold), + ) + } else { + if *dcPtr != nil { + (*dcPtr).Close() + } + switch { + case dcPtr == &t.reliableDC: + *dcPtr = datachannel.NewDataChannelWriterReliable(dc, rawDC, slowThreshold) + case dcPtr == &t.lossyDC: + *dcPtr = datachannel.NewDataChannelWriterUnreliable(dc, rawDC, t.params.DatachannelLossyTargetLatency, uint64(lossyDataChannelMinBufferedAmount)) + case dcPtr == &t.dataTrackDC: + *dcPtr = datachannel.NewDataChannelWriterUnreliable(dc, rawDC, 0, 0) + } + if dcReady != nil { + *dcReady = true + } + } + t.lock.Unlock() + t.params.Logger.Debugw(dc.Label() + " data channel open") + + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "state=Closed") { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + switch { + case isUnlabeled: + t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) + + case isDataTrack: + t.params.Handler.OnDataTrackMessage(buffer[:n], mono.UnixNano()) + + default: + t.params.Handler.OnDataMessage(kind, buffer[:n]) + } + } + }() + + t.maybeNotifyFullyEstablished() + }) + + return nil +} + +// for testing only +func (t *PCTransport) CreateReadableDataChannel(label string, dci *webrtc.DataChannelInit) error { + dc, err := t.pc.CreateDataChannel(label, dci) + if err != nil { + return err + } + + dc.OnOpen(func() { + t.params.Logger.Debugw(dc.Label() + " data channel open") + rawDC, err := dc.DetachWithDeadline() + if err != nil { + t.params.Logger.Errorw("failed to detach data channel", err, "label", dc.Label()) + return + } + + t.lock.Lock() + t.unlabeledDataChannels = append( + t.unlabeledDataChannels, + datachannel.NewDataChannelWriterReliable(dc, rawDC, t.params.DatachannelSlowThreshold), + ) + t.lock.Unlock() + + go func() { + defer rawDC.Close() + buffer := make([]byte, dataChannelBufferSize) + for { + n, _, err := rawDC.ReadDataChannel(buffer) + if err != nil { + if !errors.Is(err, io.EOF) && !strings.Contains(err.Error(), "state=Closed") { + t.params.Logger.Warnw("error reading data channel", err, "label", dc.Label()) + } + return + } + + t.params.Handler.OnDataMessageUnlabeled(buffer[:n]) + } + }() + }) + return nil +} + +func (t *PCTransport) CreateDataChannelIfEmpty(dcLabel string, dci *webrtc.DataChannelInit) (label string, id uint16, existing bool, err error) { + if dcLabel == DataTrackDataChannel && !t.params.EnableDataTracks { + t.params.Logger.Debugw("data tracks not enabled") + err = errors.New("data tracks not enabled") + return + } + + t.lock.RLock() + var dcw *datachannel.DataChannelWriter[*webrtc.DataChannel] + switch dcLabel { + case ReliableDataChannel: + dcw = t.reliableDC + case LossyDataChannel: + dcw = t.lossyDC + case DataTrackDataChannel: + dcw = t.dataTrackDC + default: + t.params.Logger.Warnw("unknown data channel label", nil, "label", label) + err = errors.New("unknown data channel label") + } + t.lock.RUnlock() + if err != nil { + return + } + + if dcw != nil { + dc := dcw.BufferedAmountGetter() + return dc.Label(), *dc.ID(), true, nil + } + + dc, err := t.pc.CreateDataChannel(dcLabel, dci) + if err != nil { + return + } + + t.onDataChannel(dc) + return dc.Label(), *dc.ID(), false, nil +} + +func (t *PCTransport) GetRTT() (float64, bool) { + scps, ok := t.iceTransport.GetSelectedCandidatePairStats() + if !ok { + return 0.0, false + } + + return scps.CurrentRoundTripTime, true +} + +func (t *PCTransport) IsEstablished() bool { + return t.pc.ConnectionState() != webrtc.PeerConnectionStateNew +} + +func (t *PCTransport) HasEverConnected() bool { + t.lock.RLock() + defer t.lock.RUnlock() + + return !t.firstConnectedAt.IsZero() +} + +func (t *PCTransport) GetICEConnectionInfo() *types.ICEConnectionInfo { + return t.connectionDetails.GetInfo() +} + +func (t *PCTransport) GetICEConnectionType() types.ICEConnectionType { + return t.connectionDetails.GetConnectionType() +} + +func (t *PCTransport) WriteRTCP(pkts []rtcp.Packet) error { + return t.pc.WriteRTCP(pkts) +} + +func (t *PCTransport) SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error { + convertFromUserPacket := false + var dc *datachannel.DataChannelWriter[*webrtc.DataChannel] + t.lock.RLock() + if t.params.UseOneShotSignallingMode { + if len(t.unlabeledDataChannels) > 0 { + // use the first unlabeled to send + dc = t.unlabeledDataChannels[0] + } + convertFromUserPacket = true + } else { + if kind == livekit.DataPacket_RELIABLE { + dc = t.reliableDC + } else { + dc = t.lossyDC + } + } + t.lock.RUnlock() + + if convertFromUserPacket { + dp := &livekit.DataPacket{} + if err := proto.Unmarshal(data, dp); err != nil { + return err + } + + switch payload := dp.Value.(type) { + case *livekit.DataPacket_User: + return t.sendDataMessage(dc, payload.User.Payload) + default: + return errors.New("cannot forward non user data packet") + } + } + + return t.sendDataMessage(dc, data) +} + +func (t *PCTransport) SendDataMessageUnlabeled(data []byte, useRaw bool, sender livekit.ParticipantIdentity) error { + convertToUserPacket := false + var dc *datachannel.DataChannelWriter[*webrtc.DataChannel] + t.lock.RLock() + if t.params.UseOneShotSignallingMode || useRaw { + if len(t.unlabeledDataChannels) > 0 { + // use the first unlabeled to send + dc = t.unlabeledDataChannels[0] + } + } else { + if t.reliableDC != nil { + dc = t.reliableDC + } else if t.lossyDC != nil { + dc = t.lossyDC + } + + convertToUserPacket = true + } + t.lock.RUnlock() + + if convertToUserPacket { + dpData, err := proto.Marshal(&livekit.DataPacket{ + ParticipantIdentity: string(sender), + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{Payload: data}, + }, + }) + if err != nil { + return err + } + return t.sendDataMessage(dc, dpData) + } + + return t.sendDataMessage(dc, data) +} + +func (t *PCTransport) SendDataTrackMessage(data []byte) error { + t.lock.RLock() + dc := t.dataTrackDC + t.lock.RUnlock() + + return t.sendDataMessage(dc, data) +} + +func (t *PCTransport) sendDataMessage(dc *datachannel.DataChannelWriter[*webrtc.DataChannel], data []byte) error { + if dc == nil { + return ErrDataChannelUnavailable + } + + if t.pc.ConnectionState() == webrtc.PeerConnectionStateFailed { + return ErrTransportFailure + } + + if t.params.DatachannelSlowThreshold == 0 && t.params.DataChannelMaxBufferedAmount > 0 && dc.BufferedAmountGetter().BufferedAmount() > t.params.DataChannelMaxBufferedAmount { + return ErrDataChannelBufferFull + } + _, err := dc.Write(data) + return err +} + +func (t *PCTransport) Close() { + if t.isClosed.Swap(true) { + return + } + + <-t.eventsQueue.Stop() + t.clearSignalStateCheckTimer() + + if t.streamAllocator != nil { + t.streamAllocator.Stop() + } + + if t.pacer != nil { + t.pacer.Stop() + } + + t.clearConnTimer() + + t.lock.Lock() + if t.mayFailedICEStatsTimer != nil { + t.mayFailedICEStatsTimer.Stop() + t.mayFailedICEStatsTimer = nil + } + + if t.reliableDC != nil { + t.reliableDC.Close() + t.reliableDC = nil + } + + if t.lossyDC != nil { + t.lossyDC.Close() + t.lossyDC = nil + } + + if t.dataTrackDC != nil { + t.dataTrackDC.Close() + t.dataTrackDC = nil + } + + for _, dc := range t.unlabeledDataChannels { + dc.Close() + } + t.unlabeledDataChannels = nil + t.lock.Unlock() + + if err := t.pc.Close(); err != nil { + t.params.Logger.Warnw("unclean close of peer connection", err) + } + + t.outputAndClearICEStats() +} + +func (t *PCTransport) clearConnTimer() { + t.lock.Lock() + defer t.lock.Unlock() + + if t.connectAfterICETimer != nil { + t.connectAfterICETimer.Stop() + t.connectAfterICETimer = nil + } + + if t.tcpICETimer != nil { + t.tcpICETimer.Stop() + t.tcpICETimer = nil + } +} + +func (t *PCTransport) HandleRemoteDescription(sd webrtc.SessionDescription, remoteId uint32) error { + if t.params.UseOneShotSignallingMode { + if sd.Type == webrtc.SDPTypeOffer { + remoteOfferId := t.remoteOfferId.Load() + if remoteOfferId != 0 && remoteOfferId != t.localAnswerId.Load() { + t.params.Logger.Warnw( + "sdp state: multiple offers without answer", nil, + "remoteOfferId", remoteOfferId, + "localAnswerId", t.localAnswerId.Load(), + "receivedRemoteOfferId", remoteId, + ) + } + t.remoteOfferId.Store(remoteId) + } else { + if remoteId != 0 && remoteId != t.localOfferId.Load() { + t.params.Logger.Warnw("sdp state: answer id mismatch", nil, "expected", t.localOfferId.Load(), "got", remoteId) + } + t.remoteAnswerId.Store(remoteId) + } + + // add remote candidates to ICE connection details + parsed, err := sd.Unmarshal() + if err == nil { + addRemoteICECandidates := func(attrs []sdp.Attribute) { + for _, a := range attrs { + if a.IsICECandidate() { + c, err := ice.UnmarshalCandidate(a.Value) + if err != nil { + continue + } + t.connectionDetails.AddRemoteICECandidate(c, false, false, false) + } + } + } + + addRemoteICECandidates(parsed.Attributes) + for _, m := range parsed.MediaDescriptions { + addRemoteICECandidates(m.Attributes) + } + } + + err = t.pc.SetRemoteDescription(sd) + if err != nil { + t.params.Logger.Errorw("could not set remote description on synchronous mode peer connection", err) + return err + } + + rtxRepairs := nonSimulcastRTXRepairsFromSDP(parsed, t.params.Logger) + if len(rtxRepairs) > 0 { + t.params.Logger.Debugw("rtx pairs found from sdp", "ssrcs", rtxRepairs) + for repair, base := range rtxRepairs { + t.params.Config.BufferFactory.SetRTXPair(repair, base, "") + } + } + return nil + } + + t.postEvent(event{ + signal: signalRemoteDescriptionReceived, + data: remoteDescriptionData{ + sessionDescription: &sd, + remoteId: remoteId, + }, + }) + return nil +} + +func (t *PCTransport) GetAnswer() (webrtc.SessionDescription, uint32, error) { + if !t.params.UseOneShotSignallingMode { + return webrtc.SessionDescription{}, 0, ErrNotSynchronousLocalCandidatesMode + } + + prd := t.pc.PendingRemoteDescription() + if prd == nil || prd.Type != webrtc.SDPTypeOffer { + return webrtc.SessionDescription{}, 0, ErrNoRemoteDescription + } + + answer, err := t.pc.CreateAnswer(nil) + if err != nil { + return webrtc.SessionDescription{}, 0, err + } + + if err = t.pc.SetLocalDescription(answer); err != nil { + return webrtc.SessionDescription{}, 0, err + } + + // wait for gathering to complete to include all candidates in the answer + <-webrtc.GatheringCompletePromise(t.pc) + + cld := t.pc.CurrentLocalDescription() + + // add local candidates to ICE connection details + parsed, err := cld.Unmarshal() + if err == nil { + addLocalICECandidates := func(attrs []sdp.Attribute) { + for _, a := range attrs { + if a.IsICECandidate() { + c, err := ice.UnmarshalCandidate(a.Value) + if err != nil { + continue + } + t.connectionDetails.AddLocalICECandidate(c, false, false) + } + } + } + + addLocalICECandidates(parsed.Attributes) + for _, m := range parsed.MediaDescriptions { + addLocalICECandidates(m.Attributes) + } + } + + answerId := t.remoteOfferId.Load() + t.localAnswerId.Store(answerId) + + return *cld, answerId, nil +} + +func (t *PCTransport) GetICESessionUfrag() (string, error) { + cld := t.pc.CurrentLocalDescription() + if cld == nil { + return "", ErrNoLocalDescription + } + + parsed, err := cld.Unmarshal() + if err != nil { + return "", err + } + + ufrag, _, err := lksdp.ExtractICECredential(parsed) + if err != nil { + return "", err + } + + return ufrag, nil +} + +// Handles SDP Fragment for ICE Trickle in WHIP +func (t *PCTransport) HandleICETrickleSDPFragment(sdpFragment string) error { + if !t.params.UseOneShotSignallingMode { + return ErrNotSynchronousLocalCandidatesMode + } + + parsedFragment := &lksdp.SDPFragment{} + if err := parsedFragment.Unmarshal(sdpFragment); err != nil { + t.params.Logger.Warnw("could not parse SDP fragment", err, "sdpFragment", sdpFragment) + return ErrInvalidSDPFragment + } + + crd := t.pc.CurrentRemoteDescription() + if crd == nil { + t.params.Logger.Warnw("no remote description", nil) + return ErrNoRemoteDescription + } + + parsedRemote, err := crd.Unmarshal() + if err != nil { + t.params.Logger.Warnw("could not parse remote description", err, "offer", crd) + return err + } + + // check if BUNDLE mid matches the "mid" in the SDP fragment + bundleMid, found := lksdp.GetBundleMid(parsedRemote) + if !found { + return ErrNoBundleMid + } + + if parsedFragment.Mid() != bundleMid { + t.params.Logger.Warnw("incorrect mid", nil, "sdpFragment", sdpFragment) + return ErrMidMismatch + } + + fragmentICEUfrag, fragmentICEPwd, err := parsedFragment.ExtractICECredential() + if err != nil { + t.params.Logger.Warnw( + "could not get ICE crendential from fragment", err, + "sdpFragment", sdpFragment, + ) + return ErrInvalidSDPFragment + } + remoteICEUfrag, remoteICEPwd, err := lksdp.ExtractICECredential(parsedRemote) + if err != nil { + t.params.Logger.Warnw("could not get ICE crendential from remote description", err, "sdpFragment", sdpFragment, "remoteDescription", crd) + return err + } + if fragmentICEUfrag != "" && fragmentICEUfrag != remoteICEUfrag { + t.params.Logger.Warnw( + "ice ufrag mismatch", nil, + "remoteICEUfrag", remoteICEUfrag, + "fragmentICEUfrag", fragmentICEUfrag, + "sdpFragment", sdpFragment, + "remoteDescription", crd, + ) + return ErrICECredentialMismatch + } + if fragmentICEPwd != "" && fragmentICEPwd != remoteICEPwd { + t.params.Logger.Warnw( + "ice pwd mismatch", nil, + "remoteICEPwd", remoteICEPwd, + "fragmentICEPwd", fragmentICEPwd, + "sdpFragment", sdpFragment, + "remoteDescription", crd, + ) + return ErrICECredentialMismatch + } + + // add candidates from media description + for _, ic := range parsedFragment.Candidates() { + c, err := ice.UnmarshalCandidate(ic) + if err == nil { + t.connectionDetails.AddRemoteICECandidate(c, false, false, false) + } + + candidate := webrtc.ICECandidateInit{ + Candidate: ic, + } + if err := t.pc.AddICECandidate(candidate); err != nil { + t.params.Logger.Warnw("failed to add ICE candidate", err, "candidate", candidate) + } else { + t.params.Logger.Debugw("added ICE candidate", "candidate", candidate) + } + } + return nil +} + +// Handles SDP Fragment for ICE Restart in WHIP +func (t *PCTransport) HandleICERestartSDPFragment(sdpFragment string) (string, error) { + if !t.params.UseOneShotSignallingMode { + return "", ErrNotSynchronousLocalCandidatesMode + } + + parsedFragment := &lksdp.SDPFragment{} + if err := parsedFragment.Unmarshal(sdpFragment); err != nil { + t.params.Logger.Warnw("could not parse SDP fragment", err, "sdpFragment", sdpFragment) + return "", ErrInvalidSDPFragment + } + + crd := t.pc.CurrentRemoteDescription() + if crd == nil { + t.params.Logger.Warnw("no remote description", nil) + return "", ErrNoRemoteDescription + } + + parsedRemote, err := crd.Unmarshal() + if err != nil { + t.params.Logger.Warnw("could not parse remote description", err, "offer", crd) + return "", err + } + + if err := parsedFragment.PatchICECredentialAndCandidatesIntoSDP(parsedRemote); err != nil { + t.params.Logger.Warnw("could not patch SDP fragment into remote description", err, "offer", crd, "sdpFragment", sdpFragment) + return "", err + } + + bytes, err := parsedRemote.Marshal() + if err != nil { + t.params.Logger.Warnw("could not marshal SDP with patched remote", err) + return "", err + } + sd := webrtc.SessionDescription{ + SDP: string(bytes), + Type: webrtc.SDPTypeOffer, + } + if err := t.pc.SetRemoteDescription(sd); err != nil { + t.params.Logger.Warnw("could not set remote description", err) + return "", err + } + + // clear out connection details on ICE restart and re-populate + t.connectionDetails.Clear() + for _, candidate := range parsedFragment.Candidates() { + c, err := ice.UnmarshalCandidate(candidate) + if err != nil { + continue + } + t.connectionDetails.AddRemoteICECandidate(c, false, false, false) + } + + ans, err := t.pc.CreateAnswer(nil) + if err != nil { + t.params.Logger.Warnw("could not create answer", err) + return "", err + } + + if err = t.pc.SetLocalDescription(ans); err != nil { + t.params.Logger.Warnw("could not set local description", err) + return "", err + } + + // wait for gathering to complete to include all candidates in the answer + <-webrtc.GatheringCompletePromise(t.pc) + + cld := t.pc.CurrentLocalDescription() + + // add local candidates to ICE connection details + parsedAnswer, err := cld.Unmarshal() + if err != nil { + t.params.Logger.Warnw("could not parse local description", err) + return "", err + } + + addLocalICECandidates := func(attrs []sdp.Attribute) { + for _, a := range attrs { + if a.IsICECandidate() { + c, err := ice.UnmarshalCandidate(a.Value) + if err != nil { + continue + } + t.connectionDetails.AddLocalICECandidate(c, false, false) + } + } + } + + addLocalICECandidates(parsedAnswer.Attributes) + for _, m := range parsedAnswer.MediaDescriptions { + addLocalICECandidates(m.Attributes) + } + + parsedFragmentAnswer, err := lksdp.ExtractSDPFragment(parsedAnswer) + if err != nil { + t.params.Logger.Warnw("could not extract SDP fragment", err) + return "", err + } + + answerFragment, err := parsedFragmentAnswer.Marshal() + if err != nil { + t.params.Logger.Warnw("could not marshal answer SDP fragment", err) + return "", err + } + + return answerFragment, nil +} + +func (t *PCTransport) OnNegotiationStateChanged(f func(state transport.NegotiationState)) { + t.lock.Lock() + t.onNegotiationStateChanged = f + t.lock.Unlock() +} + +func (t *PCTransport) getOnNegotiationStateChanged() func(state transport.NegotiationState) { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.onNegotiationStateChanged +} + +func (t *PCTransport) Negotiate(force bool) { + if t.isClosed.Load() { + return + } + + var postEvent bool + t.lock.Lock() + if force { + t.debouncedNegotiate.Add(func() { + // no op to cancel pending negotiation + }) + t.debouncePending = false + t.updateLastNegotiateLocked() + + postEvent = true + } else { + if !t.debouncePending { + if time.Since(t.lastNegotiate) > negotiationFrequency { + t.debouncedNegotiate.SetDuration(fastNegotiationFrequency) + } else { + t.debouncedNegotiate.SetDuration(negotiationFrequency) + } + + t.debouncedNegotiate.Add(func() { + t.lock.Lock() + t.debouncePending = false + t.updateLastNegotiateLocked() + t.lock.Unlock() + + t.postEvent(event{ + signal: signalSendOffer, + }) + }) + t.debouncePending = true + } + } + t.lock.Unlock() + + if postEvent { + t.postEvent(event{ + signal: signalSendOffer, + }) + } +} + +func (t *PCTransport) updateLastNegotiateLocked() { + if now := time.Now(); now.After(t.lastNegotiate) { + t.lastNegotiate = now + } +} + +func (t *PCTransport) ICERestart() error { + if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + t.params.Logger.Warnw("trying to restart ICE on closed peer connection", nil) + return ErrIceRestartOnClosedPeerConnection + } + + t.postEvent(event{ + signal: signalICERestart, + }) + return nil +} + +func (t *PCTransport) ResetShortConnOnICERestart() { + t.resetShortConnOnICERestart.Store(true) +} + +func (t *PCTransport) AddTrackToStreamAllocator(subTrack types.SubscribedTrack) { + if t.streamAllocator == nil { + return + } + + layers := buffer.GetVideoLayersForMimeType( + subTrack.DownTrack().Mime(), + subTrack.MediaTrack().ToProto(), + ) + t.streamAllocator.AddTrack(subTrack.DownTrack(), streamallocator.AddTrackParams{ + Source: subTrack.MediaTrack().Source(), + IsMultiLayered: len(layers) > 1, + PublisherID: subTrack.MediaTrack().PublisherID(), + }) +} + +func (t *PCTransport) RemoveTrackFromStreamAllocator(subTrack types.SubscribedTrack) { + if t.streamAllocator == nil { + return + } + + t.streamAllocator.RemoveTrack(subTrack.DownTrack()) +} + +func (t *PCTransport) SetAllowPauseOfStreamAllocator(allowPause bool) { + if t.streamAllocator == nil { + return + } + + t.streamAllocator.SetAllowPause(allowPause) +} + +func (t *PCTransport) SetChannelCapacityOfStreamAllocator(channelCapacity int64) { + if t.streamAllocator == nil { + return + } + + t.streamAllocator.SetChannelCapacity(channelCapacity) +} + +func (t *PCTransport) preparePC(previousAnswer webrtc.SessionDescription) error { + // sticky data channel to first m-lines, if someday we don't send sdp without media streams to + // client's subscribe pc after joining, should change this step + parsed, err := previousAnswer.Unmarshal() + if err != nil { + return err + } + fp, fpHahs, err := lksdp.ExtractFingerprint(parsed) + if err != nil { + return err + } + + offer, err := t.pc.CreateOffer(nil) + if err != nil { + return err + } + if err := t.pc.SetLocalDescription(offer); err != nil { + return err + } + + // + // Simulate client side peer connection and set DTLS role from previous answer. + // Role needs to be set properly (one side needs to be server and the other side + // needs to be the client) for DTLS connection to form properly. As this is + // trying to replicate previous setup, read from previous answer and use that role. + // + se := webrtc.SettingEngine{} + _ = se.SetAnsweringDTLSRole(lksdp.ExtractDTLSRole(parsed)) + se.SetIgnoreRidPauseForRecv(true) + api := webrtc.NewAPI( + webrtc.WithSettingEngine(se), + webrtc.WithMediaEngine(t.me), + ) + pc2, err := api.NewPeerConnection(webrtc.Configuration{ + SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, + }) + if err != nil { + return err + } + defer pc2.Close() + + if err := pc2.SetRemoteDescription(offer); err != nil { + return err + } + ans, err := pc2.CreateAnswer(nil) + if err != nil { + return err + } + + // replace client's fingerprint into dummy pc's answer, for pion's dtls process, it will + // keep the fingerprint at first call of SetRemoteDescription, if dummy pc and client pc use + // different fingerprint, that will cause pion denied dtls data after handshake with client + // complete (can't pass fingerprint change). + // in this step, we don't established connection with dummy pc(no candidate swap), just use + // sdp negotiation to sticky data channel and keep client's fingerprint + parsedAns, _ := ans.Unmarshal() + fpLine := fpHahs + " " + fp + replaceFP := func(attrs []sdp.Attribute, fpLine string) { + for k := range attrs { + if attrs[k].Key == "fingerprint" { + attrs[k].Value = fpLine + } + } + } + replaceFP(parsedAns.Attributes, fpLine) + for _, m := range parsedAns.MediaDescriptions { + replaceFP(m.Attributes, fpLine) + } + bytes, err := parsedAns.Marshal() + if err != nil { + return err + } + ans.SDP = string(bytes) + + return t.pc.SetRemoteDescription(ans) +} + +func (t *PCTransport) initPCWithPreviousAnswer(previousAnswer webrtc.SessionDescription) (map[string]*webrtc.RTPSender, error) { + senders := make(map[string]*webrtc.RTPSender) + parsed, err := previousAnswer.Unmarshal() + if err != nil { + return senders, err + } + for _, m := range parsed.MediaDescriptions { + var codecType webrtc.RTPCodecType + switch m.MediaName.Media { + case "video": + codecType = webrtc.RTPCodecTypeVideo + case "audio": + codecType = webrtc.RTPCodecTypeAudio + case "application": + if t.params.IsOfferer { + // for pion generate unmatched sdp, it always appends data channel to last m-lines, + // that not consistent with our previous answer that data channel might at middle-line + // because sdp can negotiate multi times before migration.(it will sticky to the last m-line at first negotiate) + // so use a dummy pc to negotiate sdp to fixed the datachannel's mid at same position with previous answer + if err := t.preparePC(previousAnswer); err != nil { + t.params.Logger.Warnw("prepare pc for migration failed", err) + return senders, err + } + } + continue + default: + continue + } + + if !t.params.IsOfferer { + // `sendrecv` or `sendonly` means this transceiver is used for sending + + // Note that a transceiver previously used to send could be `inactive`. + // Let those transceivers be created when remote description is set. + _, ok1 := m.Attribute(webrtc.RTPTransceiverDirectionSendrecv.String()) + _, ok2 := m.Attribute(webrtc.RTPTransceiverDirectionSendonly.String()) + if !ok1 && !ok2 { + continue + } + } + + tr, err := t.pc.AddTransceiverFromKind( + codecType, + webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionSendonly, + }, + ) + if err != nil { + return senders, err + } + mid := lksdp.GetMidValue(m) + if mid == "" { + return senders, ErrMidNotFound + } + tr.SetMid(mid) + + // save mid -> senders for migration reuse + sender := tr.Sender() + senders[mid] = sender + + // set transceiver to inactive + tr.SetSender(sender, nil) + } + return senders, nil +} + +func (t *PCTransport) SetPreviousSdp(localDescription, remoteDescription *webrtc.SessionDescription) { + // when there is no answer, cannot migrate, force a full reconnect + if (t.params.IsOfferer && remoteDescription == nil) || (!t.params.IsOfferer && localDescription == nil) { + t.onNegotiationFailed(true, "no previous answer") + return + } + + t.lock.Lock() + var ( + senders map[string]*webrtc.RTPSender + err error + parseMids bool + ) + if t.params.IsOfferer { + if t.pc.RemoteDescription() == nil && t.previousAnswer == nil { + t.previousAnswer = remoteDescription + senders, err = t.initPCWithPreviousAnswer(*remoteDescription) + parseMids = true + } + } else { + if t.pc.LocalDescription() == nil { + senders, err = t.initPCWithPreviousAnswer(*localDescription) + parseMids = true + } + } + if err != nil { + t.lock.Unlock() + t.onNegotiationFailed(true, fmt.Sprintf("initPCWithPreviousAnswer failed, error: %s", err)) + return + } + + if localDescription != nil && parseMids { + // in migration case, can't reuse transceiver before negotiating excepted tracks + // that were subscribed at previous node + t.canReuseTransceiver = false + if err := t.parseTrackMid(*localDescription, senders); err != nil { + t.params.Logger.Warnw( + "parse previous local description failed", err, + "localDescription", localDescription.SDP, + ) + } + } + + if t.params.IsOfferer { + // disable fast negotiation temporarily after migration to avoid sending offer + // contains part of subscribed tracks before migration, let the subscribed track + // resume at the same time. + t.lastNegotiate = time.Now().Add(iceFailedTimeoutTotal) + } + t.lock.Unlock() +} + +func (t *PCTransport) parseTrackMid(sd webrtc.SessionDescription, senders map[string]*webrtc.RTPSender) error { + parsed, err := sd.Unmarshal() + if err != nil { + return err + } + + t.previousTrackDescription = make(map[string]*trackDescription) + for _, m := range parsed.MediaDescriptions { + msid, ok := m.Attribute(sdp.AttrKeyMsid) + if !ok { + continue + } + + if split := strings.Split(msid, " "); len(split) == 2 { + trackID := split[1] + mid := lksdp.GetMidValue(m) + if mid == "" { + return ErrMidNotFound + } + if sender, ok := senders[mid]; ok { + t.previousTrackDescription[trackID] = &trackDescription{mid, sender} + } + } + } + return nil +} + +func (t *PCTransport) postEvent(e event) { + e.PCTransport = t + t.eventsQueue.Enqueue(func(e event) { + var err error + switch e.signal { + case signalICEGatheringComplete: + err = e.handleICEGatheringComplete(e) + case signalLocalICECandidate: + err = e.handleLocalICECandidate(e) + case signalRemoteICECandidate: + err = e.handleRemoteICECandidate(e) + case signalSendOffer: + err = e.handleSendOffer(e) + case signalRemoteDescriptionReceived: + err = e.handleRemoteDescriptionReceived(e) + case signalICERestart: + err = e.handleICERestart(e) + } + if err != nil { + if !e.isClosed.Load() { + e.onNegotiationFailed(true, fmt.Sprintf("error handling event. err: %s, event: %s", err, e)) + } + } + }, e) +} + +func (t *PCTransport) handleICEGatheringComplete(_ event) error { + if t.params.IsOfferer { + return t.handleICEGatheringCompleteOfferer() + } else { + return t.handleICEGatheringCompleteAnswerer() + } +} + +func (t *PCTransport) handleICEGatheringCompleteOfferer() error { + if !t.restartAfterGathering { + return nil + } + + t.params.Logger.Debugw("restarting ICE after ICE gathering") + t.restartAfterGathering = false + return t.doICERestart() +} + +func (t *PCTransport) handleICEGatheringCompleteAnswerer() error { + if t.pendingRestartIceOffer == nil { + return nil + } + + offer := *t.pendingRestartIceOffer + t.pendingRestartIceOffer = nil + + t.params.Logger.Debugw("accept remote restart ice offer after ICE gathering") + if err := t.setRemoteDescription(offer); err != nil { + return err + } + t.params.Handler.OnSetRemoteDescriptionOffer() + t.processSendersPendingConfig() + + return t.createAndSendAnswer() +} + +func (t *PCTransport) localDescriptionSent() error { + if !t.cacheLocalCandidates { + return nil + } + + t.cacheLocalCandidates = false + + cachedLocalCandidates := t.cachedLocalCandidates + t.cachedLocalCandidates = nil + + for _, c := range cachedLocalCandidates { + if err := t.params.Handler.OnICECandidate(c, t.params.Transport); err != nil { + t.params.Logger.Warnw("failed to send cached ICE candidate", err, "candidate", c) + return err + } + } + return nil +} + +func (t *PCTransport) clearLocalDescriptionSent() { + t.cacheLocalCandidates = true + t.cachedLocalCandidates = nil + t.connectionDetails.Clear() +} + +func (t *PCTransport) handleLocalICECandidate(e event) error { + c := e.data.(*webrtc.ICECandidate) + + filtered := false + if c != nil { + if t.preferTCP.Load() && c.Protocol != webrtc.ICEProtocolTCP { + t.params.Logger.Debugw("filtering out local candidate", "candidate", c.String()) + filtered = true + } + t.connectionDetails.AddLocalCandidate(c, filtered, true) + } + + if filtered { + return nil + } + + if t.cacheLocalCandidates { + t.cachedLocalCandidates = append(t.cachedLocalCandidates, c) + return nil + } + + if err := t.params.Handler.OnICECandidate(c, t.params.Transport); err != nil { + t.params.Logger.Warnw("failed to send ICE candidate", err, "candidate", c) + return err + } + + return nil +} + +func (t *PCTransport) handleRemoteICECandidate(e event) error { + c := e.data.(*webrtc.ICECandidateInit) + + filtered := false + if t.preferTCP.Load() && !strings.Contains(strings.ToLower(c.Candidate), "tcp") { + t.params.Logger.Debugw("filtering out remote candidate", "candidate", c.Candidate) + filtered = true + } + + if !t.params.Config.UseMDNS && types.IsCandidateMDNS(*c) { + t.params.Logger.Debugw("ignoring mDNS candidate", "candidate", c.Candidate) + filtered = true + } + + t.connectionDetails.AddRemoteCandidate(*c, filtered, true, false) + if filtered { + return nil + } + + if t.pc.RemoteDescription() == nil { + t.pendingRemoteCandidates = append(t.pendingRemoteCandidates, c) + return nil + } + + if err := t.pc.AddICECandidate(*c); err != nil { + t.params.Logger.Warnw("failed to add ICE candidate", err, "candidate", c) + return errors.Wrap(err, "add ice candidate failed") + } else { + t.params.Logger.Debugw("added ICE candidate", "candidate", c) + } + + return nil +} + +func (t *PCTransport) setNegotiationState(state transport.NegotiationState) { + t.negotiationState = state + if onNegotiationStateChanged := t.getOnNegotiationStateChanged(); onNegotiationStateChanged != nil { + onNegotiationStateChanged(t.negotiationState) + } +} + +func (t *PCTransport) filterCandidates(sd webrtc.SessionDescription, preferTCP, isLocal bool) webrtc.SessionDescription { + parsed, err := sd.Unmarshal() + if err != nil { + t.params.Logger.Warnw("could not unmarshal SDP to filter candidates", err) + return sd + } + + filterAttributes := func(attrs []sdp.Attribute) []sdp.Attribute { + filteredAttrs := make([]sdp.Attribute, 0, len(attrs)) + for _, a := range attrs { + if a.IsICECandidate() { + c, err := ice.UnmarshalCandidate(a.Value) + if err != nil { + t.params.Logger.Errorw("failed to unmarshal candidate in sdp", err, "isLocal", isLocal, "sdp", sd.SDP) + filteredAttrs = append(filteredAttrs, a) + continue + } + excluded := preferTCP && !c.NetworkType().IsTCP() + if !excluded { + if !t.params.Config.UseMDNS && types.IsICECandidateMDNS(c) { + excluded = true + } + } + if !excluded { + filteredAttrs = append(filteredAttrs, a) + } + + if isLocal { + t.connectionDetails.AddLocalICECandidate(c, excluded, false) + } else { + t.connectionDetails.AddRemoteICECandidate(c, excluded, false, false) + } + } else { + filteredAttrs = append(filteredAttrs, a) + } + } + + return filteredAttrs + } + + parsed.Attributes = filterAttributes(parsed.Attributes) + for _, m := range parsed.MediaDescriptions { + m.Attributes = filterAttributes(m.Attributes) + } + + bytes, err := parsed.Marshal() + if err != nil { + t.params.Logger.Warnw("could not marshal SDP to filter candidates", err) + return sd + } + sd.SDP = string(bytes) + return sd +} + +func (t *PCTransport) clearSignalStateCheckTimer() { + if t.signalStateCheckTimer != nil { + t.signalStateCheckTimer.Stop() + t.signalStateCheckTimer = nil + } +} + +func (t *PCTransport) setupSignalStateCheckTimer() { + t.clearSignalStateCheckTimer() + + negotiateVersion := t.negotiateCounter.Inc() + t.signalStateCheckTimer = time.AfterFunc(negotiationFailedTimeout, func() { + t.clearSignalStateCheckTimer() + + failed := t.negotiationState != transport.NegotiationStateNone + + if t.negotiateCounter.Load() == negotiateVersion && failed && t.pc.ConnectionState() == webrtc.PeerConnectionStateConnected { + t.onNegotiationFailed(false, "negotiation timed out") + } + }) +} + +func (t *PCTransport) adjustNumOutstandingMedia(transceiver *webrtc.RTPTransceiver) { + if transceiver.Mid() != "" { + return + } + + t.lock.Lock() + if transceiver.Kind() == webrtc.RTPCodecTypeAudio { + t.numOutstandingAudios++ + } else { + t.numOutstandingVideos++ + } + t.lock.Unlock() +} + +func (t *PCTransport) sendUnmatchedMediaRequirement(force bool) error { + // if there are unmatched media sections, notify remote peer to generate offer with + // enough media section in subsequent offers + t.lock.Lock() + numAudios := t.numOutstandingAudios - t.numRequestSentAudios + t.numRequestSentAudios += numAudios + + numVideos := t.numOutstandingVideos - t.numRequestSentVideos + t.numRequestSentVideos += numVideos + t.lock.Unlock() + + if force || (numAudios+numVideos) != 0 { + if err := t.params.Handler.OnUnmatchedMedia(numAudios, numVideos); err != nil { + return errors.Wrap(err, "could not send unmatched media requirements") + } + } + + return nil +} + +func (t *PCTransport) createAndSendOffer(options *webrtc.OfferOptions) error { + if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + t.params.Logger.Warnw("trying to send offer on closed peer connection", nil) + return nil + } + + // when there's an ongoing negotiation, let it finish and not disrupt its state + if t.negotiationState == transport.NegotiationStateRemote { + t.params.Logger.Debugw("skipping negotiation, trying again later") + t.setNegotiationState(transport.NegotiationStateRetry) + return nil + } else if t.negotiationState == transport.NegotiationStateRetry { + // already set to retry, we can safely skip this attempt + return nil + } + + ensureICERestart := func(options *webrtc.OfferOptions) *webrtc.OfferOptions { + if options == nil { + options = &webrtc.OfferOptions{} + } + options.ICERestart = true + return options + } + + t.lock.Lock() + if t.previousAnswer != nil { + t.previousAnswer = nil + options = ensureICERestart(options) + t.params.Logger.Infow("ice restart due to previous answer") + } + t.lock.Unlock() + + if t.restartAtNextOffer { + t.restartAtNextOffer = false + options = ensureICERestart(options) + t.params.Logger.Infow("ice restart at next offer") + } + + if options != nil && options.ICERestart { + t.clearLocalDescriptionSent() + } + + offer, err := t.pc.CreateOffer(options) + if err != nil { + if errors.Is(err, webrtc.ErrConnectionClosed) { + t.params.Logger.Warnw("trying to create offer on closed peer connection", nil) + return nil + } + + prometheus.RecordServiceOperationError("offer", "create") + return errors.Wrap(err, "create offer failed") + } + + preferTCP := t.preferTCP.Load() + if preferTCP { + t.params.Logger.Debugw("local offer (unfiltered)", "sdp", offer.SDP) + } + + err = t.pc.SetLocalDescription(offer) + if err != nil { + if errors.Is(err, webrtc.ErrConnectionClosed) { + t.params.Logger.Warnw("trying to set local description on closed peer connection", nil) + return nil + } + + prometheus.RecordServiceOperationError("offer", "local_description") + return errors.Wrap(err, "setting local description failed") + } + + // + // Filter after setting local description as pion expects the offer + // to match between CreateOffer and SetLocalDescription. + // Filtered offer is sent to remote so that remote does not + // see filtered candidates. + // + offer = t.filterCandidates(offer, preferTCP, true) + if preferTCP { + t.params.Logger.Debugw("local offer (filtered)", "sdp", offer.SDP) + } + + // indicate waiting for remote + t.setNegotiationState(transport.NegotiationStateRemote) + + t.setupSignalStateCheckTimer() + + remoteAnswerId := t.remoteAnswerId.Load() + if remoteAnswerId != 0 && remoteAnswerId != t.localOfferId.Load() { + t.params.Logger.Warnw( + "sdp state: sending offer before receiving answer", nil, + "localOfferId", t.localOfferId.Load(), + "remoteAnswerId", remoteAnswerId, + ) + } + + if err := t.params.Handler.OnOffer(offer, t.localOfferId.Inc(), t.getMidToTrackIDMapping()); err != nil { + prometheus.RecordServiceOperationError("offer", "write_message") + return errors.Wrap(err, "could not send offer") + } + prometheus.RecordServiceOperationSuccess("offer") + + return t.localDescriptionSent() +} + +func (t *PCTransport) handleSendOffer(_ event) error { + if !t.params.IsOfferer { + return t.sendUnmatchedMediaRequirement(true) + } + + return t.createAndSendOffer(nil) +} + +type remoteDescriptionData struct { + sessionDescription *webrtc.SessionDescription + remoteId uint32 +} + +func (t *PCTransport) handleRemoteDescriptionReceived(e event) error { + rdd := e.data.(remoteDescriptionData) + if rdd.sessionDescription.Type == webrtc.SDPTypeOffer { + return t.handleRemoteOfferReceived(rdd.sessionDescription, rdd.remoteId) + } else { + return t.handleRemoteAnswerReceived(rdd.sessionDescription, rdd.remoteId) + } +} + +func (t *PCTransport) isRemoteOfferRestartICE(parsed *sdp.SessionDescription) (string, bool, error) { + user, pwd, err := lksdp.ExtractICECredential(parsed) + if err != nil { + return "", false, err + } + + credential := fmt.Sprintf("%s:%s", user, pwd) + // ice credential changed, remote offer restart ice + restartICE := t.currentOfferIceCredential != "" && t.currentOfferIceCredential != credential + return credential, restartICE, nil +} + +func (t *PCTransport) setRemoteDescription(sd webrtc.SessionDescription) error { + // filter before setting remote description so that pion does not see filtered remote candidates + preferTCP := t.preferTCP.Load() + if preferTCP { + t.params.Logger.Debugw("remote description (unfiltered)", "type", sd.Type, "sdp", sd.SDP) + } + sd = t.filterCandidates(sd, preferTCP, false) + if preferTCP { + t.params.Logger.Debugw("remote description (filtered)", "type", sd.Type, "sdp", sd.SDP) + } + + if err := t.pc.SetRemoteDescription(sd); err != nil { + if errors.Is(err, webrtc.ErrConnectionClosed) { + t.params.Logger.Warnw("trying to set remote description on closed peer connection", nil) + return nil + } + + sdpType := "offer" + if sd.Type == webrtc.SDPTypeAnswer { + sdpType = "answer" + } + prometheus.RecordServiceOperationError(sdpType, "remote_description") + return errors.Wrap(err, "setting remote description failed") + } else if sd.Type == webrtc.SDPTypeAnswer { + t.lock.Lock() + if !t.canReuseTransceiver { + t.canReuseTransceiver = true + t.previousTrackDescription = make(map[string]*trackDescription) + } + t.lock.Unlock() + } + + for _, c := range t.pendingRemoteCandidates { + if err := t.pc.AddICECandidate(*c); err != nil { + t.params.Logger.Warnw("failed to add cached ICE candidate", err, "candidate", c) + return errors.Wrap(err, "add ice candidate failed") + } else { + t.params.Logger.Debugw("added cached ICE candidate", "candidate", c) + } + } + t.pendingRemoteCandidates = nil + + return nil +} + +func (t *PCTransport) createAndSendAnswer() error { + numOutstandingAudios, numOutstandingVideos := t.getNumUnmatchedTransceivers() + t.lock.Lock() + t.numOutstandingAudios, t.numOutstandingVideos = numOutstandingAudios, numOutstandingVideos + t.numRequestSentAudios, t.numRequestSentVideos = 0, 0 + t.lock.Unlock() + + answer, err := t.pc.CreateAnswer(nil) + if err != nil { + if errors.Is(err, webrtc.ErrConnectionClosed) { + t.params.Logger.Warnw("trying to create answer on closed peer connection", nil) + return nil + } + + prometheus.RecordServiceOperationError("answer", "create") + return errors.Wrap(err, "create answer failed") + } + + preferTCP := t.preferTCP.Load() + if preferTCP { + t.params.Logger.Debugw("local answer (unfiltered)", "sdp", answer.SDP) + } + + if err = t.pc.SetLocalDescription(answer); err != nil { + prometheus.RecordServiceOperationError("answer", "local_description") + return errors.Wrap(err, "setting local description failed") + } + + // + // Filter after setting local description as pion expects the answer + // to match between CreateAnswer and SetLocalDescription. + // Filtered answer is sent to remote so that remote does not + // see filtered candidates. + // + answer = t.filterCandidates(answer, preferTCP, true) + if preferTCP { + t.params.Logger.Debugw("local answer (filtered)", "sdp", answer.SDP) + } + + localAnswerId := t.localAnswerId.Load() + if localAnswerId != 0 && localAnswerId >= t.remoteOfferId.Load() { + t.params.Logger.Warnw( + "sdp state: duplicate answer", nil, + "localAnswerId", localAnswerId, + "remoteOfferId", t.remoteOfferId.Load(), + ) + } + + answerId := t.remoteOfferId.Load() + + if err := t.params.Handler.OnAnswer(answer, answerId, t.getMidToTrackIDMapping()); err != nil { + prometheus.RecordServiceOperationError("answer", "write_message") + return errors.Wrap(err, "could not send answer") + } + t.localAnswerId.Store(answerId) + prometheus.RecordServiceOperationSuccess("asnwer") + + if err := t.sendUnmatchedMediaRequirement(false); err != nil { + return err + } + + t.lock.Lock() + if !t.canReuseTransceiver { + t.canReuseTransceiver = true + t.previousTrackDescription = make(map[string]*trackDescription) + } + t.lock.Unlock() + + return t.localDescriptionSent() +} + +func (t *PCTransport) handleRemoteOfferReceived(sd *webrtc.SessionDescription, offerId uint32) error { + t.params.Logger.Debugw("processing offer", "offerId", offerId) + remoteOfferId := t.remoteOfferId.Load() + if remoteOfferId != 0 && remoteOfferId != t.localAnswerId.Load() { + t.params.Logger.Warnw( + "sdp state: multiple offers without answer", nil, + "remoteOfferId", remoteOfferId, + "localAnswerId", t.localAnswerId.Load(), + "receivedRemoteOfferId", offerId, + ) + } + t.remoteOfferId.Store(offerId) + + parsed, err := sd.Unmarshal() + if err != nil { + return err + } + + t.lock.Lock() + if !t.firstOfferReceived { + t.firstOfferReceived = true + var dataChannelFound bool + for _, media := range parsed.MediaDescriptions { + if strings.EqualFold(media.MediaName.Media, "application") { + dataChannelFound = true + break + } + } + t.firstOfferNoDataChannel = !dataChannelFound + } + t.lock.Unlock() + + iceCredential, offerRestartICE, err := t.isRemoteOfferRestartICE(parsed) + if err != nil { + return errors.Wrap(err, "check remote offer restart ice failed") + } + + if offerRestartICE && t.pendingRestartIceOffer == nil { + t.clearLocalDescriptionSent() + } + + if offerRestartICE && t.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { + t.params.Logger.Debugw("remote offer restart ice while ice gathering") + t.pendingRestartIceOffer = sd + return nil + } + + if offerRestartICE && t.resetShortConnOnICERestart.CompareAndSwap(true, false) { + t.resetShortConn() + } + + if offerRestartICE { + t.outputAndClearICEStats() + } + + if err := t.setRemoteDescription(*sd); err != nil { + return err + } + t.params.Handler.OnSetRemoteDescriptionOffer() + t.processSendersPendingConfig() + + rtxRepairs := nonSimulcastRTXRepairsFromSDP(parsed, t.params.Logger) + if len(rtxRepairs) > 0 { + t.params.Logger.Debugw("rtx pairs found from sdp", "ssrcs", rtxRepairs) + for repair, base := range rtxRepairs { + t.params.Config.BufferFactory.SetRTXPair(repair, base, "") + } + } + + if t.currentOfferIceCredential == "" || offerRestartICE { + t.currentOfferIceCredential = iceCredential + } + + return t.createAndSendAnswer() +} + +func (t *PCTransport) handleRemoteAnswerReceived(sd *webrtc.SessionDescription, answerId uint32) error { + t.params.Logger.Debugw("processing answer", "answerId", answerId) + if answerId != 0 && answerId != t.localOfferId.Load() { + t.params.Logger.Warnw( + "sdp state: answer id mismatch", nil, + "expected", t.localOfferId.Load(), + "got", answerId, + ) + } + t.remoteAnswerId.Store(answerId) + + t.clearSignalStateCheckTimer() + + if err := t.setRemoteDescription(*sd); err != nil { + // Pion will call RTPSender.Send method for each new added Downtrack, and return error if the DownTrack.Bind + // returns error. In case of Downtrack.Bind returns ErrUnsupportedCodec, the signal state will be stable as negotiation is aleady compelted + // before startRTPSenders, and the peerconnection state can be recovered by next negotiation which will be triggered + // by the SubscriptionManager unsubscribe the failure DownTrack. So don't treat this error as negotiation failure. + if !errors.Is(err, webrtc.ErrUnsupportedCodec) { + return err + } + } + + if t.negotiationState == transport.NegotiationStateRetry { + t.setNegotiationState(transport.NegotiationStateNone) + + t.params.Logger.Debugw("re-negotiate after receiving answer") + return t.createAndSendOffer(nil) + } + + t.setNegotiationState(transport.NegotiationStateNone) + return nil +} + +func (t *PCTransport) doICERestart() error { + if t.pc.ConnectionState() == webrtc.PeerConnectionStateClosed { + t.params.Logger.Warnw("trying to restart ICE on closed peer connection", nil) + return nil + } + + // if restart is requested, but negotiation never started + iceGatheringState := t.pc.ICEGatheringState() + if iceGatheringState == webrtc.ICEGatheringStateNew { + t.params.Logger.Debugw("skipping ICE restart on not yet started peer connection") + return nil + } + + // if restart is requested, and we are not ready, then continue afterwards + if iceGatheringState == webrtc.ICEGatheringStateGathering { + t.params.Logger.Debugw("deferring ICE restart to after gathering") + t.restartAfterGathering = true + return nil + } + + if t.resetShortConnOnICERestart.CompareAndSwap(true, false) { + t.resetShortConn() + } + + if t.negotiationState == transport.NegotiationStateNone { + t.outputAndClearICEStats() + return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) + } + + currentRemoteDescription := t.pc.CurrentRemoteDescription() + if currentRemoteDescription == nil { + // restart without current remote description, send current local description again to try recover + offer := t.pc.LocalDescription() + if offer == nil { + // it should not happen, log just in case + t.params.Logger.Warnw("ice restart without local offer", nil) + return ErrIceRestartWithoutLocalSDP + } else { + t.params.Logger.Infow("deferring ice restart to next offer") + t.setNegotiationState(transport.NegotiationStateRetry) + t.restartAtNextOffer = true + + remoteAnswerId := t.remoteAnswerId.Load() + if remoteAnswerId != 0 && remoteAnswerId != t.localOfferId.Load() { + t.params.Logger.Warnw( + "sdp state: answer not received in ICE restart", nil, + "localOfferId", t.localOfferId.Load(), + "remoteAnswerId", remoteAnswerId, + ) + } + + err := t.params.Handler.OnOffer(*offer, t.localOfferId.Inc(), t.getMidToTrackIDMapping()) + if err != nil { + prometheus.RecordServiceOperationError("offer", "write_message") + } else { + prometheus.RecordServiceOperationSuccess("offer") + } + return err + } + } else { + // recover by re-applying the last answer + t.params.Logger.Infow("recovering from client negotiation state on ICE restart") + if err := t.pc.SetRemoteDescription(*currentRemoteDescription); err != nil { + prometheus.RecordServiceOperationError("offer", "remote_description") + return errors.Wrap(err, "set remote description failed") + } else { + t.setNegotiationState(transport.NegotiationStateNone) + t.outputAndClearICEStats() + return t.createAndSendOffer(&webrtc.OfferOptions{ICERestart: true}) + } + } +} + +func (t *PCTransport) handleICERestart(_ event) error { + return t.doICERestart() +} + +func (t *PCTransport) onNegotiationFailed(warning bool, reason string) { + logFields := []any{ + "reason", reason, + "localCurrent", t.pc.CurrentLocalDescription(), + "localPending", t.pc.PendingLocalDescription(), + "remoteCurrent", t.pc.CurrentRemoteDescription(), + "remotePending", t.pc.PendingRemoteDescription(), + } + if warning { + t.params.Logger.Warnw( + "negotiation failed", + nil, + logFields..., + ) + } else { + t.params.Logger.Infow("negotiation failed", logFields...) + } + t.params.Handler.OnNegotiationFailed() +} + +func (t *PCTransport) outputAndClearICEStats() { + t.lock.Lock() + stats := t.mayFailedICEStats + t.mayFailedICEStats = nil + t.lock.Unlock() + + if len(stats) > 0 { + t.params.Logger.Infow("ICE candidate pair stats", "stats", iceCandidatePairStatsEncoder{stats}) + } +} + +func (t *PCTransport) getMidToTrackIDMapping() map[string]string { + transceivers := t.pc.GetTransceivers() + midToTrackID := make(map[string]string, len(transceivers)) + for _, tr := range transceivers { + if mid := tr.Mid(); mid != "" { + if sender := tr.Sender(); sender != nil { + if track := sender.Track(); track != nil { + midToTrackID[mid] = track.ID() + } + } + } + } + return midToTrackID +} + +// ---------------------- + +type configureSenderParams struct { + transceiver *webrtc.RTPTransceiver + enabledCodecs []*livekit.Codec + rtcpFeedbackConfig RTCPFeedbackConfig + filterOutH264HighProfile bool + enableAudioStereo bool + enableAudioNACK bool +} + +func configureSender(params configureSenderParams) { + configureSenderCodecs( + params.transceiver, + params.enabledCodecs, + params.rtcpFeedbackConfig, + params.filterOutH264HighProfile, + ) + + if params.transceiver.Kind() == webrtc.RTPCodecTypeAudio { + configureSenderAudio(params.transceiver, params.enableAudioStereo, params.enableAudioNACK) + } +} + +// configure subscriber transceiver for audio stereo and nack +// pion doesn't support per transciver codec configuration, so the nack of this session will be disabled +// forever once it is first disabled by a transceiver. +func configureSenderAudio(tr *webrtc.RTPTransceiver, stereo bool, nack bool) { + sender := tr.Sender() + if sender == nil { + return + } + + // enable stereo + codecs := sender.GetParameters().Codecs + configCodecs := make([]webrtc.RTPCodecParameters, 0, len(codecs)) + for _, c := range codecs { + if mime.IsMimeTypeStringOpus(c.MimeType) { + c.SDPFmtpLine = strings.ReplaceAll(c.SDPFmtpLine, ";sprop-stereo=1", "") + if stereo { + c.SDPFmtpLine += ";sprop-stereo=1" + } + if !nack { + for i, fb := range c.RTCPFeedback { + if fb.Type == webrtc.TypeRTCPFBNACK { + c.RTCPFeedback = append(c.RTCPFeedback[:i], c.RTCPFeedback[i+1:]...) + break + } + } + } + } + configCodecs = append(configCodecs, c) + } + + tr.SetCodecPreferences(configCodecs) +} + +// In single peer connection mode, set up enebled codecs for sender. +// The config provides config of direction. +// For publisher peer connection those are publish enabled codecs +// and for subscriber peer connection those are subscribe enabled codecs. +// +// But, in single peer connection mode, if setting up a transceiver where the media is +// flowing in the other direction, the other direction codec config needs to be set. +func configureSenderCodecs( + tr *webrtc.RTPTransceiver, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, + filterOutH264HighProfile bool, +) { + if len(enabledCodecs) == 0 { + return + } + + sender := tr.Sender() + if sender == nil { + return + } + + filteredCodecs := filterCodecs( + sender.GetParameters().Codecs, + enabledCodecs, + rtcpFeedbackConfig, + filterOutH264HighProfile, + ) + tr.SetCodecPreferences(filteredCodecs) +} + +func configureReceiverCodecs( + tr *webrtc.RTPTransceiver, + preferredMimeType string, + compliesWithCodecOrderInSDPAnswer bool, +) { + receiver := tr.Receiver() + if receiver == nil { + return + } + + var preferredCodecs, leftCodecs []webrtc.RTPCodecParameters + for _, c := range receiver.GetParameters().Codecs { + if tr.Kind() == webrtc.RTPCodecTypeAudio { + nackFound := false + for _, fb := range c.RTCPFeedback { + if fb.Type == webrtc.TypeRTCPFBNACK { + nackFound = true + break + } + } + + if !nackFound { + c.RTCPFeedback = append(c.RTCPFeedback, webrtc.RTCPFeedback{Type: webrtc.TypeRTCPFBNACK}) + } + } + + if mime.GetMimeTypeCodec(preferredMimeType) == mime.GetMimeTypeCodec(c.RTPCodecCapability.MimeType) { + preferredCodecs = append(preferredCodecs, c) + } else { + leftCodecs = append(leftCodecs, c) + } + } + if len(preferredCodecs) == 0 { + return + } + + reorderedCodecs := append([]webrtc.RTPCodecParameters{}, preferredCodecs...) + if tr.Kind() == webrtc.RTPCodecTypeVideo { + // if the client don't comply with codec order in SDP answer, only keep preferred codecs to force client to use it + if compliesWithCodecOrderInSDPAnswer { + reorderedCodecs = append(reorderedCodecs, leftCodecs...) + } + } else { + reorderedCodecs = append(reorderedCodecs, leftCodecs...) + } + tr.SetCodecPreferences(reorderedCodecs) +} + +func nonSimulcastRTXRepairsFromSDP(s *sdp.SessionDescription, logger logger.Logger) map[uint32]uint32 { + rtxRepairFlows := map[uint32]uint32{} + for _, media := range s.MediaDescriptions { + // extract rtx repair flows from the media section for non-simulcast stream, + // pion will handle simulcast streams by rid probe, don't need handle it here. + var ridFound bool + rtxPairs := make(map[uint32]uint32) + findRTX: + for _, attr := range media.Attributes { + switch attr.Key { + case "rid": + ridFound = true + break findRTX + case sdp.AttrKeySSRCGroup: + split := strings.Split(attr.Value, " ") + if split[0] == sdp.SemanticTokenFlowIdentification { + // Essentially lines like `a=ssrc-group:FID 2231627014 632943048` are processed by this section + // as this declares that the second SSRC (632943048) is a rtx repair flow (RFC4588) for the first + // (2231627014) as specified in RFC5576 + if len(split) == 3 { + baseSsrc, err := strconv.ParseUint(split[1], 10, 32) + if err != nil { + logger.Warnw("Failed to parse SSRC", err, "ssrc", split[1]) + continue + } + rtxRepairFlow, err := strconv.ParseUint(split[2], 10, 32) + if err != nil { + logger.Warnw("Failed to parse SSRC", err, "ssrc", split[2]) + continue + } + rtxPairs[uint32(rtxRepairFlow)] = uint32(baseSsrc) + } + } + } + } + if !ridFound { + maps.Copy(rtxRepairFlows, rtxPairs) + } + } + + return rtxRepairFlows +} + +// ---------------------- + +type iceCandidatePairStatsEncoder struct { + stats []iceCandidatePairStats +} + +func (e iceCandidatePairStatsEncoder) MarshalLogArray(arr zapcore.ArrayEncoder) error { + for _, s := range e.stats { + if err := arr.AppendObject(s); err != nil { + return err + } + } + return nil +} + +type iceCandidatePairStats struct { + webrtc.ICECandidatePairStats + local, remote webrtc.ICECandidateStats +} + +func (r iceCandidatePairStats) MarshalLogObject(e zapcore.ObjectEncoder) error { + candidateToString := func(c webrtc.ICECandidateStats) string { + return fmt.Sprintf("%s:%d %s type(%s/%s), priority(%d)", c.IP, c.Port, c.Protocol, c.CandidateType, c.RelayProtocol, c.Priority) + } + e.AddString("state", string(r.State)) + e.AddBool("nominated", r.Nominated) + e.AddString("local", candidateToString(r.local)) + e.AddString("remote", candidateToString(r.remote)) + e.AddUint64("requestsSent", r.RequestsSent) + e.AddUint64("responsesReceived", r.ResponsesReceived) + e.AddUint64("requestsReceived", r.RequestsReceived) + e.AddUint64("responsesSent", r.ResponsesSent) + e.AddTime("firstRequestSentAt", r.FirstRequestTimestamp.Time()) + e.AddTime("lastRequestSentAt", r.LastRequestTimestamp.Time()) + e.AddTime("firstResponseReceivedAt", r.FirstResponseTimestamp.Time()) + e.AddTime("lastResponseReceivedAt", r.LastResponseTimestamp.Time()) + e.AddTime("firstRequestReceivedAt", r.FirstRequestReceivedTimestamp.Time()) + e.AddTime("lastRequestReceivedAt", r.LastRequestReceivedTimestamp.Time()) + + return nil +} diff --git a/livekit/pkg/rtc/transport/handler.go b/livekit/pkg/rtc/transport/handler.go new file mode 100644 index 0000000..5b6ea75 --- /dev/null +++ b/livekit/pkg/rtc/transport/handler.go @@ -0,0 +1,82 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "errors" + + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/protocol/livekit" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +var ( + ErrNoICECandidateHandler = errors.New("no ICE candidate handler") + ErrNoOfferHandler = errors.New("no offer handler") + ErrNoAnswerHandler = errors.New("no answer handler") +) + +//counterfeiter:generate . Handler +type Handler interface { + OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error + OnInitialConnected() + OnFullyEstablished() + OnFailed(isShortLived bool, iceConnectionInfo *types.ICEConnectionInfo) + OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) + OnDataMessage(kind livekit.DataPacket_Kind, data []byte) + OnDataMessageUnlabeled(data []byte) + OnDataTrackMessage(data []byte, arrivalTime int64) + OnDataSendError(err error) + OnOffer(sd webrtc.SessionDescription, offerId uint32, midToTrackID map[string]string) error + OnSetRemoteDescriptionOffer() + OnAnswer(sd webrtc.SessionDescription, answerId uint32, midToTrackID map[string]string) error + OnNegotiationStateChanged(state NegotiationState) + OnNegotiationFailed() + OnStreamStateChange(update *streamallocator.StreamStateUpdate) error + OnUnmatchedMedia(numAudios uint32, numVideos uint32) error +} + +type UnimplementedHandler struct{} + +func (h UnimplementedHandler) OnICECandidate(c *webrtc.ICECandidate, target livekit.SignalTarget) error { + return ErrNoICECandidateHandler +} +func (h UnimplementedHandler) OnInitialConnected() {} +func (h UnimplementedHandler) OnFullyEstablished() {} +func (h UnimplementedHandler) OnFailed(isShortLived bool) {} +func (h UnimplementedHandler) OnTrack(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) {} +func (h UnimplementedHandler) OnDataMessage(kind livekit.DataPacket_Kind, data []byte) {} +func (h UnimplementedHandler) OnDataMessageUnlabeled(data []byte) {} +func (h UnimplementedHandler) OnDataTrackMessage(data []byte, arrivalTime int64) {} +func (h UnimplementedHandler) OnDataSendError(err error) {} +func (h UnimplementedHandler) OnOffer(sd webrtc.SessionDescription, offerId uint32, midToTrackID map[string]string) error { + return ErrNoOfferHandler +} +func (h UnimplementedHandler) OnSetRemoteDescriptionOffer() {} +func (h UnimplementedHandler) OnAnswer(sd webrtc.SessionDescription, answerId uint32, midToTrackID map[string]string) error { + return ErrNoAnswerHandler +} +func (h UnimplementedHandler) OnNegotiationStateChanged(state NegotiationState) {} +func (h UnimplementedHandler) OnNegotiationFailed() {} +func (h UnimplementedHandler) OnStreamStateChange(update *streamallocator.StreamStateUpdate) error { + return nil +} +func (h UnimplementedHandler) OnUnmatchedMedia(numAudios uint32, numVideos uint32) error { + return nil +} diff --git a/livekit/pkg/rtc/transport/negotiationstate.go b/livekit/pkg/rtc/transport/negotiationstate.go new file mode 100644 index 0000000..26d6d1a --- /dev/null +++ b/livekit/pkg/rtc/transport/negotiationstate.go @@ -0,0 +1,40 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import "fmt" + +type NegotiationState int + +const ( + NegotiationStateNone NegotiationState = iota + // waiting for remote description + NegotiationStateRemote + // need to Negotiate again + NegotiationStateRetry +) + +func (n NegotiationState) String() string { + switch n { + case NegotiationStateNone: + return "NONE" + case NegotiationStateRemote: + return "WAITING_FOR_REMOTE" + case NegotiationStateRetry: + return "RETRY" + default: + return fmt.Sprintf("%d", int(n)) + } +} diff --git a/livekit/pkg/rtc/transport/transportfakes/fake_handler.go b/livekit/pkg/rtc/transport/transportfakes/fake_handler.go new file mode 100644 index 0000000..5c09af0 --- /dev/null +++ b/livekit/pkg/rtc/transport/transportfakes/fake_handler.go @@ -0,0 +1,807 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package transportfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/streamallocator" + "github.com/livekit/protocol/livekit" + webrtc "github.com/pion/webrtc/v4" +) + +type FakeHandler struct { + OnAnswerStub func(webrtc.SessionDescription, uint32, map[string]string) error + onAnswerMutex sync.RWMutex + onAnswerArgsForCall []struct { + arg1 webrtc.SessionDescription + arg2 uint32 + arg3 map[string]string + } + onAnswerReturns struct { + result1 error + } + onAnswerReturnsOnCall map[int]struct { + result1 error + } + OnDataMessageStub func(livekit.DataPacket_Kind, []byte) + onDataMessageMutex sync.RWMutex + onDataMessageArgsForCall []struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + } + OnDataMessageUnlabeledStub func([]byte) + onDataMessageUnlabeledMutex sync.RWMutex + onDataMessageUnlabeledArgsForCall []struct { + arg1 []byte + } + OnDataSendErrorStub func(error) + onDataSendErrorMutex sync.RWMutex + onDataSendErrorArgsForCall []struct { + arg1 error + } + OnDataTrackMessageStub func([]byte, int64) + onDataTrackMessageMutex sync.RWMutex + onDataTrackMessageArgsForCall []struct { + arg1 []byte + arg2 int64 + } + OnFailedStub func(bool, *types.ICEConnectionInfo) + onFailedMutex sync.RWMutex + onFailedArgsForCall []struct { + arg1 bool + arg2 *types.ICEConnectionInfo + } + OnFullyEstablishedStub func() + onFullyEstablishedMutex sync.RWMutex + onFullyEstablishedArgsForCall []struct { + } + OnICECandidateStub func(*webrtc.ICECandidate, livekit.SignalTarget) error + onICECandidateMutex sync.RWMutex + onICECandidateArgsForCall []struct { + arg1 *webrtc.ICECandidate + arg2 livekit.SignalTarget + } + onICECandidateReturns struct { + result1 error + } + onICECandidateReturnsOnCall map[int]struct { + result1 error + } + OnInitialConnectedStub func() + onInitialConnectedMutex sync.RWMutex + onInitialConnectedArgsForCall []struct { + } + OnNegotiationFailedStub func() + onNegotiationFailedMutex sync.RWMutex + onNegotiationFailedArgsForCall []struct { + } + OnNegotiationStateChangedStub func(transport.NegotiationState) + onNegotiationStateChangedMutex sync.RWMutex + onNegotiationStateChangedArgsForCall []struct { + arg1 transport.NegotiationState + } + OnOfferStub func(webrtc.SessionDescription, uint32, map[string]string) error + onOfferMutex sync.RWMutex + onOfferArgsForCall []struct { + arg1 webrtc.SessionDescription + arg2 uint32 + arg3 map[string]string + } + onOfferReturns struct { + result1 error + } + onOfferReturnsOnCall map[int]struct { + result1 error + } + OnSetRemoteDescriptionOfferStub func() + onSetRemoteDescriptionOfferMutex sync.RWMutex + onSetRemoteDescriptionOfferArgsForCall []struct { + } + OnStreamStateChangeStub func(*streamallocator.StreamStateUpdate) error + onStreamStateChangeMutex sync.RWMutex + onStreamStateChangeArgsForCall []struct { + arg1 *streamallocator.StreamStateUpdate + } + onStreamStateChangeReturns struct { + result1 error + } + onStreamStateChangeReturnsOnCall map[int]struct { + result1 error + } + OnTrackStub func(*webrtc.TrackRemote, *webrtc.RTPReceiver) + onTrackMutex sync.RWMutex + onTrackArgsForCall []struct { + arg1 *webrtc.TrackRemote + arg2 *webrtc.RTPReceiver + } + OnUnmatchedMediaStub func(uint32, uint32) error + onUnmatchedMediaMutex sync.RWMutex + onUnmatchedMediaArgsForCall []struct { + arg1 uint32 + arg2 uint32 + } + onUnmatchedMediaReturns struct { + result1 error + } + onUnmatchedMediaReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeHandler) OnAnswer(arg1 webrtc.SessionDescription, arg2 uint32, arg3 map[string]string) error { + fake.onAnswerMutex.Lock() + ret, specificReturn := fake.onAnswerReturnsOnCall[len(fake.onAnswerArgsForCall)] + fake.onAnswerArgsForCall = append(fake.onAnswerArgsForCall, struct { + arg1 webrtc.SessionDescription + arg2 uint32 + arg3 map[string]string + }{arg1, arg2, arg3}) + stub := fake.OnAnswerStub + fakeReturns := fake.onAnswerReturns + fake.recordInvocation("OnAnswer", []interface{}{arg1, arg2, arg3}) + fake.onAnswerMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnAnswerCallCount() int { + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + return len(fake.onAnswerArgsForCall) +} + +func (fake *FakeHandler) OnAnswerCalls(stub func(webrtc.SessionDescription, uint32, map[string]string) error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = stub +} + +func (fake *FakeHandler) OnAnswerArgsForCall(i int) (webrtc.SessionDescription, uint32, map[string]string) { + fake.onAnswerMutex.RLock() + defer fake.onAnswerMutex.RUnlock() + argsForCall := fake.onAnswerArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeHandler) OnAnswerReturns(result1 error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = nil + fake.onAnswerReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnAnswerReturnsOnCall(i int, result1 error) { + fake.onAnswerMutex.Lock() + defer fake.onAnswerMutex.Unlock() + fake.OnAnswerStub = nil + if fake.onAnswerReturnsOnCall == nil { + fake.onAnswerReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onAnswerReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnDataMessage(arg1 livekit.DataPacket_Kind, arg2 []byte) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.onDataMessageMutex.Lock() + fake.onDataMessageArgsForCall = append(fake.onDataMessageArgsForCall, struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + }{arg1, arg2Copy}) + stub := fake.OnDataMessageStub + fake.recordInvocation("OnDataMessage", []interface{}{arg1, arg2Copy}) + fake.onDataMessageMutex.Unlock() + if stub != nil { + fake.OnDataMessageStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnDataMessageCallCount() int { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + return len(fake.onDataMessageArgsForCall) +} + +func (fake *FakeHandler) OnDataMessageCalls(stub func(livekit.DataPacket_Kind, []byte)) { + fake.onDataMessageMutex.Lock() + defer fake.onDataMessageMutex.Unlock() + fake.OnDataMessageStub = stub +} + +func (fake *FakeHandler) OnDataMessageArgsForCall(i int) (livekit.DataPacket_Kind, []byte) { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + argsForCall := fake.onDataMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnDataMessageUnlabeled(arg1 []byte) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.onDataMessageUnlabeledMutex.Lock() + fake.onDataMessageUnlabeledArgsForCall = append(fake.onDataMessageUnlabeledArgsForCall, struct { + arg1 []byte + }{arg1Copy}) + stub := fake.OnDataMessageUnlabeledStub + fake.recordInvocation("OnDataMessageUnlabeled", []interface{}{arg1Copy}) + fake.onDataMessageUnlabeledMutex.Unlock() + if stub != nil { + fake.OnDataMessageUnlabeledStub(arg1) + } +} + +func (fake *FakeHandler) OnDataMessageUnlabeledCallCount() int { + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() + return len(fake.onDataMessageUnlabeledArgsForCall) +} + +func (fake *FakeHandler) OnDataMessageUnlabeledCalls(stub func([]byte)) { + fake.onDataMessageUnlabeledMutex.Lock() + defer fake.onDataMessageUnlabeledMutex.Unlock() + fake.OnDataMessageUnlabeledStub = stub +} + +func (fake *FakeHandler) OnDataMessageUnlabeledArgsForCall(i int) []byte { + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() + argsForCall := fake.onDataMessageUnlabeledArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnDataSendError(arg1 error) { + fake.onDataSendErrorMutex.Lock() + fake.onDataSendErrorArgsForCall = append(fake.onDataSendErrorArgsForCall, struct { + arg1 error + }{arg1}) + stub := fake.OnDataSendErrorStub + fake.recordInvocation("OnDataSendError", []interface{}{arg1}) + fake.onDataSendErrorMutex.Unlock() + if stub != nil { + fake.OnDataSendErrorStub(arg1) + } +} + +func (fake *FakeHandler) OnDataSendErrorCallCount() int { + fake.onDataSendErrorMutex.RLock() + defer fake.onDataSendErrorMutex.RUnlock() + return len(fake.onDataSendErrorArgsForCall) +} + +func (fake *FakeHandler) OnDataSendErrorCalls(stub func(error)) { + fake.onDataSendErrorMutex.Lock() + defer fake.onDataSendErrorMutex.Unlock() + fake.OnDataSendErrorStub = stub +} + +func (fake *FakeHandler) OnDataSendErrorArgsForCall(i int) error { + fake.onDataSendErrorMutex.RLock() + defer fake.onDataSendErrorMutex.RUnlock() + argsForCall := fake.onDataSendErrorArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnDataTrackMessage(arg1 []byte, arg2 int64) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.onDataTrackMessageMutex.Lock() + fake.onDataTrackMessageArgsForCall = append(fake.onDataTrackMessageArgsForCall, struct { + arg1 []byte + arg2 int64 + }{arg1Copy, arg2}) + stub := fake.OnDataTrackMessageStub + fake.recordInvocation("OnDataTrackMessage", []interface{}{arg1Copy, arg2}) + fake.onDataTrackMessageMutex.Unlock() + if stub != nil { + fake.OnDataTrackMessageStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnDataTrackMessageCallCount() int { + fake.onDataTrackMessageMutex.RLock() + defer fake.onDataTrackMessageMutex.RUnlock() + return len(fake.onDataTrackMessageArgsForCall) +} + +func (fake *FakeHandler) OnDataTrackMessageCalls(stub func([]byte, int64)) { + fake.onDataTrackMessageMutex.Lock() + defer fake.onDataTrackMessageMutex.Unlock() + fake.OnDataTrackMessageStub = stub +} + +func (fake *FakeHandler) OnDataTrackMessageArgsForCall(i int) ([]byte, int64) { + fake.onDataTrackMessageMutex.RLock() + defer fake.onDataTrackMessageMutex.RUnlock() + argsForCall := fake.onDataTrackMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnFailed(arg1 bool, arg2 *types.ICEConnectionInfo) { + fake.onFailedMutex.Lock() + fake.onFailedArgsForCall = append(fake.onFailedArgsForCall, struct { + arg1 bool + arg2 *types.ICEConnectionInfo + }{arg1, arg2}) + stub := fake.OnFailedStub + fake.recordInvocation("OnFailed", []interface{}{arg1, arg2}) + fake.onFailedMutex.Unlock() + if stub != nil { + fake.OnFailedStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnFailedCallCount() int { + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + return len(fake.onFailedArgsForCall) +} + +func (fake *FakeHandler) OnFailedCalls(stub func(bool, *types.ICEConnectionInfo)) { + fake.onFailedMutex.Lock() + defer fake.onFailedMutex.Unlock() + fake.OnFailedStub = stub +} + +func (fake *FakeHandler) OnFailedArgsForCall(i int) (bool, *types.ICEConnectionInfo) { + fake.onFailedMutex.RLock() + defer fake.onFailedMutex.RUnlock() + argsForCall := fake.onFailedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnFullyEstablished() { + fake.onFullyEstablishedMutex.Lock() + fake.onFullyEstablishedArgsForCall = append(fake.onFullyEstablishedArgsForCall, struct { + }{}) + stub := fake.OnFullyEstablishedStub + fake.recordInvocation("OnFullyEstablished", []interface{}{}) + fake.onFullyEstablishedMutex.Unlock() + if stub != nil { + fake.OnFullyEstablishedStub() + } +} + +func (fake *FakeHandler) OnFullyEstablishedCallCount() int { + fake.onFullyEstablishedMutex.RLock() + defer fake.onFullyEstablishedMutex.RUnlock() + return len(fake.onFullyEstablishedArgsForCall) +} + +func (fake *FakeHandler) OnFullyEstablishedCalls(stub func()) { + fake.onFullyEstablishedMutex.Lock() + defer fake.onFullyEstablishedMutex.Unlock() + fake.OnFullyEstablishedStub = stub +} + +func (fake *FakeHandler) OnICECandidate(arg1 *webrtc.ICECandidate, arg2 livekit.SignalTarget) error { + fake.onICECandidateMutex.Lock() + ret, specificReturn := fake.onICECandidateReturnsOnCall[len(fake.onICECandidateArgsForCall)] + fake.onICECandidateArgsForCall = append(fake.onICECandidateArgsForCall, struct { + arg1 *webrtc.ICECandidate + arg2 livekit.SignalTarget + }{arg1, arg2}) + stub := fake.OnICECandidateStub + fakeReturns := fake.onICECandidateReturns + fake.recordInvocation("OnICECandidate", []interface{}{arg1, arg2}) + fake.onICECandidateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnICECandidateCallCount() int { + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + return len(fake.onICECandidateArgsForCall) +} + +func (fake *FakeHandler) OnICECandidateCalls(stub func(*webrtc.ICECandidate, livekit.SignalTarget) error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = stub +} + +func (fake *FakeHandler) OnICECandidateArgsForCall(i int) (*webrtc.ICECandidate, livekit.SignalTarget) { + fake.onICECandidateMutex.RLock() + defer fake.onICECandidateMutex.RUnlock() + argsForCall := fake.onICECandidateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnICECandidateReturns(result1 error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = nil + fake.onICECandidateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnICECandidateReturnsOnCall(i int, result1 error) { + fake.onICECandidateMutex.Lock() + defer fake.onICECandidateMutex.Unlock() + fake.OnICECandidateStub = nil + if fake.onICECandidateReturnsOnCall == nil { + fake.onICECandidateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onICECandidateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnInitialConnected() { + fake.onInitialConnectedMutex.Lock() + fake.onInitialConnectedArgsForCall = append(fake.onInitialConnectedArgsForCall, struct { + }{}) + stub := fake.OnInitialConnectedStub + fake.recordInvocation("OnInitialConnected", []interface{}{}) + fake.onInitialConnectedMutex.Unlock() + if stub != nil { + fake.OnInitialConnectedStub() + } +} + +func (fake *FakeHandler) OnInitialConnectedCallCount() int { + fake.onInitialConnectedMutex.RLock() + defer fake.onInitialConnectedMutex.RUnlock() + return len(fake.onInitialConnectedArgsForCall) +} + +func (fake *FakeHandler) OnInitialConnectedCalls(stub func()) { + fake.onInitialConnectedMutex.Lock() + defer fake.onInitialConnectedMutex.Unlock() + fake.OnInitialConnectedStub = stub +} + +func (fake *FakeHandler) OnNegotiationFailed() { + fake.onNegotiationFailedMutex.Lock() + fake.onNegotiationFailedArgsForCall = append(fake.onNegotiationFailedArgsForCall, struct { + }{}) + stub := fake.OnNegotiationFailedStub + fake.recordInvocation("OnNegotiationFailed", []interface{}{}) + fake.onNegotiationFailedMutex.Unlock() + if stub != nil { + fake.OnNegotiationFailedStub() + } +} + +func (fake *FakeHandler) OnNegotiationFailedCallCount() int { + fake.onNegotiationFailedMutex.RLock() + defer fake.onNegotiationFailedMutex.RUnlock() + return len(fake.onNegotiationFailedArgsForCall) +} + +func (fake *FakeHandler) OnNegotiationFailedCalls(stub func()) { + fake.onNegotiationFailedMutex.Lock() + defer fake.onNegotiationFailedMutex.Unlock() + fake.OnNegotiationFailedStub = stub +} + +func (fake *FakeHandler) OnNegotiationStateChanged(arg1 transport.NegotiationState) { + fake.onNegotiationStateChangedMutex.Lock() + fake.onNegotiationStateChangedArgsForCall = append(fake.onNegotiationStateChangedArgsForCall, struct { + arg1 transport.NegotiationState + }{arg1}) + stub := fake.OnNegotiationStateChangedStub + fake.recordInvocation("OnNegotiationStateChanged", []interface{}{arg1}) + fake.onNegotiationStateChangedMutex.Unlock() + if stub != nil { + fake.OnNegotiationStateChangedStub(arg1) + } +} + +func (fake *FakeHandler) OnNegotiationStateChangedCallCount() int { + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + return len(fake.onNegotiationStateChangedArgsForCall) +} + +func (fake *FakeHandler) OnNegotiationStateChangedCalls(stub func(transport.NegotiationState)) { + fake.onNegotiationStateChangedMutex.Lock() + defer fake.onNegotiationStateChangedMutex.Unlock() + fake.OnNegotiationStateChangedStub = stub +} + +func (fake *FakeHandler) OnNegotiationStateChangedArgsForCall(i int) transport.NegotiationState { + fake.onNegotiationStateChangedMutex.RLock() + defer fake.onNegotiationStateChangedMutex.RUnlock() + argsForCall := fake.onNegotiationStateChangedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnOffer(arg1 webrtc.SessionDescription, arg2 uint32, arg3 map[string]string) error { + fake.onOfferMutex.Lock() + ret, specificReturn := fake.onOfferReturnsOnCall[len(fake.onOfferArgsForCall)] + fake.onOfferArgsForCall = append(fake.onOfferArgsForCall, struct { + arg1 webrtc.SessionDescription + arg2 uint32 + arg3 map[string]string + }{arg1, arg2, arg3}) + stub := fake.OnOfferStub + fakeReturns := fake.onOfferReturns + fake.recordInvocation("OnOffer", []interface{}{arg1, arg2, arg3}) + fake.onOfferMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnOfferCallCount() int { + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + return len(fake.onOfferArgsForCall) +} + +func (fake *FakeHandler) OnOfferCalls(stub func(webrtc.SessionDescription, uint32, map[string]string) error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = stub +} + +func (fake *FakeHandler) OnOfferArgsForCall(i int) (webrtc.SessionDescription, uint32, map[string]string) { + fake.onOfferMutex.RLock() + defer fake.onOfferMutex.RUnlock() + argsForCall := fake.onOfferArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeHandler) OnOfferReturns(result1 error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = nil + fake.onOfferReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnOfferReturnsOnCall(i int, result1 error) { + fake.onOfferMutex.Lock() + defer fake.onOfferMutex.Unlock() + fake.OnOfferStub = nil + if fake.onOfferReturnsOnCall == nil { + fake.onOfferReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onOfferReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnSetRemoteDescriptionOffer() { + fake.onSetRemoteDescriptionOfferMutex.Lock() + fake.onSetRemoteDescriptionOfferArgsForCall = append(fake.onSetRemoteDescriptionOfferArgsForCall, struct { + }{}) + stub := fake.OnSetRemoteDescriptionOfferStub + fake.recordInvocation("OnSetRemoteDescriptionOffer", []interface{}{}) + fake.onSetRemoteDescriptionOfferMutex.Unlock() + if stub != nil { + fake.OnSetRemoteDescriptionOfferStub() + } +} + +func (fake *FakeHandler) OnSetRemoteDescriptionOfferCallCount() int { + fake.onSetRemoteDescriptionOfferMutex.RLock() + defer fake.onSetRemoteDescriptionOfferMutex.RUnlock() + return len(fake.onSetRemoteDescriptionOfferArgsForCall) +} + +func (fake *FakeHandler) OnSetRemoteDescriptionOfferCalls(stub func()) { + fake.onSetRemoteDescriptionOfferMutex.Lock() + defer fake.onSetRemoteDescriptionOfferMutex.Unlock() + fake.OnSetRemoteDescriptionOfferStub = stub +} + +func (fake *FakeHandler) OnStreamStateChange(arg1 *streamallocator.StreamStateUpdate) error { + fake.onStreamStateChangeMutex.Lock() + ret, specificReturn := fake.onStreamStateChangeReturnsOnCall[len(fake.onStreamStateChangeArgsForCall)] + fake.onStreamStateChangeArgsForCall = append(fake.onStreamStateChangeArgsForCall, struct { + arg1 *streamallocator.StreamStateUpdate + }{arg1}) + stub := fake.OnStreamStateChangeStub + fakeReturns := fake.onStreamStateChangeReturns + fake.recordInvocation("OnStreamStateChange", []interface{}{arg1}) + fake.onStreamStateChangeMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnStreamStateChangeCallCount() int { + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + return len(fake.onStreamStateChangeArgsForCall) +} + +func (fake *FakeHandler) OnStreamStateChangeCalls(stub func(*streamallocator.StreamStateUpdate) error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = stub +} + +func (fake *FakeHandler) OnStreamStateChangeArgsForCall(i int) *streamallocator.StreamStateUpdate { + fake.onStreamStateChangeMutex.RLock() + defer fake.onStreamStateChangeMutex.RUnlock() + argsForCall := fake.onStreamStateChangeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeHandler) OnStreamStateChangeReturns(result1 error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = nil + fake.onStreamStateChangeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnStreamStateChangeReturnsOnCall(i int, result1 error) { + fake.onStreamStateChangeMutex.Lock() + defer fake.onStreamStateChangeMutex.Unlock() + fake.OnStreamStateChangeStub = nil + if fake.onStreamStateChangeReturnsOnCall == nil { + fake.onStreamStateChangeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onStreamStateChangeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnTrack(arg1 *webrtc.TrackRemote, arg2 *webrtc.RTPReceiver) { + fake.onTrackMutex.Lock() + fake.onTrackArgsForCall = append(fake.onTrackArgsForCall, struct { + arg1 *webrtc.TrackRemote + arg2 *webrtc.RTPReceiver + }{arg1, arg2}) + stub := fake.OnTrackStub + fake.recordInvocation("OnTrack", []interface{}{arg1, arg2}) + fake.onTrackMutex.Unlock() + if stub != nil { + fake.OnTrackStub(arg1, arg2) + } +} + +func (fake *FakeHandler) OnTrackCallCount() int { + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + return len(fake.onTrackArgsForCall) +} + +func (fake *FakeHandler) OnTrackCalls(stub func(*webrtc.TrackRemote, *webrtc.RTPReceiver)) { + fake.onTrackMutex.Lock() + defer fake.onTrackMutex.Unlock() + fake.OnTrackStub = stub +} + +func (fake *FakeHandler) OnTrackArgsForCall(i int) (*webrtc.TrackRemote, *webrtc.RTPReceiver) { + fake.onTrackMutex.RLock() + defer fake.onTrackMutex.RUnlock() + argsForCall := fake.onTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnUnmatchedMedia(arg1 uint32, arg2 uint32) error { + fake.onUnmatchedMediaMutex.Lock() + ret, specificReturn := fake.onUnmatchedMediaReturnsOnCall[len(fake.onUnmatchedMediaArgsForCall)] + fake.onUnmatchedMediaArgsForCall = append(fake.onUnmatchedMediaArgsForCall, struct { + arg1 uint32 + arg2 uint32 + }{arg1, arg2}) + stub := fake.OnUnmatchedMediaStub + fakeReturns := fake.onUnmatchedMediaReturns + fake.recordInvocation("OnUnmatchedMedia", []interface{}{arg1, arg2}) + fake.onUnmatchedMediaMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeHandler) OnUnmatchedMediaCallCount() int { + fake.onUnmatchedMediaMutex.RLock() + defer fake.onUnmatchedMediaMutex.RUnlock() + return len(fake.onUnmatchedMediaArgsForCall) +} + +func (fake *FakeHandler) OnUnmatchedMediaCalls(stub func(uint32, uint32) error) { + fake.onUnmatchedMediaMutex.Lock() + defer fake.onUnmatchedMediaMutex.Unlock() + fake.OnUnmatchedMediaStub = stub +} + +func (fake *FakeHandler) OnUnmatchedMediaArgsForCall(i int) (uint32, uint32) { + fake.onUnmatchedMediaMutex.RLock() + defer fake.onUnmatchedMediaMutex.RUnlock() + argsForCall := fake.onUnmatchedMediaArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeHandler) OnUnmatchedMediaReturns(result1 error) { + fake.onUnmatchedMediaMutex.Lock() + defer fake.onUnmatchedMediaMutex.Unlock() + fake.OnUnmatchedMediaStub = nil + fake.onUnmatchedMediaReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) OnUnmatchedMediaReturnsOnCall(i int, result1 error) { + fake.onUnmatchedMediaMutex.Lock() + defer fake.onUnmatchedMediaMutex.Unlock() + fake.OnUnmatchedMediaStub = nil + if fake.onUnmatchedMediaReturnsOnCall == nil { + fake.onUnmatchedMediaReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onUnmatchedMediaReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ transport.Handler = new(FakeHandler) diff --git a/livekit/pkg/rtc/transport_test.go b/livekit/pkg/rtc/transport_test.go new file mode 100644 index 0000000..d27ec0a --- /dev/null +++ b/livekit/pkg/rtc/transport_test.go @@ -0,0 +1,637 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/protocol/livekit" +) + +func TestMissingAnswerDuringICERestart(t *testing.T) { + params := TransportParams{ + Config: &WebRTCConfig{}, + IsOfferer: true, + } + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) + require.NoError(t, err) + _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) + require.NoError(t, err) + + paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB + paramsB.IsOfferer = false + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) + + // exchange ICE + handleICEExchange(t, transportA, transportB, handlerA, handlerB) + + connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) + require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) + require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) + + var negotiationState atomic.Value + transportA.OnNegotiationStateChanged(func(state transport.NegotiationState) { + negotiationState.Store(state) + }) + + // offer again, but missed + var offerReceived atomic.Bool + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, _offerId uint32, _midToTrackID map[string]string) error { + require.Equal(t, webrtc.SignalingStateHaveLocalOffer, transportA.pc.SignalingState()) + require.Equal(t, transport.NegotiationStateRemote, negotiationState.Load().(transport.NegotiationState)) + offerReceived.Store(true) + return nil + }) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return offerReceived.Load() + }, 10*time.Second, time.Millisecond*10, "transportA offer not received") + + connectTransports(t, transportA, transportB, handlerA, handlerB, true, 1, 1) + require.Equal(t, webrtc.ICEConnectionStateConnected, transportA.pc.ICEConnectionState()) + require.Equal(t, webrtc.ICEConnectionStateConnected, transportB.pc.ICEConnectionState()) + + transportA.Close() + transportB.Close() +} + +func TestNegotiationTiming(t *testing.T) { + params := TransportParams{ + Config: &WebRTCConfig{}, + IsOfferer: true, + } + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) + require.NoError(t, err) + _, err = transportA.pc.CreateDataChannel(LossyDataChannel, nil) + require.NoError(t, err) + + paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB + paramsB.IsOfferer = false + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) + + require.False(t, transportA.IsEstablished()) + require.False(t, transportB.IsEstablished()) + + handleICEExchange(t, transportA, transportB, handlerA, handlerB) + firstOffer := atomic.Value{} + firstOfferId := atomic.Uint32{} + secondOffer := atomic.Value{} + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) error { + if _, ok := firstOffer.Load().(*webrtc.SessionDescription); !ok { + firstOffer.Store(&sd) + firstOfferId.Store(offerId) + } else { + secondOffer.Store(&sd) + } + return nil + }) + + var negotiationState atomic.Value + transportA.OnNegotiationStateChanged(func(state transport.NegotiationState) { + negotiationState.Store(state) + }) + + // initial offer + transportA.Negotiate(true) + require.Eventually(t, func() bool { + state, ok := negotiationState.Load().(transport.NegotiationState) + if !ok { + return false + } + + return state == transport.NegotiationStateRemote + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRemote") + + // second try, should've flipped transport status to retry + transportA.Negotiate(true) + require.Eventually(t, func() bool { + state, ok := negotiationState.Load().(transport.NegotiationState) + if !ok { + return false + } + + return state == transport.NegotiationStateRetry + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") + + // third try, should've stayed at retry + transportA.Negotiate(true) + time.Sleep(100 * time.Millisecond) // some time to process the negotiate event + require.Eventually(t, func() bool { + state, ok := negotiationState.Load().(transport.NegotiationState) + if !ok { + return false + } + + return state == transport.NegotiationStateRetry + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRetry") + + require.Eventually(t, func() bool { + _, ok := firstOffer.Load().(*webrtc.SessionDescription) + if !ok { + return false + } + if firstOfferId.Load() == 0 { + return false + } + return true + }, 10*time.Second, 10*time.Millisecond, "first offer not received yet") + + handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32, _midToTrackID map[string]string) error { + transportA.HandleRemoteDescription(answer, answerId) + return nil + }) + transportB.HandleRemoteDescription(*firstOffer.Load().(*webrtc.SessionDescription), firstOfferId.Load()) + + require.Eventually(t, func() bool { + return transportA.IsEstablished() + }, 10*time.Second, time.Millisecond*10, "transportA is not established") + require.Eventually(t, func() bool { + return transportB.IsEstablished() + }, 10*time.Second, time.Millisecond*10, "transportB is not established") + + // offerer should send another offer after processing the answer + // as there were forced negotiations a couple of time above + require.Eventually(t, func() bool { + state, ok := negotiationState.Load().(transport.NegotiationState) + if !ok { + return false + } + + return state == transport.NegotiationStateRemote + }, 10*time.Second, 10*time.Millisecond, "negotiation state does not match NegotiateStateRemote") + _, ok := secondOffer.Load().(*webrtc.SessionDescription) + require.True(t, ok) + + transportA.Close() + transportB.Close() +} + +func TestFirstOfferMissedDuringICERestart(t *testing.T) { + params := TransportParams{ + Config: &WebRTCConfig{}, + IsOfferer: true, + } + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) + require.NoError(t, err) + _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) + require.NoError(t, err) + + paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB + paramsB.IsOfferer = false + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) + + // exchange ICE + handleICEExchange(t, transportA, transportB, handlerA, handlerB) + + // first offer missed + var firstOfferReceived atomic.Bool + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, _offerId uint32, _midToTrackID map[string]string) error { + firstOfferReceived.Store(true) + return nil + }) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return firstOfferReceived.Load() + }, 10*time.Second, 10*time.Millisecond, "first offer not received") + + // set offer/answer with restart ICE, will negotiate twice, + // first one is recover from missed offer + // second one is restartICE + handlerB.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32, _midToTrackID map[string]string) error { + transportA.HandleRemoteDescription(answer, answerId) + return nil + }) + + var offerCount atomic.Int32 + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) error { + offerCount.Inc() + + // the second offer is a ice restart offer, so we wait transportB complete the ice gathering + if transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { + require.Eventually(t, func() bool { + return transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateComplete + }, 10*time.Second, time.Millisecond*10) + } + + transportB.HandleRemoteDescription(sd, offerId) + return nil + }) + + // first establish connection + transportA.ICERestart() + + // ensure we are connected + require.Eventually(t, func() bool { + return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && + transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && + offerCount.Load() == 2 + }, testutils.ConnectTimeout, 10*time.Millisecond, "transport did not connect") + + transportA.Close() + transportB.Close() +} + +func TestFirstAnswerMissedDuringICERestart(t *testing.T) { + params := TransportParams{ + Config: &WebRTCConfig{}, + IsOfferer: true, + } + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) + require.NoError(t, err) + _, err = transportA.pc.CreateDataChannel(LossyDataChannel, nil) + require.NoError(t, err) + + paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB + paramsB.IsOfferer = false + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) + + // exchange ICE + handleICEExchange(t, transportA, transportB, handlerA, handlerB) + + // first answer missed + var firstAnswerReceived atomic.Bool + handlerB.OnAnswerCalls(func(sd webrtc.SessionDescription, answerId uint32, _midToTrackID map[string]string) error { + if firstAnswerReceived.Load() { + transportA.HandleRemoteDescription(sd, answerId) + } else { + // do not send first answer so that remote misses the first answer + firstAnswerReceived.Store(true) + } + return nil + }) + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) error { + transportB.HandleRemoteDescription(sd, offerId) + return nil + }) + + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return transportB.pc.SignalingState() == webrtc.SignalingStateStable && firstAnswerReceived.Load() + }, time.Second, 10*time.Millisecond, "transportB signaling state did not go to stable") + + // set offer/answer with restart ICE, will negotiate twice, + // first one is recover from missed offer + // second one is restartICE + var offerCount atomic.Int32 + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) error { + offerCount.Inc() + + // the second offer is a ice restart offer, so we wait for transportB to complete ICE gathering + if transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateGathering { + require.Eventually(t, func() bool { + return transportB.pc.ICEGatheringState() == webrtc.ICEGatheringStateComplete + }, 10*time.Second, time.Millisecond*10) + } + + transportB.HandleRemoteDescription(sd, offerId) + return nil + }) + + // first establish connection + transportA.ICERestart() + + // ensure we are connected + require.Eventually(t, func() bool { + return transportA.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && + transportB.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected && + offerCount.Load() == 2 + }, testutils.ConnectTimeout, 10*time.Millisecond, "transport did not connect") + + transportA.Close() + transportB.Close() +} + +func TestNegotiationFailed(t *testing.T) { + params := TransportParams{ + Config: &WebRTCConfig{}, + IsOfferer: true, + } + + paramsA := params + handlerA := &transportfakes.FakeHandler{} + paramsA.Handler = handlerA + transportA, err := NewPCTransport(paramsA) + require.NoError(t, err) + _, err = transportA.pc.CreateDataChannel(ReliableDataChannel, nil) + require.NoError(t, err) + + paramsB := params + handlerB := &transportfakes.FakeHandler{} + paramsB.Handler = handlerB + paramsB.IsOfferer = false + transportB, err := NewPCTransport(paramsB) + require.NoError(t, err) + + // exchange ICE + handleICEExchange(t, transportA, transportB, handlerA, handlerB) + + // wait for transport to be connected before maiming the signalling channel + connectTransports(t, transportA, transportB, handlerA, handlerB, false, 1, 1) + + // reset OnOffer to force a negotiation failure + handlerA.OnOfferCalls(func(sd webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) error { + return nil + }) + var failed atomic.Int32 + handlerA.OnNegotiationFailedCalls(func() { + failed.Inc() + }) + transportA.Negotiate(true) + require.Eventually(t, func() bool { + return failed.Load() == 1 + }, negotiationFailedTimeout+time.Second, 10*time.Millisecond, "negotiation failed") + + transportA.Close() +} + +func TestFilteringCandidates(t *testing.T) { + params := TransportParams{ + Config: &WebRTCConfig{}, + EnabledCodecs: []*livekit.Codec{ + {Mime: mime.MimeTypeOpus.String()}, + {Mime: mime.MimeTypeVP8.String()}, + {Mime: mime.MimeTypeH264.String()}, + }, + Handler: &transportfakes.FakeHandler{}, + } + transport, err := NewPCTransport(params) + require.NoError(t, err) + + _, err = transport.pc.CreateDataChannel(ReliableDataChannel, nil) + require.NoError(t, err) + + _, err = transport.pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio) + require.NoError(t, err) + + _, err = transport.pc.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo) + require.NoError(t, err) + + offer, err := transport.pc.CreateOffer(nil) + require.NoError(t, err) + + offerGatheringComplete := webrtc.GatheringCompletePromise(transport.pc) + require.NoError(t, transport.pc.SetLocalDescription(offer)) + <-offerGatheringComplete + + // should not filter out UDP candidates if TCP is not preferred + offer = *transport.pc.LocalDescription() + filteredOffer := transport.filterCandidates(offer, false, true) + require.EqualValues(t, offer.SDP, filteredOffer.SDP) + + parsed, err := offer.Unmarshal() + require.NoError(t, err) + + // add a couple of TCP candidates + done := false + for _, m := range parsed.MediaDescriptions { + for _, a := range m.Attributes { + if a.Key == sdp.AttrKeyCandidate { + for idx, aa := range m.Attributes { + if aa.Key == sdp.AttrKeyEndOfCandidates { + modifiedAttributes := make([]sdp.Attribute, idx) + copy(modifiedAttributes, m.Attributes[:idx]) + modifiedAttributes = append(modifiedAttributes, []sdp.Attribute{ + { + Key: sdp.AttrKeyCandidate, + Value: "054225987 1 tcp 2124414975 159.203.70.248 7881 typ host tcptype passive", + }, + { + Key: sdp.AttrKeyCandidate, + Value: "054225987 2 tcp 2124414975 159.203.70.248 7881 typ host tcptype passive", + }, + }...) + m.Attributes = append(modifiedAttributes, m.Attributes[idx:]...) + done = true + break + } + } + } + if done { + break + } + } + if done { + break + } + } + bytes, err := parsed.Marshal() + require.NoError(t, err) + offer.SDP = string(bytes) + + parsed, err = offer.Unmarshal() + require.NoError(t, err) + + getNumTransportTypeCandidates := func(sd *sdp.SessionDescription) (int, int) { + numUDPCandidates := 0 + numTCPCandidates := 0 + for _, a := range sd.Attributes { + if a.Key == sdp.AttrKeyCandidate { + if strings.Contains(a.Value, "udp") { + numUDPCandidates++ + } + if strings.Contains(a.Value, "tcp") { + numTCPCandidates++ + } + } + } + for _, m := range sd.MediaDescriptions { + for _, a := range m.Attributes { + if a.Key == sdp.AttrKeyCandidate { + if strings.Contains(a.Value, "udp") { + numUDPCandidates++ + } + if strings.Contains(a.Value, "tcp") { + numTCPCandidates++ + } + } + } + } + return numUDPCandidates, numTCPCandidates + } + udp, tcp := getNumTransportTypeCandidates(parsed) + require.NotZero(t, udp) + require.Equal(t, 2, tcp) + + transport.SetPreferTCP(true) + filteredOffer = transport.filterCandidates(offer, true, true) + parsed, err = filteredOffer.Unmarshal() + require.NoError(t, err) + udp, tcp = getNumTransportTypeCandidates(parsed) + require.Zero(t, udp) + require.Equal(t, 2, tcp) + + transport.Close() +} + +func handleICEExchange(t *testing.T, a, b *PCTransport, ah, bh *transportfakes.FakeHandler) { + ah.OnICECandidateCalls(func(candidate *webrtc.ICECandidate, target livekit.SignalTarget) error { + if candidate == nil { + return nil + } + t.Logf("got ICE candidate from A: %v", candidate) + b.AddICECandidate(candidate.ToJSON()) + return nil + }) + bh.OnICECandidateCalls(func(candidate *webrtc.ICECandidate, target livekit.SignalTarget) error { + if candidate == nil { + return nil + } + t.Logf("got ICE candidate from B: %v", candidate) + a.AddICECandidate(candidate.ToJSON()) + return nil + }) +} + +func connectTransports(t *testing.T, offerer, answerer *PCTransport, offererHandler, answererHandler *transportfakes.FakeHandler, isICERestart bool, expectedOfferCount int32, expectedAnswerCount int32) { + var offerCount atomic.Int32 + var answerCount atomic.Int32 + answererHandler.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32, _midToTrackID map[string]string) error { + answerCount.Inc() + offerer.HandleRemoteDescription(answer, answerId) + return nil + }) + + offererHandler.OnOfferCalls(func(offer webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) error { + offerCount.Inc() + answerer.HandleRemoteDescription(offer, offerId) + return nil + }) + + if isICERestart { + offerer.ICERestart() + } else { + offerer.Negotiate(true) + } + + require.Eventually(t, func() bool { + return offerCount.Load() == expectedOfferCount + }, 10*time.Second, time.Millisecond*10, fmt.Sprintf("offer count mismatch, expected: %d, actual: %d", expectedOfferCount, offerCount.Load())) + + require.Eventually(t, func() bool { + return offerer.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected + }, 10*time.Second, time.Millisecond*10, "offerer did not become connected") + + require.Eventually(t, func() bool { + return answerCount.Load() == expectedAnswerCount + }, 10*time.Second, time.Millisecond*10, fmt.Sprintf("answer count mismatch, expected: %d, actual: %d", expectedAnswerCount, answerCount.Load())) + + require.Eventually(t, func() bool { + return answerer.pc.ICEConnectionState() == webrtc.ICEConnectionStateConnected + }, 10*time.Second, time.Millisecond*10, "answerer did not become connected") + + transportsConnected := untilTransportsConnected(offererHandler, answererHandler) + transportsConnected.Wait() +} + +func untilTransportsConnected(transports ...*transportfakes.FakeHandler) *sync.WaitGroup { + var triggered sync.WaitGroup + triggered.Add(len(transports)) + + for _, t := range transports { + var done atomic.Value + done.Store(false) + hdlr := func() { + if val, ok := done.Load().(bool); ok && !val { + done.Store(true) + triggered.Done() + } + } + + if t.OnInitialConnectedCallCount() != 0 { + hdlr() + } + t.OnInitialConnectedCalls(hdlr) + } + return &triggered +} + +func TestConfigureAudioTransceiver(t *testing.T) { + for _, testcase := range []struct { + nack bool + stereo bool + }{ + {false, false}, + {true, false}, + {false, true}, + {true, true}, + } { + t.Run(fmt.Sprintf("nack=%v,stereo=%v", testcase.nack, testcase.stereo), func(t *testing.T) { + var me webrtc.MediaEngine + registerCodecs(&me, []*livekit.Codec{{Mime: mime.MimeTypeOpus.String()}}, RTCPFeedbackConfig{Audio: []webrtc.RTCPFeedback{{Type: webrtc.TypeRTCPFBNACK}}}, false) + pc, err := webrtc.NewAPI(webrtc.WithMediaEngine(&me)).NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + tr, err := pc.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}) + require.NoError(t, err) + + configureSenderAudio(tr, testcase.stereo, testcase.nack) + codecs := tr.Sender().GetParameters().Codecs + for _, codec := range codecs { + if mime.IsMimeTypeStringOpus(codec.MimeType) { + require.Equal(t, testcase.stereo, strings.Contains(codec.SDPFmtpLine, "sprop-stereo=1")) + var nackEnabled bool + for _, fb := range codec.RTCPFeedback { + if fb.Type == webrtc.TypeRTCPFBNACK { + nackEnabled = true + break + } + } + require.Equal(t, testcase.nack, nackEnabled) + } + } + }) + } +} diff --git a/livekit/pkg/rtc/transportmanager.go b/livekit/pkg/rtc/transportmanager.go new file mode 100644 index 0000000..c3c01b7 --- /dev/null +++ b/livekit/pkg/rtc/transportmanager.go @@ -0,0 +1,1034 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "context" + "io" + "math/bits" + "sync" + "time" + + "github.com/pion/rtcp" + "github.com/pion/sctp" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "github.com/pkg/errors" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc/transport" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" + "github.com/livekit/livekit-server/pkg/sfu/interceptor" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +const ( + failureCountThreshold = 2 + preferNextByFailureWindow = time.Minute + + // when RR report loss percentage over this threshold, we consider it is a unstable event + udpLossFracUnstable = 25 + // if in last 32 times RR, the unstable report count over this threshold, the connection is unstable + udpLossUnstableCountThreshold = 20 +) + +// ------------------------------- + +type TransportManagerTransportHandler struct { + transport.Handler + t *TransportManager + logger logger.Logger +} + +func (h TransportManagerTransportHandler) OnFailed(isShortLived bool, iceConnectionInfo *types.ICEConnectionInfo) { + if isShortLived { + h.logger.Infow("short ice connection", connectionDetailsFields([]*types.ICEConnectionInfo{iceConnectionInfo})...) + } + h.t.handleConnectionFailed(isShortLived) + h.Handler.OnFailed(isShortLived, iceConnectionInfo) +} + +// ------------------------------- + +type TransportManagerParams struct { + SubscriberAsPrimary bool + UseSinglePeerConnection bool + Config *WebRTCConfig + Twcc *twcc.Responder + ProtocolVersion types.ProtocolVersion + CongestionControlConfig config.CongestionControlConfig + EnabledSubscribeCodecs []*livekit.Codec + EnabledPublishCodecs []*livekit.Codec + SimTracks map[uint32]interceptor.SimulcastTrackInfo + ClientInfo ClientInfo + Migration bool + AllowTCPFallback bool + TCPFallbackRTTThreshold int + AllowUDPUnstableFallback bool + TURNSEnabled bool + AllowPlayoutDelay bool + DataChannelMaxBufferedAmount uint64 + DatachannelSlowThreshold int + DatachannelLossyTargetLatency time.Duration + Logger logger.Logger + PublisherHandler transport.Handler + SubscriberHandler transport.Handler + DataChannelStats *telemetry.BytesTrackStats + UseOneShotSignallingMode bool + FireOnTrackBySdp bool + EnableDataTracks bool +} + +type TransportManager struct { + params TransportManagerParams + + lock sync.RWMutex + + publisher *PCTransport + subscriber *PCTransport + failureCount int + isTransportReconfigured bool + lastFailure time.Time + lastSignalAt time.Time + signalSourceValid atomic.Bool + + pendingOfferPublisher *webrtc.SessionDescription + pendingOfferIdPublisher uint32 + pendingDataChannelsPublisher []*livekit.DataChannelInfo + iceConfig *livekit.ICEConfig + + mediaLossProxy *MediaLossProxy + udpLossUnstableCount uint32 + signalingRTT, udpRTT uint32 + + onICEConfigChanged func(iceConfig *livekit.ICEConfig) + + droppedBySlowReaderCount atomic.Uint32 +} + +func NewTransportManager(params TransportManagerParams) (*TransportManager, error) { + if params.Logger == nil { + params.Logger = logger.GetLogger() + } + t := &TransportManager{ + params: params, + mediaLossProxy: NewMediaLossProxy(MediaLossProxyParams{Logger: params.Logger}), + iceConfig: &livekit.ICEConfig{}, + } + t.mediaLossProxy.OnMediaLossUpdate(t.onMediaLossUpdate) + + lgr := LoggerWithPCTarget(params.Logger, livekit.SignalTarget_PUBLISHER) + publisher, err := NewPCTransport(TransportParams{ + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + Twcc: params.Twcc, + DirectionConfig: params.Config.Publisher, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledPublishCodecs, + Logger: lgr, + SimTracks: params.SimTracks, + ClientInfo: params.ClientInfo, + IsSendSide: params.UseOneShotSignallingMode || params.UseSinglePeerConnection, + AllowPlayoutDelay: params.AllowPlayoutDelay, + Transport: livekit.SignalTarget_PUBLISHER, + Handler: TransportManagerTransportHandler{params.PublisherHandler, t, lgr}, + UseOneShotSignallingMode: params.UseOneShotSignallingMode, + DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + DatachannelLossyTargetLatency: params.DatachannelLossyTargetLatency, + FireOnTrackBySdp: params.FireOnTrackBySdp, + EnableDataTracks: params.EnableDataTracks, + }) + if err != nil { + return nil, err + } + t.publisher = publisher + + if !t.params.UseOneShotSignallingMode && !t.params.UseSinglePeerConnection { + lgr := LoggerWithPCTarget(params.Logger, livekit.SignalTarget_SUBSCRIBER) + subscriber, err := NewPCTransport(TransportParams{ + ProtocolVersion: params.ProtocolVersion, + Config: params.Config, + DirectionConfig: params.Config.Subscriber, + CongestionControlConfig: params.CongestionControlConfig, + EnabledCodecs: params.EnabledSubscribeCodecs, + Logger: lgr, + ClientInfo: params.ClientInfo, + IsOfferer: true, + IsSendSide: true, + AllowPlayoutDelay: params.AllowPlayoutDelay, + DataChannelMaxBufferedAmount: params.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: params.DatachannelSlowThreshold, + DatachannelLossyTargetLatency: params.DatachannelLossyTargetLatency, + Transport: livekit.SignalTarget_SUBSCRIBER, + Handler: TransportManagerTransportHandler{params.SubscriberHandler, t, lgr}, + FireOnTrackBySdp: params.FireOnTrackBySdp, + EnableDataTracks: params.EnableDataTracks, + }) + if err != nil { + return nil, err + } + t.subscriber = subscriber + } + if !t.params.Migration && t.params.SubscriberAsPrimary { + if err := t.createDataChannelsForSubscriber(nil); err != nil { + return nil, err + } + } + + t.signalSourceValid.Store(true) + return t, nil +} + +func (t *TransportManager) Close() { + if t.publisher != nil { + t.publisher.Close() + } + if t.subscriber != nil { + t.subscriber.Close() + } +} + +func (t *TransportManager) SubscriberClose() { + var subscriberClosed atomic.Bool + time.AfterFunc(time.Minute, func() { // CLOSE-DEBUG-CLEANUP + if !subscriberClosed.Load() { + t.params.Logger.Infow( + "transport maanager subscriber close timeout", + "subscriberClosed", subscriberClosed.Load(), + ) + } + }) + t.subscriber.Close() + subscriberClosed.Store(true) +} + +func (t *TransportManager) HasPublisherEverConnected() bool { + return t.publisher.HasEverConnected() +} + +func (t *TransportManager) IsPublisherEstablished() bool { + return t.publisher.IsEstablished() +} + +func (t *TransportManager) GetPublisherRTT() (float64, bool) { + return t.publisher.GetRTT() +} + +func (t *TransportManager) GetPublisherMid(rtpReceiver *webrtc.RTPReceiver) string { + return t.publisher.GetMid(rtpReceiver) +} + +func (t *TransportManager) GetPublisherRTPTransceiver(mid string) *webrtc.RTPTransceiver { + return t.publisher.GetRTPTransceiver(mid) +} + +func (t *TransportManager) GetPublisherRTPReceiver(mid string) *webrtc.RTPReceiver { + return t.publisher.GetRTPReceiver(mid) +} + +func (t *TransportManager) WritePublisherRTCP(pkts []rtcp.Packet) error { + return t.publisher.WriteRTCP(pkts) +} + +func (t *TransportManager) GetSubscriberRTT() (float64, bool) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.GetRTT() + } else { + return t.subscriber.GetRTT() + } +} + +func (t *TransportManager) HasSubscriberEverConnected() bool { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.HasEverConnected() + } else { + return t.subscriber.HasEverConnected() + } +} + +func (t *TransportManager) AddTrackLocal( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, +) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.AddTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) + } else { + return t.subscriber.AddTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) + } +} + +func (t *TransportManager) AddTransceiverFromTrackLocal( + trackLocal webrtc.TrackLocal, + params types.AddTrackParams, + enabledCodecs []*livekit.Codec, + rtcpFeedbackConfig RTCPFeedbackConfig, +) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.AddTransceiverFromTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) + } else { + return t.subscriber.AddTransceiverFromTrack(trackLocal, params, enabledCodecs, rtcpFeedbackConfig) + } +} + +func (t *TransportManager) RemoveTrackLocal(sender *webrtc.RTPSender) error { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.RemoveTrack(sender) + } else { + return t.subscriber.RemoveTrack(sender) + } +} + +func (t *TransportManager) WriteSubscriberRTCP(pkts []rtcp.Packet) error { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.WriteRTCP(pkts) + } else { + return t.subscriber.WriteRTCP(pkts) + } +} + +func (t *TransportManager) GetSubscriberPacer() pacer.Pacer { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + return t.publisher.GetPacer() + } else { + return t.subscriber.GetPacer() + } +} + +func (t *TransportManager) AddSubscribedTrack(subTrack types.SubscribedTrack) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.AddTrackToStreamAllocator(subTrack) + } else { + t.subscriber.AddTrackToStreamAllocator(subTrack) + } +} + +func (t *TransportManager) RemoveSubscribedTrack(subTrack types.SubscribedTrack) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.RemoveTrackFromStreamAllocator(subTrack) + } else { + t.subscriber.RemoveTrackFromStreamAllocator(subTrack) + } +} + +func (t *TransportManager) SendDataMessage(kind livekit.DataPacket_Kind, data []byte) error { + // downstream data is sent via primary peer connection + return t.handleSendDataResult(t.getTransport(true).SendDataMessage(kind, data), kind.String(), len(data)) +} + +func (t *TransportManager) SendDataMessageUnlabeled(data []byte, useRaw bool, sender livekit.ParticipantIdentity) error { + // downstream data is sent via primary peer connection + return t.handleSendDataResult( + t.getTransport(true).SendDataMessageUnlabeled(data, useRaw, sender), + "unlabeled", + len(data), + ) +} + +func (t *TransportManager) handleSendDataResult(err error, kind string, size int) error { + if err != nil { + if !utils.ErrorIsOneOf( + err, + io.ErrClosedPipe, + sctp.ErrStreamClosed, + ErrTransportFailure, + ErrDataChannelBufferFull, + context.DeadlineExceeded, + datachannel.ErrDataDroppedByHighBufferedAmount, + ) { + if errors.Is(err, datachannel.ErrDataDroppedBySlowReader) { + droppedBySlowReaderCount := t.droppedBySlowReaderCount.Inc() + if (droppedBySlowReaderCount-1)%100 == 0 { + t.params.Logger.Infow( + "drop data message by slow reader", + "error", err, + "kind", kind, + "count", droppedBySlowReaderCount, + ) + } + } else { + t.params.Logger.Warnw("send data message error", err) + } + } + if utils.ErrorIsOneOf(err, sctp.ErrStreamClosed, io.ErrClosedPipe) { + if t.params.SubscriberAsPrimary { + t.params.SubscriberHandler.OnDataSendError(err) + } else { + t.params.PublisherHandler.OnDataSendError(err) + } + } + } else { + t.params.DataChannelStats.AddBytes(uint64(size), true) + } + + return err +} + +func (t *TransportManager) createDataChannelsForSubscriber(pendingDataChannels []*livekit.DataChannelInfo) error { + var ( + reliableID, lossyID, dataTrackID uint16 + reliableIDPtr, lossyIDPtr, dataTrackIDPtr *uint16 + ) + + // + // For old version migration clients, they don't send subscriber data channel info + // so we need to create data channels with default ID and don't negotiate as client already has + // data channels with default ID. + // + // For new version migration clients, we create data channels with new ID and negotiate with client + // + for _, dc := range pendingDataChannels { + switch dc.Label { + case ReliableDataChannel: + // pion use step 2 for auto generated ID, so we need to add 6 to avoid conflict + reliableID = uint16(dc.Id) + 6 + reliableIDPtr = &reliableID + case LossyDataChannel: + lossyID = uint16(dc.Id) + 6 + lossyIDPtr = &lossyID + case DataTrackDataChannel: + dataTrackID = uint16(dc.Id) + 6 + dataTrackIDPtr = &dataTrackID + } + } + + ordered := true + negotiated := t.params.Migration && reliableIDPtr == nil + if err := t.subscriber.CreateDataChannel(ReliableDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + ID: reliableIDPtr, + Negotiated: &negotiated, + }); err != nil { + return err + } + + ordered = false + retransmits := uint16(0) + negotiated = t.params.Migration && lossyIDPtr == nil + if err := t.subscriber.CreateDataChannel(LossyDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &retransmits, + ID: lossyIDPtr, + Negotiated: &negotiated, + }); err != nil { + return err + } + + negotiated = t.params.Migration && dataTrackIDPtr == nil + if err := t.subscriber.CreateDataChannel(DataTrackDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &retransmits, + ID: dataTrackIDPtr, + Negotiated: &negotiated, + }); err != nil { + return err + } + + return nil +} + +func (t *TransportManager) GetUnmatchMediaForOffer(parsedOffer *sdp.SessionDescription, mediaType string) (unmatched []*sdp.MediaDescription, err error) { + var lastMatchedMid string + if lastAnswer := t.publisher.CurrentLocalDescription(); lastAnswer != nil { + parsedAnswer, err1 := lastAnswer.Unmarshal() + if err1 != nil { + // should not happen + t.params.Logger.Errorw("failed to parse last answer", err1) + return unmatched, err1 + } + + for i := len(parsedAnswer.MediaDescriptions) - 1; i >= 0; i-- { + media := parsedAnswer.MediaDescriptions[i] + if media.MediaName.Media == mediaType { + lastMatchedMid, _ = media.Attribute(sdp.AttrKeyMID) + break + } + } + } + + for i := len(parsedOffer.MediaDescriptions) - 1; i >= 0; i-- { + media := parsedOffer.MediaDescriptions[i] + if media.MediaName.Media == mediaType { + mid, _ := media.Attribute(sdp.AttrKeyMID) + if mid == lastMatchedMid { + break + } + unmatched = append(unmatched, media) + } + } + + return +} + +func (t *TransportManager) LastPublisherOffer() *webrtc.SessionDescription { + return t.publisher.CurrentRemoteDescription() +} + +func (t *TransportManager) LastPublisherOfferPending() *webrtc.SessionDescription { + return t.publisher.PendingRemoteDescription() +} + +func (t *TransportManager) HandleOffer(offer webrtc.SessionDescription, offerId uint32, shouldPend bool) error { + t.lock.Lock() + if shouldPend { + t.pendingOfferPublisher = &offer + t.pendingOfferIdPublisher = offerId + t.lock.Unlock() + return nil + } + t.lock.Unlock() + + return t.publisher.HandleRemoteDescription(offer, offerId) +} + +func (t *TransportManager) GetAnswer() (webrtc.SessionDescription, uint32, error) { + return t.publisher.GetAnswer() +} + +func (t *TransportManager) GetPublisherICESessionUfrag() (string, error) { + return t.publisher.GetICESessionUfrag() +} + +func (t *TransportManager) HandleICETrickleSDPFragment(sdpFragment string) error { + return t.publisher.HandleICETrickleSDPFragment(sdpFragment) +} + +func (t *TransportManager) HandleICERestartSDPFragment(sdpFragment string) (string, error) { + return t.publisher.HandleICERestartSDPFragment(sdpFragment) +} + +func (t *TransportManager) ProcessPendingPublisherOffer() { + t.lock.Lock() + pendingOffer := t.pendingOfferPublisher + t.pendingOfferPublisher = nil + + pendingOfferId := t.pendingOfferIdPublisher + t.pendingOfferIdPublisher = 0 + t.lock.Unlock() + + if pendingOffer != nil { + t.HandleOffer(*pendingOffer, pendingOfferId, false) + } +} + +func (t *TransportManager) HandleAnswer(answer webrtc.SessionDescription, answerId uint32) { + t.subscriber.HandleRemoteDescription(answer, answerId) +} + +// AddICECandidate adds candidates for remote peer +func (t *TransportManager) AddICECandidate(candidate webrtc.ICECandidateInit, target livekit.SignalTarget) { + switch target { + case livekit.SignalTarget_PUBLISHER: + t.publisher.AddICECandidate(candidate) + case livekit.SignalTarget_SUBSCRIBER: + t.subscriber.AddICECandidate(candidate) + default: + err := errors.New("unknown signal target") + t.params.Logger.Errorw("ice candidate for unknown signal target", err, "target", target) + } +} + +func (t *TransportManager) NegotiateSubscriber(force bool) { + if t.subscriber != nil { + t.subscriber.Negotiate(force) + } else { + t.publisher.Negotiate(force) + } +} + +func (t *TransportManager) HandleClientReconnect(reason livekit.ReconnectReason) { + var ( + isShort bool + duration time.Duration + resetShortConnection bool + ) + switch reason { + case livekit.ReconnectReason_RR_PUBLISHER_FAILED: + if t.publisher != nil { + resetShortConnection = true + isShort, duration = t.publisher.IsShortConnection(time.Now()) + } + + case livekit.ReconnectReason_RR_SUBSCRIBER_FAILED: + if t.subscriber != nil { + resetShortConnection = true + isShort, duration = t.subscriber.IsShortConnection(time.Now()) + } + } + + if isShort { + t.lock.Lock() + t.resetTransportConfigureLocked(false) + t.lock.Unlock() + t.params.Logger.Infow("short connection by client ice restart", "duration", duration, "reason", reason) + t.handleConnectionFailed(isShort) + } + + if resetShortConnection { + if t.publisher != nil { + t.publisher.ResetShortConnOnICERestart() + } + if t.subscriber != nil { + t.subscriber.ResetShortConnOnICERestart() + } + } +} + +func (t *TransportManager) ICERestart(iceConfig *livekit.ICEConfig) error { + t.SetICEConfig(iceConfig) + + if t.subscriber != nil { + return t.subscriber.ICERestart() + } + + return nil +} + +func (t *TransportManager) OnICEConfigChanged(f func(iceConfig *livekit.ICEConfig)) { + t.lock.Lock() + t.onICEConfigChanged = f + t.lock.Unlock() +} + +func (t *TransportManager) SetICEConfig(iceConfig *livekit.ICEConfig) { + if iceConfig != nil { + t.configureICE(iceConfig, true) + } +} + +func (t *TransportManager) GetICEConfig() *livekit.ICEConfig { + t.lock.RLock() + defer t.lock.RUnlock() + if t.iceConfig == nil { + return nil + } + return utils.CloneProto(t.iceConfig) +} + +func (t *TransportManager) resetTransportConfigureLocked(reconfigured bool) { + t.failureCount = 0 + t.isTransportReconfigured = reconfigured + t.udpLossUnstableCount = 0 + t.lastFailure = time.Time{} +} + +func (t *TransportManager) configureICE(iceConfig *livekit.ICEConfig, reset bool) { + t.lock.Lock() + isEqual := proto.Equal(t.iceConfig, iceConfig) + if reset || !isEqual { + t.resetTransportConfigureLocked(!reset) + } + + if isEqual { + t.lock.Unlock() + return + } + + t.params.Logger.Infow("setting ICE config", "iceConfig", logger.Proto(iceConfig)) + onICEConfigChanged := t.onICEConfigChanged + t.iceConfig = iceConfig + t.lock.Unlock() + + if iceConfig.PreferenceSubscriber != livekit.ICECandidateType_ICT_NONE { + t.mediaLossProxy.OnMediaLossUpdate(nil) + } + + if t.publisher != nil { + t.publisher.SetPreferTCP(iceConfig.PreferencePublisher == livekit.ICECandidateType_ICT_TCP) + } + if t.subscriber != nil { + t.subscriber.SetPreferTCP(iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TCP) + } + + if onICEConfigChanged != nil { + onICEConfigChanged(iceConfig) + } +} + +func (t *TransportManager) SubscriberAsPrimary() bool { + return t.params.SubscriberAsPrimary +} + +func (t *TransportManager) GetICEConnectionInfo() []*types.ICEConnectionInfo { + infos := make([]*types.ICEConnectionInfo, 0, 2) + for _, pc := range []*PCTransport{t.publisher, t.subscriber} { + if pc == nil { + continue + } + + info := pc.GetICEConnectionInfo() + if info.HasCandidates() { + infos = append(infos, info) + } + } + return infos +} + +func (t *TransportManager) GetDataTrackTransport() types.DataTrackTransport { + return t.getTransport(true) +} + +func (t *TransportManager) getTransport(isPrimary bool) *PCTransport { + switch { + case t.publisher == nil: + return t.subscriber + + case t.subscriber == nil: + return t.publisher + + default: + pcTransport := t.publisher + if (isPrimary && t.params.SubscriberAsPrimary) || (!isPrimary && !t.params.SubscriberAsPrimary) { + pcTransport = t.subscriber + } + + return pcTransport + } +} + +func (t *TransportManager) getLowestPriorityConnectionType() types.ICEConnectionType { + switch { + case t.publisher == nil: + return t.subscriber.GetICEConnectionType() + + case t.subscriber == nil: + return t.publisher.GetICEConnectionType() + + default: + ctype := t.publisher.GetICEConnectionType() + if stype := t.subscriber.GetICEConnectionType(); stype > ctype { + ctype = stype + } + return ctype + } +} + +func (t *TransportManager) handleConnectionFailed(isShortLived bool) { + if !t.params.AllowTCPFallback || t.params.UseOneShotSignallingMode { + return + } + + t.lock.Lock() + if t.isTransportReconfigured { + t.lock.Unlock() + return + } + + lastSignalSince := time.Since(t.lastSignalAt) + signalValid := t.signalSourceValid.Load() + if !t.hasRecentSignalLocked() || !signalValid { + // the failed might cause by network interrupt because signal closed or we have not seen any signal in the time window, + // so don't switch to next candidate type + t.params.Logger.Debugw( + "ignoring prefer candidate check by ICE failure because signal connection interrupted", + "lastSignalSince", lastSignalSince, + "signalValid", signalValid, + ) + t.failureCount = 0 + t.lastFailure = time.Time{} + t.lock.Unlock() + return + } + + lowestPriorityConnectionType := t.getLowestPriorityConnectionType() + + // + // Checking only `PreferenceSubscriber` field although any connection failure (PUBLISHER OR SUBSCRIBER) will + // flow through here. + // + // As both transports are switched to the same type on any failure, checking just subscriber should be fine. + // + getNext := func(ic *livekit.ICEConfig) livekit.ICECandidateType { + switch lowestPriorityConnectionType { + case types.ICEConnectionTypeUDP: + // try ICE/TCP if ICE/UDP failed + if ic.PreferenceSubscriber == livekit.ICECandidateType_ICT_NONE { + if t.params.ClientInfo.SupportsICETCP() && t.canUseICETCP() { + return livekit.ICECandidateType_ICT_TCP + } else if t.params.TURNSEnabled { + // fallback to TURN/TLS if TCP is not supported + return livekit.ICECandidateType_ICT_TLS + } + } + + case types.ICEConnectionTypeTCP: + // try TURN/TLS if ICE/TCP failed, + // the configuration could have been ICT_NONE or ICT_TCP, + // in either case, fallback to TURN/TLS + if t.params.TURNSEnabled { + return livekit.ICECandidateType_ICT_TLS + } else { + // keep the current config + return ic.PreferenceSubscriber + } + + case types.ICEConnectionTypeTURN: + // TURN/TLS is the most permissive option, if that fails there is nowhere to go to + // the configuration could have been ICT_NONE or ICT_TLS, + // keep the current config + return ic.PreferenceSubscriber + } + return livekit.ICECandidateType_ICT_NONE + } + + var preferNext livekit.ICECandidateType + if isShortLived { + preferNext = getNext(t.iceConfig) + } else { + t.failureCount++ + lastFailure := t.lastFailure + t.lastFailure = time.Now() + if t.failureCount < failureCountThreshold || time.Since(lastFailure) > preferNextByFailureWindow { + t.lock.Unlock() + return + } + + preferNext = getNext(t.iceConfig) + } + + if preferNext == t.iceConfig.PreferenceSubscriber { + t.lock.Unlock() + return + } + + t.isTransportReconfigured = true + t.lock.Unlock() + + switch preferNext { + case livekit.ICECandidateType_ICT_TCP: + t.params.Logger.Debugw("prefer TCP transport on both peer connections") + + case livekit.ICECandidateType_ICT_TLS: + t.params.Logger.Debugw("prefer TLS transport both peer connections") + + case livekit.ICECandidateType_ICT_NONE: + t.params.Logger.Debugw("allowing all transports on both peer connections") + } + + // irrespective of which one fails, force prefer candidate on both as the other one might + // fail at a different time and cause another disruption + t.configureICE(&livekit.ICEConfig{ + PreferenceSubscriber: preferNext, + PreferencePublisher: preferNext, + }, false) +} + +func (t *TransportManager) SetMigrateInfo( + previousOffer *webrtc.SessionDescription, + previousAnswer *webrtc.SessionDescription, + dataChannels []*livekit.DataChannelInfo, +) { + t.lock.Lock() + t.pendingDataChannelsPublisher = make([]*livekit.DataChannelInfo, 0, len(dataChannels)) + pendingDataChannelsSubscriber := make([]*livekit.DataChannelInfo, 0, len(dataChannels)) + for _, dci := range dataChannels { + if dci.Target == livekit.SignalTarget_SUBSCRIBER { + pendingDataChannelsSubscriber = append(pendingDataChannelsSubscriber, dci) + } else { + t.pendingDataChannelsPublisher = append(t.pendingDataChannelsPublisher, dci) + } + } + t.lock.Unlock() + + if t.params.SubscriberAsPrimary { + if err := t.createDataChannelsForSubscriber(pendingDataChannelsSubscriber); err != nil { + t.params.Logger.Errorw("create subscriber data channels during migration failed", err) + } + } + + if t.params.UseSinglePeerConnection { + t.publisher.SetPreviousSdp(previousAnswer, previousOffer) + } else { + t.subscriber.SetPreviousSdp(previousOffer, previousAnswer) + } +} + +func (t *TransportManager) ProcessPendingPublisherDataChannels() { + t.lock.Lock() + pendingDataChannels := t.pendingDataChannelsPublisher + t.pendingDataChannelsPublisher = nil + t.lock.Unlock() + + ordered := true + negotiated := true + + for _, ci := range pendingDataChannels { + var ( + dcLabel string + dcID uint16 + dcExisting bool + err error + ) + switch ci.Label { + case ReliableDataChannel: + id := uint16(ci.GetId()) + dcLabel, dcID, dcExisting, err = t.publisher.CreateDataChannelIfEmpty(ReliableDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + Negotiated: &negotiated, + ID: &id, + }) + case LossyDataChannel: + ordered = false + retransmits := uint16(0) + id := uint16(ci.GetId()) + dcLabel, dcID, dcExisting, err = t.publisher.CreateDataChannelIfEmpty(LossyDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &retransmits, + Negotiated: &negotiated, + ID: &id, + }) + case DataTrackDataChannel: + ordered = false + retransmits := uint16(0) + id := uint16(ci.GetId()) + dcLabel, dcID, dcExisting, err = t.publisher.CreateDataChannelIfEmpty(DataTrackDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &retransmits, + Negotiated: &negotiated, + ID: &id, + }) + } + if err != nil { + t.params.Logger.Errorw("create migrated data channel failed", err, "label", ci.Label) + } else if dcExisting { + t.params.Logger.Debugw("existing data channel during migration", "label", dcLabel, "id", dcID) + } else { + t.params.Logger.Debugw("create migrated data channel", "label", dcLabel, "id", dcID) + } + } +} + +func (t *TransportManager) HandleReceiverReport(dt *sfu.DownTrack, report *rtcp.ReceiverReport) { + t.mediaLossProxy.HandleMaxLossFeedback(dt, report) +} + +func (t *TransportManager) onMediaLossUpdate(loss uint8) { + if t.params.TCPFallbackRTTThreshold == 0 || !t.params.AllowUDPUnstableFallback { + return + } + t.lock.Lock() + t.udpLossUnstableCount <<= 1 + if loss >= uint8(255*udpLossFracUnstable/100) { + t.udpLossUnstableCount |= 1 + if bits.OnesCount32(t.udpLossUnstableCount) >= udpLossUnstableCountThreshold { + if t.udpRTT > 0 && t.signalingRTT < uint32(float32(t.udpRTT)*1.3) && int(t.signalingRTT) < t.params.TCPFallbackRTTThreshold && t.hasRecentSignalLocked() { + t.udpLossUnstableCount = 0 + t.lock.Unlock() + + t.params.Logger.Infow("udp connection unstable, switch to tcp", "signalingRTT", t.signalingRTT) + if t.params.UseSinglePeerConnection { + t.params.PublisherHandler.OnFailed(true, t.publisher.GetICEConnectionInfo()) + } else { + t.params.SubscriberHandler.OnFailed(true, t.subscriber.GetICEConnectionInfo()) + } + return + } + } + } + t.lock.Unlock() +} + +func (t *TransportManager) UpdateSignalingRTT(rtt uint32) { + t.lock.Lock() + t.signalingRTT = rtt + t.lock.Unlock() + if t.publisher != nil { + t.publisher.SetSignalingRTT(rtt) + } + if t.subscriber != nil { + t.subscriber.SetSignalingRTT(rtt) + } + + // TODO: considering using tcp rtt to calculate ice connection cost, if ice connection can't be established + // within 5 * tcp rtt(at least 5s), means udp traffic might be block/dropped, switch to tcp. + // Currently, most cases reported is that ice connected but subsequent connection, so left the thinking for now. +} + +func (t *TransportManager) UpdateMediaRTT(rtt uint32) { + t.lock.Lock() + if t.udpRTT == 0 { + t.udpRTT = rtt + } else { + t.udpRTT = uint32(int(t.udpRTT) + (int(rtt)-int(t.udpRTT))/2) + } + t.lock.Unlock() +} + +func (t *TransportManager) UpdateLastSeenSignal() { + t.lock.Lock() + t.lastSignalAt = time.Now() + t.lock.Unlock() +} + +func (t *TransportManager) SinceLastSignal() time.Duration { + t.lock.RLock() + defer t.lock.RUnlock() + return time.Since(t.lastSignalAt) +} + +func (t *TransportManager) LastSeenSignalAt() time.Time { + t.lock.RLock() + defer t.lock.RUnlock() + return t.lastSignalAt +} + +func (t *TransportManager) canUseICETCP() bool { + return t.params.TCPFallbackRTTThreshold == 0 || int(t.signalingRTT) < t.params.TCPFallbackRTTThreshold +} + +func (t *TransportManager) SetSignalSourceValid(valid bool) { + t.signalSourceValid.Store(valid) + t.params.Logger.Debugw("signal source valid", "valid", valid) +} + +func (t *TransportManager) SetSubscriberAllowPause(allowPause bool) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.SetAllowPauseOfStreamAllocator(allowPause) + } else { + t.subscriber.SetAllowPauseOfStreamAllocator(allowPause) + } +} + +func (t *TransportManager) SetSubscriberChannelCapacity(channelCapacity int64) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.SetChannelCapacityOfStreamAllocator(channelCapacity) + } else { + t.subscriber.SetChannelCapacityOfStreamAllocator(channelCapacity) + } +} + +func (t *TransportManager) hasRecentSignalLocked() bool { + return time.Since(t.lastSignalAt) < PingTimeoutSeconds*time.Second +} + +func (t *TransportManager) RTPStreamPublished(ssrc uint32, mid, rid string) { + if t.params.UseOneShotSignallingMode || t.params.UseSinglePeerConnection { + t.publisher.RTPStreamPublished(ssrc, mid, rid) + } else { + t.subscriber.RTPStreamPublished(ssrc, mid, rid) + } +} diff --git a/livekit/pkg/rtc/types/ice.go b/livekit/pkg/rtc/types/ice.go new file mode 100644 index 0000000..fbcbc5a --- /dev/null +++ b/livekit/pkg/rtc/types/ice.go @@ -0,0 +1,450 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types + +import ( + "fmt" + "strings" + "sync" + + "github.com/pion/ice/v4" + "github.com/pion/webrtc/v4" + "golang.org/x/exp/slices" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" +) + +type ICEConnectionType int + +const ( + // this is in ICE priority highest -> lowest ordering + // WARNING: Keep this ordering as it is used to find lowest priority connection type. + ICEConnectionTypeUnknown ICEConnectionType = iota + ICEConnectionTypeUDP + ICEConnectionTypeTCP + ICEConnectionTypeTURN +) + +func (i ICEConnectionType) String() string { + switch i { + case ICEConnectionTypeUnknown: + return "unknown" + case ICEConnectionTypeUDP: + return "udp" + case ICEConnectionTypeTCP: + return "tcp" + case ICEConnectionTypeTURN: + return "turn" + default: + return "unknown" + } +} + +func (i ICEConnectionType) ReporterType() roomobs.ConnectionType { + switch i { + case ICEConnectionTypeUnknown: + return roomobs.ConnectionTypeUndefined + case ICEConnectionTypeUDP: + return roomobs.ConnectionTypeUDP + case ICEConnectionTypeTCP: + return roomobs.ConnectionTypeTCP + case ICEConnectionTypeTURN: + return roomobs.ConnectionTypeTurn + default: + return roomobs.ConnectionTypeUndefined + } +} + +// -------------------------------------------- + +type ICECandidateExtended struct { + // only one of local or remote is set. This is due to type foo in Pion + Local *webrtc.ICECandidate + Remote ice.Candidate + SelectedOrder int + Filtered bool + Trickle bool +} + +// -------------------------------------------- + +type ICEConnectionInfo struct { + Local []*ICECandidateExtended + Remote []*ICECandidateExtended + Transport livekit.SignalTarget + Type ICEConnectionType +} + +func (i *ICEConnectionInfo) HasCandidates() bool { + return len(i.Local) > 0 || len(i.Remote) > 0 +} + +func ICEConnectionInfosType(infos []*ICEConnectionInfo) ICEConnectionType { + for _, info := range infos { + if info.Type != ICEConnectionTypeUnknown { + return info.Type + } + } + return ICEConnectionTypeUnknown +} + +// -------------------------------------------- + +type ICEConnectionDetails struct { + ICEConnectionInfo + lock sync.Mutex + selectedCount int + logger logger.Logger +} + +func NewICEConnectionDetails(transport livekit.SignalTarget, l logger.Logger) *ICEConnectionDetails { + d := &ICEConnectionDetails{ + ICEConnectionInfo: ICEConnectionInfo{ + Transport: transport, + Type: ICEConnectionTypeUnknown, + }, + logger: l, + } + return d +} + +func (d *ICEConnectionDetails) GetInfo() *ICEConnectionInfo { + d.lock.Lock() + defer d.lock.Unlock() + info := &ICEConnectionInfo{ + Transport: d.Transport, + Type: d.Type, + Local: make([]*ICECandidateExtended, 0, len(d.Local)), + Remote: make([]*ICECandidateExtended, 0, len(d.Remote)), + } + for _, c := range d.Local { + info.Local = append(info.Local, &ICECandidateExtended{ + Local: c.Local, + Filtered: c.Filtered, + SelectedOrder: c.SelectedOrder, + Trickle: c.Trickle, + }) + } + for _, c := range d.Remote { + info.Remote = append(info.Remote, &ICECandidateExtended{ + Remote: c.Remote, + Filtered: c.Filtered, + SelectedOrder: c.SelectedOrder, + Trickle: c.Trickle, + }) + } + return info +} + +func (d *ICEConnectionDetails) GetConnectionType() ICEConnectionType { + d.lock.Lock() + defer d.lock.Unlock() + + return d.Type +} + +func (d *ICEConnectionDetails) AddLocalCandidate(c *webrtc.ICECandidate, filtered, trickle bool) { + d.lock.Lock() + defer d.lock.Unlock() + compFn := func(e *ICECandidateExtended) bool { + return isCandidateEqualTo(e.Local, c) + } + if slices.ContainsFunc(d.Local, compFn) { + return + } + d.Local = append(d.Local, &ICECandidateExtended{ + Local: c, + Filtered: filtered, + Trickle: trickle, + }) +} + +func (d *ICEConnectionDetails) AddLocalICECandidate(c ice.Candidate, filtered, trickle bool) { + candidate, err := unmarshalCandidate(c) + if err != nil { + d.logger.Errorw("could not unmarshal ice candidate", err, "candidate", c) + return + } + + d.AddLocalCandidate(candidate, filtered, trickle) +} + +func (d *ICEConnectionDetails) AddRemoteCandidate(c webrtc.ICECandidateInit, filtered, trickle, canUpdate bool) { + candidate, err := unmarshalICECandidate(c) + if err != nil { + d.logger.Errorw("could not unmarshal candidate", err, "candidate", c) + return + } + d.AddRemoteICECandidate(candidate, filtered, trickle, canUpdate) +} + +func (d *ICEConnectionDetails) AddRemoteICECandidate(candidate ice.Candidate, filtered, trickle, canUpdate bool) { + if candidate == nil { + // end-of-candidates candidate + return + } + + d.lock.Lock() + defer d.lock.Unlock() + indexFn := func(e *ICECandidateExtended) bool { + return isICECandidateEqualTo(e.Remote, candidate) + } + if idx := slices.IndexFunc(d.Remote, indexFn); idx != -1 { + if canUpdate { + d.Remote[idx].Filtered = filtered + d.Remote[idx].Trickle = trickle + } + return + } + d.Remote = append(d.Remote, &ICECandidateExtended{ + Remote: candidate, + Filtered: filtered, + Trickle: trickle, + }) + d.updateConnectionTypeLocked() +} + +func (d *ICEConnectionDetails) Clear() { + d.lock.Lock() + defer d.lock.Unlock() + d.Local = nil + d.Remote = nil + d.Type = ICEConnectionTypeUnknown +} + +func (d *ICEConnectionDetails) SetSelectedPair(pair *webrtc.ICECandidatePair) { + d.lock.Lock() + defer d.lock.Unlock() + + d.selectedCount++ + + remoteIdx := slices.IndexFunc(d.Remote, func(e *ICECandidateExtended) bool { + return isICECandidateEqualToCandidate(e.Remote, pair.Remote) + }) + if remoteIdx < 0 { + // it's possible for prflx candidates to be generated by Pion, we'll add them + candidate, err := unmarshalICECandidate(pair.Remote.ToJSON()) + if err != nil { + d.logger.Errorw("could not unmarshal remote candidate", err, "candidate", pair.Remote) + return + } + if candidate == nil { + return + } + d.Remote = append(d.Remote, &ICECandidateExtended{ + Remote: candidate, + Filtered: false, + Trickle: false, + }) + remoteIdx = len(d.Remote) - 1 + } + d.Remote[remoteIdx].SelectedOrder = d.selectedCount + d.updateConnectionTypeLocked() + + localIdx := slices.IndexFunc(d.Local, func(e *ICECandidateExtended) bool { + return isCandidateEqualTo(e.Local, pair.Local) + }) + if localIdx < 0 { + d.logger.Errorw("could not match local candidate", nil, "local", pair.Local) + // should not happen + return + } + d.Local[localIdx].SelectedOrder = d.selectedCount +} + +func (d *ICEConnectionDetails) updateConnectionTypeLocked() { + highestSelectedOrder := -1 + var selectedRemoteCandidate *ICECandidateExtended + for _, remote := range d.Remote { + if remote.SelectedOrder == 0 { + continue + } + + if remote.SelectedOrder > highestSelectedOrder { + highestSelectedOrder = remote.SelectedOrder + selectedRemoteCandidate = remote + } + } + + if selectedRemoteCandidate == nil { + return + } + + remoteCandidate := selectedRemoteCandidate.Remote + switch remoteCandidate.NetworkType() { + case ice.NetworkTypeUDP4, ice.NetworkTypeUDP6: + d.Type = ICEConnectionTypeUDP + + case ice.NetworkTypeTCP4, ice.NetworkTypeTCP6: + d.Type = ICEConnectionTypeTCP + } + + switch remoteCandidate.Type() { + case ice.CandidateTypeRelay: + d.Type = ICEConnectionTypeTURN + + case ice.CandidateTypePeerReflexive: + // if the remote relay candidate pings us *before* we get a relay candidate, + // Pion would have created a prflx candidate with the same address as the relay candidate. + // to report an accurate connection type, we'll compare to see if existing relay candidates match + for _, other := range d.Remote { + or := other.Remote + if or.Type() == ice.CandidateTypeRelay && + remoteCandidate.Address() == or.Address() && + // NOTE: port is not compared as relayed address reported by TURN ALLOCATE from + // pion/turn server -> client and later sent from client -> server via ICE Trickle does not + // match port of `prflx` candidate learnt via TURN path. TODO-INVESTIGATE: how and why doesn't + // port match? + //remoteCanddiate.Port() == or.Port() && + remoteCandidate.NetworkType().NetworkShort() == or.NetworkType().NetworkShort() { + d.Type = ICEConnectionTypeTURN + break + } + } + } +} + +// ------------------------------------------------------------- + +func isCandidateEqualTo(c1, c2 *webrtc.ICECandidate) bool { + if c1 == nil && c2 == nil { + return true + } + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return false + } + return c1.Typ == c2.Typ && + c1.Protocol == c2.Protocol && + c1.Address == c2.Address && + c1.Port == c2.Port && + c1.Foundation == c2.Foundation && + c1.Priority == c2.Priority && + c1.RelatedAddress == c2.RelatedAddress && + c1.RelatedPort == c2.RelatedPort && + c1.TCPType == c2.TCPType +} + +func isICECandidateEqualTo(c1, c2 ice.Candidate) bool { + if c1 == nil && c2 == nil { + return true + } + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return false + } + return c1.Type() == c2.Type() && + c1.NetworkType() == c2.NetworkType() && + c1.Address() == c2.Address() && + c1.Port() == c2.Port() && + c1.Foundation() == c2.Foundation() && + c1.Priority() == c2.Priority() && + c1.RelatedAddress().Equal(c2.RelatedAddress()) && + c1.TCPType() == c2.TCPType() +} + +func isICECandidateEqualToCandidate(c1 ice.Candidate, c2 *webrtc.ICECandidate) bool { + if c1 == nil && c2 == nil { + return true + } + if (c1 == nil && c2 != nil) || (c1 != nil && c2 == nil) { + return false + } + return c1.Type().String() == c2.Typ.String() && + c1.NetworkType().NetworkShort() == c2.Protocol.String() && + c1.Address() == c2.Address && + c1.Port() == int(c2.Port) && + c1.Foundation() == c2.Foundation && + c1.Priority() == c2.Priority && + c1.TCPType().String() == c2.TCPType +} + +func unmarshalICECandidate(c webrtc.ICECandidateInit) (ice.Candidate, error) { + candidateValue := strings.TrimPrefix(c.Candidate, "candidate:") + if candidateValue == "" { + return nil, nil + } + + candidate, err := ice.UnmarshalCandidate(candidateValue) + if err != nil { + return nil, err + } + + return candidate, nil +} + +func unmarshalCandidate(i ice.Candidate) (*webrtc.ICECandidate, error) { + var typ webrtc.ICECandidateType + switch i.Type() { + case ice.CandidateTypeHost: + typ = webrtc.ICECandidateTypeHost + case ice.CandidateTypeServerReflexive: + typ = webrtc.ICECandidateTypeSrflx + case ice.CandidateTypePeerReflexive: + typ = webrtc.ICECandidateTypePrflx + case ice.CandidateTypeRelay: + typ = webrtc.ICECandidateTypeRelay + default: + return nil, fmt.Errorf("unknown candidate type: %s", i.Type()) + } + + var protocol webrtc.ICEProtocol + switch strings.ToLower(i.NetworkType().NetworkShort()) { + case "udp": + protocol = webrtc.ICEProtocolUDP + case "tcp": + protocol = webrtc.ICEProtocolTCP + default: + return nil, fmt.Errorf("unknown network type: %s", i.NetworkType()) + } + + c := webrtc.ICECandidate{ + Foundation: i.Foundation(), + Priority: i.Priority(), + Address: i.Address(), + Protocol: protocol, + Port: uint16(i.Port()), + Component: i.Component(), + Typ: typ, + TCPType: i.TCPType().String(), + } + + if i.RelatedAddress() != nil { + c.RelatedAddress = i.RelatedAddress().Address + c.RelatedPort = uint16(i.RelatedAddress().Port) + } + + return &c, nil +} + +func IsCandidateMDNS(candidate webrtc.ICECandidateInit) bool { + c, err := unmarshalICECandidate(candidate) + if err != nil { + return false + } + + return IsICECandidateMDNS(c) +} + +func IsICECandidateMDNS(candidate ice.Candidate) bool { + if candidate == nil { + // end-of-candidates candidate + return false + } + + return strings.HasSuffix(candidate.Address(), ".local") +} diff --git a/livekit/pkg/rtc/types/interfaces.go b/livekit/pkg/rtc/types/interfaces.go new file mode 100644 index 0000000..ec8b00e --- /dev/null +++ b/livekit/pkg/rtc/types/interfaces.go @@ -0,0 +1,865 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "fmt" + "time" + + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/telemetry" + + "google.golang.org/protobuf/proto" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . WebsocketClient +type WebsocketClient interface { + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error + SetReadDeadline(deadline time.Time) error + Close() error +} + +type AddSubscriberParams struct { + AllTracks bool + TrackIDs []livekit.TrackID +} + +// --------------------------------------------- + +type MigrateState int32 + +const ( + MigrateStateInit MigrateState = iota + MigrateStateSync + MigrateStateComplete +) + +func (m MigrateState) String() string { + switch m { + case MigrateStateInit: + return "MIGRATE_STATE_INIT" + case MigrateStateSync: + return "MIGRATE_STATE_SYNC" + case MigrateStateComplete: + return "MIGRATE_STATE_COMPLETE" + default: + return fmt.Sprintf("%d", int(m)) + } +} + +// --------------------------------------------- + +type SubscribedCodecQuality struct { + CodecMime mime.MimeType + Quality livekit.VideoQuality +} + +// --------------------------------------------- + +type ParticipantCloseReason int + +const ( + ParticipantCloseReasonNone ParticipantCloseReason = iota + ParticipantCloseReasonClientRequestLeave + ParticipantCloseReasonRoomManagerStop + ParticipantCloseReasonVerifyFailed + ParticipantCloseReasonJoinFailed + ParticipantCloseReasonJoinTimeout + ParticipantCloseReasonMessageBusFailed + ParticipantCloseReasonPeerConnectionDisconnected + ParticipantCloseReasonDuplicateIdentity + ParticipantCloseReasonMigrationComplete + ParticipantCloseReasonStale + ParticipantCloseReasonServiceRequestRemoveParticipant + ParticipantCloseReasonServiceRequestDeleteRoom + ParticipantCloseReasonSimulateMigration + ParticipantCloseReasonSimulateNodeFailure + ParticipantCloseReasonSimulateServerLeave + ParticipantCloseReasonSimulateLeaveRequest + ParticipantCloseReasonNegotiateFailed + ParticipantCloseReasonMigrationRequested + ParticipantCloseReasonPublicationError + ParticipantCloseReasonSubscriptionError + ParticipantCloseReasonDataChannelError + ParticipantCloseReasonMigrateCodecMismatch + ParticipantCloseReasonSignalSourceClose + ParticipantCloseReasonRoomClosed + ParticipantCloseReasonUserUnavailable + ParticipantCloseReasonUserRejected + ParticipantCloseReasonMoveFailed +) + +func (p ParticipantCloseReason) String() string { + switch p { + case ParticipantCloseReasonNone: + return "NONE" + case ParticipantCloseReasonClientRequestLeave: + return "CLIENT_REQUEST_LEAVE" + case ParticipantCloseReasonRoomManagerStop: + return "ROOM_MANAGER_STOP" + case ParticipantCloseReasonVerifyFailed: + return "VERIFY_FAILED" + case ParticipantCloseReasonJoinFailed: + return "JOIN_FAILED" + case ParticipantCloseReasonJoinTimeout: + return "JOIN_TIMEOUT" + case ParticipantCloseReasonMessageBusFailed: + return "MESSAGE_BUS_FAILED" + case ParticipantCloseReasonPeerConnectionDisconnected: + return "PEER_CONNECTION_DISCONNECTED" + case ParticipantCloseReasonDuplicateIdentity: + return "DUPLICATE_IDENTITY" + case ParticipantCloseReasonMigrationComplete: + return "MIGRATION_COMPLETE" + case ParticipantCloseReasonStale: + return "STALE" + case ParticipantCloseReasonServiceRequestRemoveParticipant: + return "SERVICE_REQUEST_REMOVE_PARTICIPANT" + case ParticipantCloseReasonServiceRequestDeleteRoom: + return "SERVICE_REQUEST_DELETE_ROOM" + case ParticipantCloseReasonSimulateMigration: + return "SIMULATE_MIGRATION" + case ParticipantCloseReasonSimulateNodeFailure: + return "SIMULATE_NODE_FAILURE" + case ParticipantCloseReasonSimulateServerLeave: + return "SIMULATE_SERVER_LEAVE" + case ParticipantCloseReasonSimulateLeaveRequest: + return "SIMULATE_LEAVE_REQUEST" + case ParticipantCloseReasonNegotiateFailed: + return "NEGOTIATE_FAILED" + case ParticipantCloseReasonMigrationRequested: + return "MIGRATION_REQUESTED" + case ParticipantCloseReasonPublicationError: + return "PUBLICATION_ERROR" + case ParticipantCloseReasonSubscriptionError: + return "SUBSCRIPTION_ERROR" + case ParticipantCloseReasonDataChannelError: + return "DATA_CHANNEL_ERROR" + case ParticipantCloseReasonMigrateCodecMismatch: + return "MIGRATE_CODEC_MISMATCH" + case ParticipantCloseReasonSignalSourceClose: + return "SIGNAL_SOURCE_CLOSE" + case ParticipantCloseReasonRoomClosed: + return "ROOM_CLOSED" + case ParticipantCloseReasonUserUnavailable: + return "USER_UNAVAILABLE" + case ParticipantCloseReasonUserRejected: + return "USER_REJECTED" + case ParticipantCloseReasonMoveFailed: + return "MOVE_FAILED" + default: + return fmt.Sprintf("%d", int(p)) + } +} + +func (p ParticipantCloseReason) ToDisconnectReason() livekit.DisconnectReason { + switch p { + case ParticipantCloseReasonClientRequestLeave, ParticipantCloseReasonSimulateLeaveRequest: + return livekit.DisconnectReason_CLIENT_INITIATED + case ParticipantCloseReasonRoomManagerStop: + return livekit.DisconnectReason_SERVER_SHUTDOWN + case ParticipantCloseReasonVerifyFailed, ParticipantCloseReasonJoinFailed, ParticipantCloseReasonJoinTimeout, ParticipantCloseReasonMessageBusFailed: + // expected to be connected but is not + return livekit.DisconnectReason_JOIN_FAILURE + case ParticipantCloseReasonPeerConnectionDisconnected: + return livekit.DisconnectReason_CONNECTION_TIMEOUT + case ParticipantCloseReasonDuplicateIdentity, ParticipantCloseReasonStale: + return livekit.DisconnectReason_DUPLICATE_IDENTITY + case ParticipantCloseReasonMigrationRequested, ParticipantCloseReasonMigrationComplete, ParticipantCloseReasonSimulateMigration: + return livekit.DisconnectReason_MIGRATION + case ParticipantCloseReasonServiceRequestRemoveParticipant: + return livekit.DisconnectReason_PARTICIPANT_REMOVED + case ParticipantCloseReasonServiceRequestDeleteRoom: + return livekit.DisconnectReason_ROOM_DELETED + case ParticipantCloseReasonSimulateNodeFailure, ParticipantCloseReasonSimulateServerLeave: + return livekit.DisconnectReason_SERVER_SHUTDOWN + case ParticipantCloseReasonNegotiateFailed, ParticipantCloseReasonPublicationError, ParticipantCloseReasonSubscriptionError, + ParticipantCloseReasonDataChannelError, ParticipantCloseReasonMigrateCodecMismatch, ParticipantCloseReasonMoveFailed: + return livekit.DisconnectReason_STATE_MISMATCH + case ParticipantCloseReasonSignalSourceClose: + return livekit.DisconnectReason_SIGNAL_CLOSE + case ParticipantCloseReasonRoomClosed: + return livekit.DisconnectReason_ROOM_CLOSED + case ParticipantCloseReasonUserUnavailable: + return livekit.DisconnectReason_USER_UNAVAILABLE + case ParticipantCloseReasonUserRejected: + return livekit.DisconnectReason_USER_REJECTED + default: + // the other types will map to unknown reason + return livekit.DisconnectReason_UNKNOWN_REASON + } +} + +// --------------------------------------------- + +type SignallingCloseReason int + +const ( + SignallingCloseReasonUnknown SignallingCloseReason = iota + SignallingCloseReasonMigration + SignallingCloseReasonResume + SignallingCloseReasonTransportFailure + SignallingCloseReasonFullReconnectPublicationError + SignallingCloseReasonFullReconnectSubscriptionError + SignallingCloseReasonFullReconnectDataChannelError + SignallingCloseReasonFullReconnectNegotiateFailed + SignallingCloseReasonParticipantClose + SignallingCloseReasonDisconnectOnResume + SignallingCloseReasonDisconnectOnResumeNoMessages +) + +func (s SignallingCloseReason) String() string { + switch s { + case SignallingCloseReasonUnknown: + return "UNKNOWN" + case SignallingCloseReasonMigration: + return "MIGRATION" + case SignallingCloseReasonResume: + return "RESUME" + case SignallingCloseReasonTransportFailure: + return "TRANSPORT_FAILURE" + case SignallingCloseReasonFullReconnectPublicationError: + return "FULL_RECONNECT_PUBLICATION_ERROR" + case SignallingCloseReasonFullReconnectSubscriptionError: + return "FULL_RECONNECT_SUBSCRIPTION_ERROR" + case SignallingCloseReasonFullReconnectDataChannelError: + return "FULL_RECONNECT_DATA_CHANNEL_ERROR" + case SignallingCloseReasonFullReconnectNegotiateFailed: + return "FULL_RECONNECT_NEGOTIATE_FAILED" + case SignallingCloseReasonParticipantClose: + return "PARTICIPANT_CLOSE" + case SignallingCloseReasonDisconnectOnResume: + return "DISCONNECT_ON_RESUME" + case SignallingCloseReasonDisconnectOnResumeNoMessages: + return "DISCONNECT_ON_RESUME_NO_MESSAGES" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +// --------------------------------------------- +const ( + ParticipantCloseKeyNormal = "normal" + ParticipantCloseKeyWHIP = "whip" +) + +// --------------------------------------------- + +//counterfeiter:generate . Participant +type Participant interface { + ID() livekit.ParticipantID + Identity() livekit.ParticipantIdentity + State() livekit.ParticipantInfo_State + ConnectedAt() time.Time + CloseReason() ParticipantCloseReason + Kind() livekit.ParticipantInfo_Kind + IsRecorder() bool + IsDependent() bool + IsAgent() bool + + GetLogger() logger.Logger + + CanSkipBroadcast() bool + Version() utils.TimedVersion + ToProto() *livekit.ParticipantInfo + ToProtoWithVersion() (*livekit.ParticipantInfo, utils.TimedVersion) + + IsPublisher() bool + GetPublishedTrack(trackID livekit.TrackID) MediaTrack + GetPublishedTracks() []MediaTrack + RemovePublishedTrack(track MediaTrack, isExpectedToResume bool) + + GetPublishedDataTracks() []DataTrack + GetPublishedDataTrack(handle uint16) DataTrack + RemovePublishedDataTrack(track DataTrack) + + GetAudioLevel() (smoothedLevel float64, active bool) + + // HasPermission checks permission of the subscriber by identity. Returns true if subscriber is allowed to subscribe + // to the track with trackID + HasPermission(trackID livekit.TrackID, subIdentity livekit.ParticipantIdentity) bool + + // permissions + Hidden() bool + + MigrateState() MigrateState + + Close(sendLeave bool, reason ParticipantCloseReason, isExpectedToResume bool) error + IsClosed() bool + IsDisconnected() bool + + SubscriptionPermission() (*livekit.SubscriptionPermission, utils.TimedVersion) + + // updates from remotes + UpdateSubscriptionPermission( + subscriptionPermission *livekit.SubscriptionPermission, + timedVersion utils.TimedVersion, + resolverBySid func(participantID livekit.ParticipantID) LocalParticipant, + ) error + + DebugInfo() map[string]any + + HandleReceivedDataTrackMessage([]byte, *datatrack.Packet, int64) + + GetParticipantListener() ParticipantListener +} + +// ------------------------------------------------------- + +type AddTrackParams struct { + Stereo bool + Red bool +} + +type MoveToRoomParams struct { + RoomName livekit.RoomName + ParticipantID livekit.ParticipantID + Listener LocalParticipantListener + Helper LocalParticipantHelper +} + +type DataMessageCache struct { + Data []byte + SenderID livekit.ParticipantID + Seq uint32 + DestIdentities []livekit.ParticipantIdentity +} + +//counterfeiter:generate . LocalParticipantHelper +type LocalParticipantHelper interface { + ResolveMediaTrack(LocalParticipant, livekit.TrackID) MediaResolverResult + ResolveDataTrack(LocalParticipant, livekit.TrackID) DataResolverResult + GetParticipantInfo(pID livekit.ParticipantID) *livekit.ParticipantInfo + GetRegionSettings(ip string) *livekit.RegionSettings + GetSubscriberForwarderState(p LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error) + ShouldRegressCodec() bool + GetCachedReliableDataMessage(seqs map[livekit.ParticipantID]uint32) []*DataMessageCache +} + +//counterfeiter:generate . LocalParticipant +type LocalParticipant interface { + Participant + + TelemetryGuard() *telemetry.ReferenceGuard + + // getters + GetCountry() string + GetTrailer() []byte + GetLoggerResolver() logger.DeferredFieldResolver + GetReporter() roomobs.ParticipantSessionReporter + GetReporterResolver() roomobs.ParticipantReporterResolver + GetAdaptiveStream() bool + ProtocolVersion() ProtocolVersion + SupportsSyncStreamID() bool + SupportsTransceiverReuse() bool + IsUsingSinglePeerConnection() bool + IsReady() bool + ActiveAt() time.Time + Disconnected() <-chan struct{} + IsIdle() bool + SubscriberAsPrimary() bool + GetClientInfo() *livekit.ClientInfo + GetClientConfiguration() *livekit.ClientConfiguration + GetBufferFactory() *buffer.Factory + GetPlayoutDelayConfig() *livekit.PlayoutDelay + GetPendingTrack(trackID livekit.TrackID) *livekit.TrackInfo + GetICEConnectionInfo() []*ICEConnectionInfo + HasConnected() bool + GetEnabledPublishCodecs() []*livekit.Codec + GetPublisherICESessionUfrag() (string, error) + SupportsMoving() error + GetLastReliableSequence(migrateOut bool) uint32 + + SwapResponseSink(sink routing.MessageSink, reason SignallingCloseReason) + GetResponseSink() routing.MessageSink + CloseSignalConnection(reason SignallingCloseReason) + UpdateLastSeenSignal() + SetSignalSourceValid(valid bool) + HandleSignalSourceClose() + + // updates + UpdateMetadata(update *livekit.UpdateParticipantMetadata, fromAdmin bool) error + SetName(name string) + SetMetadata(metadata string) + SetAttributes(attributes map[string]string) + UpdateAudioTrack(update *livekit.UpdateLocalAudioTrack) error + UpdateVideoTrack(update *livekit.UpdateLocalVideoTrack) error + + // permissions + ClaimGrants() *auth.ClaimGrants + SetPermission(permission *livekit.ParticipantPermission) bool + CanPublish() bool + CanPublishSource(source livekit.TrackSource) bool + CanSubscribe() bool + CanPublishData() bool + + // PeerConnection + HandleICETrickle(trickleRequest *livekit.TrickleRequest) + HandleOffer(sd *livekit.SessionDescription) error + GetAnswer() (webrtc.SessionDescription, uint32, error) + HandleICETrickleSDPFragment(sdpFragment string) error + HandleICERestartSDPFragment(sdpFragment string) (string, error) + AddTrack(req *livekit.AddTrackRequest) + SetTrackMuted(mute *livekit.MuteTrackRequest, fromAdmin bool) *livekit.TrackInfo + + HandleAnswer(sd *livekit.SessionDescription) + Negotiate(force bool) + ICERestart(iceConfig *livekit.ICEConfig) + AddTrackLocal(trackLocal webrtc.TrackLocal, params AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + AddTransceiverFromTrackLocal(trackLocal webrtc.TrackLocal, params AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + RemoveTrackLocal(sender *webrtc.RTPSender) error + + WriteSubscriberRTCP(pkts []rtcp.Packet) error + + // subscriptions + SubscribeToTrack(trackID livekit.TrackID, isSync bool) + UnsubscribeFromTrack(trackID livekit.TrackID) + UpdateSubscribedTrackSettings(trackID livekit.TrackID, settings *livekit.UpdateTrackSettings) + GetSubscribedTracks() []SubscribedTrack + IsTrackNameSubscribed(publisherIdentity livekit.ParticipantIdentity, trackName string) bool + SubscribeToDataTrack(trackID livekit.TrackID) + UnsubscribeFromDataTrack(trackID livekit.TrackID) + UpdateDataTrackSubscriptionOptions(trackID livekit.TrackID, subscriptionOptions *livekit.DataTrackSubscriptionOptions) + Verify() bool + VerifySubscribeParticipantInfo(pID livekit.ParticipantID, version uint32) + // WaitUntilSubscribed waits until all subscriptions have been settled, or if the timeout + // has been reached. If the timeout expires, it will return an error. + WaitUntilSubscribed(timeout time.Duration) error + StopAndGetSubscribedTracksForwarderState() map[livekit.TrackID]*livekit.RTPForwarderState + SupportsCodecChange() bool + + // returns list of participant identities that the current participant is subscribed to + GetSubscribedParticipants() []livekit.ParticipantID + IsSubscribedTo(sid livekit.ParticipantID) bool + + GetConnectionQuality() *livekit.ConnectionQualityInfo + + // server sent messages + SendJoinResponse(joinResponse *livekit.JoinResponse) error + SendParticipantUpdate(participants []*livekit.ParticipantInfo) error + SendSpeakerUpdate(speakers []*livekit.SpeakerInfo, force bool) error + SendDataMessage(kind livekit.DataPacket_Kind, data []byte, senderID livekit.ParticipantID, seq uint32) error + SendDataMessageUnlabeled(data []byte, useRaw bool, sender livekit.ParticipantIdentity) error + SendRoomUpdate(room *livekit.Room) error + SendConnectionQualityUpdate(update *livekit.ConnectionQualityUpdate) error + SendSubscriptionPermissionUpdate(publisherID livekit.ParticipantID, trackID livekit.TrackID, allowed bool) error + SendRefreshToken(token string) error + HandleReconnectAndSendResponse(reconnectReason livekit.ReconnectReason, reconnectResponse *livekit.ReconnectResponse) error + IssueFullReconnect(reason ParticipantCloseReason) + SendRoomMovedResponse(moved *livekit.RoomMovedResponse) error + SendDataTrackSubscriberHandles(handles map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack) error + + AddOnClose(key string, callback func(LocalParticipant)) + OnClaimsChanged(callback func(LocalParticipant)) + + HandleReceiverReport(dt *sfu.DownTrack, report *rtcp.ReceiverReport) + + // session migration + MaybeStartMigration(force bool, onStart func()) bool + NotifyMigration() + SetMigrateState(s MigrateState) + SetMigrateInfo( + previousOffer *webrtc.SessionDescription, + previousAnswer *webrtc.SessionDescription, + mediaTracks []*livekit.TrackPublishedResponse, + dataChannels []*livekit.DataChannelInfo, + dataChannelReceiveState []*livekit.DataChannelReceiveState, + dataTracks []*livekit.PublishDataTrackResponse, + ) + IsReconnect() bool + MoveToRoom(params MoveToRoomParams) + + UpdateMediaRTT(rtt uint32) + UpdateSignalingRTT(rtt uint32) + + CacheDownTrack(trackID livekit.TrackID, rtpTransceiver *webrtc.RTPTransceiver, downTrackState sfu.DownTrackState) + UncacheDownTrack(rtpTransceiver *webrtc.RTPTransceiver) + GetCachedDownTrack(trackID livekit.TrackID) (*webrtc.RTPTransceiver, sfu.DownTrackState) + + SetICEConfig(iceConfig *livekit.ICEConfig) + GetICEConfig() *livekit.ICEConfig + OnICEConfigChanged(callback func(participant LocalParticipant, iceConfig *livekit.ICEConfig)) + + UpdateSubscribedQuality(nodeID livekit.NodeID, trackID livekit.TrackID, maxQualities []SubscribedCodecQuality) error + UpdateSubscribedAudioCodecs(nodeID livekit.NodeID, trackID livekit.TrackID, codecs []*livekit.SubscribedAudioCodec) error + UpdateMediaLoss(nodeID livekit.NodeID, trackID livekit.TrackID, fractionalLoss uint32) error + + // down stream bandwidth management + SetSubscriberAllowPause(allowPause bool) + SetSubscriberChannelCapacity(channelCapacity int64) + + GetPacer() pacer.Pacer + + GetDisableSenderReportPassThrough() bool + + HandleMetrics(senderParticipantID livekit.ParticipantID, batch *livekit.MetricsBatch) error + HandleUpdateSubscriptions( + []livekit.TrackID, + []*livekit.ParticipantTracks, + bool, + ) + HandleUpdateSubscriptionPermission(*livekit.SubscriptionPermission) error + HandleSyncState(*livekit.SyncState) error + HandleSimulateScenario(*livekit.SimulateScenario) error + HandleLeaveRequest(reason ParticipantCloseReason) + + HandlePublishDataTrackRequest(*livekit.PublishDataTrackRequest) + HandleUnpublishDataTrackRequest(*livekit.UnpublishDataTrackRequest) + HandleUpdateDataSubscription(*livekit.UpdateDataSubscription) + + HandleSignalMessage(msg proto.Message) error + + PerformRpc(req *livekit.PerformRpcRequest, resultCh chan string, errorCh chan error) + + GetDataTrackTransport() DataTrackTransport + + ClearParticipantListener() + + GetNextSubscribedDataTrackHandle() uint16 +} + +// --------------------------------------------- + +//counterfeiter:generate . ParticipantListener +type ParticipantListener interface { + OnParticipantUpdate(Participant) + OnTrackPublished(Participant, MediaTrack) + OnTrackUpdated(Participant, MediaTrack) + OnTrackUnpublished(Participant, MediaTrack) + OnDataTrackPublished(Participant, DataTrack) + OnDataTrackUnpublished(Participant, DataTrack) + OnDataTrackMessage(Participant, []byte, *datatrack.Packet) + OnMetrics(Participant, *livekit.DataPacket) +} + +var _ ParticipantListener = (*NullParticipantListener)(nil) + +type NullParticipantListener struct{} + +func (*NullParticipantListener) OnParticipantUpdate(Participant) {} +func (*NullParticipantListener) OnTrackPublished(Participant, MediaTrack) {} +func (*NullParticipantListener) OnTrackUpdated(Participant, MediaTrack) {} +func (*NullParticipantListener) OnTrackUnpublished(Participant, MediaTrack) {} +func (*NullParticipantListener) OnDataTrackPublished(Participant, DataTrack) {} +func (*NullParticipantListener) OnDataTrackUnpublished(Participant, DataTrack) {} +func (*NullParticipantListener) OnDataTrackMessage(Participant, []byte, *datatrack.Packet) {} +func (*NullParticipantListener) OnMetrics(Participant, *livekit.DataPacket) {} + +// --------------------------------------------- + +//counterfeiter:generate . LocalParticipantListener +type LocalParticipantListener interface { + ParticipantListener + + OnStateChange(LocalParticipant) + OnSubscriberReady(LocalParticipant) + OnMigrateStateChange(LocalParticipant, MigrateState) + OnDataMessage(LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) + OnDataMessageUnlabeled(LocalParticipant, []byte) + OnSubscribeStatusChanged(LocalParticipant, livekit.ParticipantID, bool) + OnUpdateSubscriptions( + LocalParticipant, + []livekit.TrackID, + []*livekit.ParticipantTracks, + bool, + ) + OnUpdateSubscriptionPermission(LocalParticipant, *livekit.SubscriptionPermission) error + OnUpdateDataSubscriptions(LocalParticipant, *livekit.UpdateDataSubscription) + OnSyncState(LocalParticipant, *livekit.SyncState) error + OnSimulateScenario(LocalParticipant, *livekit.SimulateScenario) error + OnLeave(LocalParticipant, ParticipantCloseReason) +} + +var _ LocalParticipantListener = (*NullLocalParticipantListener)(nil) + +type NullLocalParticipantListener struct { + NullParticipantListener +} + +func (*NullLocalParticipantListener) OnStateChange(LocalParticipant) {} +func (*NullLocalParticipantListener) OnSubscriberReady(LocalParticipant) {} +func (*NullLocalParticipantListener) OnMigrateStateChange(LocalParticipant, MigrateState) {} +func (*NullLocalParticipantListener) OnDataMessage(LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) { +} +func (*NullLocalParticipantListener) OnDataMessageUnlabeled(LocalParticipant, []byte) {} +func (*NullLocalParticipantListener) OnSubscribeStatusChanged(LocalParticipant, livekit.ParticipantID, bool) { +} +func (*NullLocalParticipantListener) OnUpdateSubscriptions( + LocalParticipant, + []livekit.TrackID, + []*livekit.ParticipantTracks, + bool, +) { +} +func (*NullLocalParticipantListener) OnUpdateSubscriptionPermission(LocalParticipant, *livekit.SubscriptionPermission) error { + return nil +} +func (*NullLocalParticipantListener) OnUpdateDataSubscriptions(LocalParticipant, *livekit.UpdateDataSubscription) { +} +func (*NullLocalParticipantListener) OnSyncState(LocalParticipant, *livekit.SyncState) error { + return nil +} +func (*NullLocalParticipantListener) OnSimulateScenario(LocalParticipant, *livekit.SimulateScenario) error { + return nil +} +func (*NullLocalParticipantListener) OnLeave(LocalParticipant, ParticipantCloseReason) {} + +// --------------------------------------------- + +// Room is a container of participants, and can provide room-level actions +// +//counterfeiter:generate . Room +type Room interface { + Name() livekit.RoomName + ID() livekit.RoomID + RemoveParticipant(identity livekit.ParticipantIdentity, pID livekit.ParticipantID, reason ParticipantCloseReason) + UpdateSubscriptions( + participant LocalParticipant, + trackIDs []livekit.TrackID, + participantTracks []*livekit.ParticipantTracks, + subscribe bool, + ) + ResolveMediaTrackForSubscriber(sub LocalParticipant, trackID livekit.TrackID) MediaResolverResult + ResolveDataTrackForSubscriber(sub LocalParticipant, trackID livekit.TrackID) DataResolverResult + GetLocalParticipants() []LocalParticipant + IsDataMessageUserPacketDuplicate(ip *livekit.UserPacket) bool +} + +// MediaTrack represents a media track +// +//counterfeiter:generate . MediaTrack +type MediaTrack interface { + ID() livekit.TrackID + Kind() livekit.TrackType + Name() string + Source() livekit.TrackSource + Stream() string + + UpdateTrackInfo(ti *livekit.TrackInfo) + UpdateAudioTrack(update *livekit.UpdateLocalAudioTrack) + UpdateVideoTrack(update *livekit.UpdateLocalVideoTrack) + ToProto() *livekit.TrackInfo + + PublisherID() livekit.ParticipantID + PublisherIdentity() livekit.ParticipantIdentity + PublisherVersion() uint32 + Logger() logger.Logger + + IsMuted() bool + SetMuted(muted bool) + + GetAudioLevel() (level float64, active bool) + + Close(isExpectedToResume bool) + IsOpen() bool + + // callbacks + AddOnClose(func(isExpectedToResume bool)) + + // subscribers + AddSubscriber(participant LocalParticipant) (SubscribedTrack, error) + RemoveSubscriber(participantID livekit.ParticipantID, isExpectedToResume bool) + IsSubscriber(subID livekit.ParticipantID) bool + RevokeDisallowedSubscribers(allowedSubscriberIdentities []livekit.ParticipantIdentity) []livekit.ParticipantIdentity + GetAllSubscribers() []livekit.ParticipantID + GetNumSubscribers() int + OnTrackSubscribed() + + // returns quality information that's appropriate for width & height + GetQualityForDimension(mimeType mime.MimeType, width, height uint32) livekit.VideoQuality + + // returns temporal layer that's appropriate for fps + GetTemporalLayerForSpatialFps(mimeType mime.MimeType, spatial int32, fps uint32) int32 + + Receivers() []sfu.TrackReceiver + ClearAllReceivers(isExpectedToResume bool) + + IsEncrypted() bool +} + +//counterfeiter:generate . LocalMediaTrack +type LocalMediaTrack interface { + MediaTrack + + Restart() + + HasSignalCid(cid string) bool + HasSdpCid(cid string) bool + + GetConnectionScoreAndQuality() (float32, livekit.ConnectionQuality) + GetTrackStats() *livekit.RTPStats + + SetRTT(rtt uint32) + + NotifySubscriberNodeMaxQuality(nodeID livekit.NodeID, qualities []SubscribedCodecQuality) + NotifySubscriptionNode(nodeID livekit.NodeID, codecs []*livekit.SubscribedAudioCodec) + ClearSubscriberNodes() + NotifySubscriberNodeMediaLoss(nodeID livekit.NodeID, fractionalLoss uint8) +} + +// DataTrack represents a data track +// +//counterfeiter:generate . DataTrack +type DataTrack interface { + ID() livekit.TrackID + PubHandle() uint16 + Name() string + ToProto() *livekit.DataTrackInfo + + PublisherID() livekit.ParticipantID + PublisherIdentity() livekit.ParticipantIdentity + + AddSubscriber(sub LocalParticipant) (DataDownTrack, error) + RemoveSubscriber(participantID livekit.ParticipantID) + IsSubscriber(subID livekit.ParticipantID) bool + + AddDataDownTrack(sender DataTrackSender) error + DeleteDataDownTrack(subscriberID livekit.ParticipantID) + + HandlePacket(data []byte, packet *datatrack.Packet, arrivalTime int64) + + Close() +} + +//counterfeiter:generate . DataDownTrack +type DataDownTrack interface { + Close() + + Handle() uint16 + PublishDataTrack() DataTrack + + UpdateSubscriptionOptions(subscriptionOptions *livekit.DataTrackSubscriptionOptions) +} + +//counterfeiter:generate . DataTrackSender +type DataTrackSender interface { + SubscriberID() livekit.ParticipantID + + WritePacket(data []byte, packet *datatrack.Packet, arrivalTime int64) +} + +//counterfeiter:generate . DataTrackTransport +type DataTrackTransport interface { + SendDataTrackMessage(data []byte) error +} + +//counterfeiter:generate . SubscribedTrack +type SubscribedTrack interface { + AddOnBind(f func(error)) + IsBound() bool + Close(isExpectedToResume bool) + OnClose(f func(isExpectedToResume bool)) + ID() livekit.TrackID + PublisherID() livekit.ParticipantID + PublisherIdentity() livekit.ParticipantIdentity + PublisherVersion() uint32 + SubscriberID() livekit.ParticipantID + SubscriberIdentity() livekit.ParticipantIdentity + Subscriber() LocalParticipant + DownTrack() *sfu.DownTrack + MediaTrack() MediaTrack + RTPSender() *webrtc.RTPSender + IsMuted() bool + SetPublisherMuted(muted bool) + UpdateSubscriberSettings(settings *livekit.UpdateTrackSettings, isImmediate bool) + // selects appropriate video layer according to subscriber preferences + UpdateVideoLayer() + NeedsNegotiation() bool +} + +type ChangeNotifier interface { + AddObserver(key string, onChanged func()) + RemoveObserver(key string) + HasObservers() bool + NotifyChanged() +} + +type MediaResolverResult struct { + TrackChangedNotifier ChangeNotifier + TrackRemovedNotifier ChangeNotifier + Track MediaTrack + // is permission given to the requesting participant + HasPermission bool + PublisherID livekit.ParticipantID + PublisherIdentity livekit.ParticipantIdentity +} + +type DataResolverResult struct { + TrackChangedNotifier ChangeNotifier + TrackRemovedNotifier ChangeNotifier + DataTrack DataTrack + PublisherID livekit.ParticipantID + PublisherIdentity livekit.ParticipantIdentity +} + +// MediaTrackResolver locates a specific media track for a subscriber +type MediaTrackResolver func(LocalParticipant, livekit.TrackID) MediaResolverResult + +// DataTrackResolver locates a specific data track for a subscriber +type DataTrackResolver func(LocalParticipant, livekit.TrackID) DataResolverResult + +// Supervisor/operation monitor related definitions +type OperationMonitorEvent int + +const ( + OperationMonitorEventPublisherPeerConnectionConnected OperationMonitorEvent = iota + OperationMonitorEventAddPendingPublication + OperationMonitorEventSetPublicationMute + OperationMonitorEventSetPublishedTrack + OperationMonitorEventClearPublishedTrack +) + +func (o OperationMonitorEvent) String() string { + switch o { + case OperationMonitorEventPublisherPeerConnectionConnected: + return "PUBLISHER_PEER_CONNECTION_CONNECTED" + case OperationMonitorEventAddPendingPublication: + return "ADD_PENDING_PUBLICATION" + case OperationMonitorEventSetPublicationMute: + return "SET_PUBLICATION_MUTE" + case OperationMonitorEventSetPublishedTrack: + return "SET_PUBLISHED_TRACK" + case OperationMonitorEventClearPublishedTrack: + return "CLEAR_PUBLISHED_TRACK" + default: + return fmt.Sprintf("%d", int(o)) + } +} + +type OperationMonitorData any + +type OperationMonitor interface { + PostEvent(ome OperationMonitorEvent, omd OperationMonitorData) + Check() error + IsIdle() bool +} diff --git a/livekit/pkg/rtc/types/protocol_version.go b/livekit/pkg/rtc/types/protocol_version.go new file mode 100644 index 0000000..e7f5698 --- /dev/null +++ b/livekit/pkg/rtc/types/protocol_version.go @@ -0,0 +1,101 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +type ProtocolVersion int + +const CurrentProtocol = 16 + +func (v ProtocolVersion) SupportsPackedStreamId() bool { + return v > 0 +} + +func (v ProtocolVersion) SupportsProtobuf() bool { + return v > 0 +} + +func (v ProtocolVersion) HandlesDataPackets() bool { + return v > 1 +} + +// SubscriberAsPrimary indicates clients initiate subscriber connection as primary +func (v ProtocolVersion) SubscriberAsPrimary() bool { + return v > 2 +} + +// SupportsSpeakerChanged - if client handles speaker info deltas, instead of a comprehensive list +func (v ProtocolVersion) SupportsSpeakerChanged() bool { + return v > 2 +} + +// SupportsTransceiverReuse - if transceiver reuse is supported, optimizes SDP size +func (v ProtocolVersion) SupportsTransceiverReuse() bool { + return v > 3 +} + +// SupportsConnectionQuality - avoid sending frequent ConnectionQuality updates for lower protocol versions +func (v ProtocolVersion) SupportsConnectionQuality() bool { + return v > 4 +} + +func (v ProtocolVersion) SupportsSessionMigrate() bool { + return v > 5 +} + +func (v ProtocolVersion) SupportsICELite() bool { + return v > 5 +} + +func (v ProtocolVersion) SupportsUnpublish() bool { + return v > 6 +} + +// SupportFastStart - if client supports fast start, server side will send media streams +// in the first offer +func (v ProtocolVersion) SupportFastStart() bool { + return v > 7 +} + +func (v ProtocolVersion) SupportsDisconnectedUpdate() bool { + return v > 8 +} + +func (v ProtocolVersion) SupportsSyncStreamID() bool { + return v > 9 +} + +func (v ProtocolVersion) SupportsConnectionQualityLost() bool { + return v > 10 +} + +func (v ProtocolVersion) SupportsAsyncRoomID() bool { + return v > 11 +} + +func (v ProtocolVersion) SupportsIdentityBasedReconnection() bool { + return v > 11 +} + +func (v ProtocolVersion) SupportsRegionsInLeaveRequest() bool { + return v > 12 +} + +func (v ProtocolVersion) SupportsNonErrorSignalResponse() bool { + return v > 14 +} + +func (v ProtocolVersion) SupportsMoving() bool { + return v > 15 +} diff --git a/livekit/pkg/rtc/types/trafficstats.go b/livekit/pkg/rtc/types/trafficstats.go new file mode 100644 index 0000000..f5c0b38 --- /dev/null +++ b/livekit/pkg/rtc/types/trafficstats.go @@ -0,0 +1,161 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "time" + + "github.com/livekit/protocol/livekit" +) + +type TrafficStats struct { + StartTime time.Time + EndTime time.Time + Packets uint32 + PacketsLost uint32 + PacketsPadding uint32 + PacketsOutOfOrder uint32 + Bytes uint64 +} + +type TrafficTypeStats struct { + TrackType livekit.TrackType + StreamType livekit.StreamType + TrafficStats *TrafficStats +} + +type TrafficLoad struct { + TrafficTypeStats []*TrafficTypeStats +} + +func RTPStatsDiffToTrafficStats(before, after *livekit.RTPStats) *TrafficStats { + if after == nil { + return nil + } + + startTime := after.StartTime + if before != nil { + startTime = before.EndTime + } + + getAfter := func() *TrafficStats { + return &TrafficStats{ + StartTime: startTime.AsTime(), + EndTime: after.EndTime.AsTime(), + Packets: after.Packets, + PacketsLost: after.PacketsLost, + PacketsPadding: after.PacketsPadding, + PacketsOutOfOrder: after.PacketsOutOfOrder, + Bytes: after.Bytes + after.BytesDuplicate + after.BytesPadding, + } + } + + if before == nil { + return getAfter() + } + + if (after.Packets - before.Packets) > (1 << 31) { + // after packets < before packets, probably got reset, just return after + return getAfter() + } + if ((after.Bytes + after.BytesDuplicate + after.BytesPadding) - (before.Bytes + before.BytesDuplicate + before.BytesPadding)) > (1 << 63) { + // after bytes < before bytes, probably got reset, just return after + return getAfter() + } + + packetsLost := uint32(0) + if after.PacketsLost >= before.PacketsLost { + packetsLost = after.PacketsLost - before.PacketsLost + } + return &TrafficStats{ + StartTime: startTime.AsTime(), + EndTime: after.EndTime.AsTime(), + Packets: after.Packets - before.Packets, + PacketsLost: packetsLost, + PacketsPadding: after.PacketsPadding - before.PacketsPadding, + PacketsOutOfOrder: after.PacketsOutOfOrder - before.PacketsOutOfOrder, + Bytes: (after.Bytes + after.BytesDuplicate + after.BytesPadding) - (before.Bytes + before.BytesDuplicate + before.BytesPadding), + } +} + +func AggregateTrafficStats(statsList ...*TrafficStats) *TrafficStats { + if len(statsList) == 0 { + return nil + } + + startTime := time.Time{} + endTime := time.Time{} + + packets := uint32(0) + packetsLost := uint32(0) + packetsPadding := uint32(0) + packetsOutOfOrder := uint32(0) + bytes := uint64(0) + + for _, stats := range statsList { + if startTime.IsZero() || startTime.After(stats.StartTime) { + startTime = stats.StartTime + } + + if endTime.IsZero() || endTime.Before(stats.EndTime) { + endTime = stats.EndTime + } + + packets += stats.Packets + packetsLost += stats.PacketsLost + packetsPadding += stats.PacketsPadding + packetsOutOfOrder += stats.PacketsOutOfOrder + bytes += stats.Bytes + } + + if endTime.IsZero() { + endTime = time.Now() + } + return &TrafficStats{ + StartTime: startTime, + EndTime: endTime, + Packets: packets, + PacketsLost: packetsLost, + PacketsPadding: packetsPadding, + PacketsOutOfOrder: packetsOutOfOrder, + Bytes: bytes, + } +} + +func TrafficLoadToTrafficRate(trafficLoad *TrafficLoad) ( + packetRateIn float64, + byteRateIn float64, + packetRateOut float64, + byteRateOut float64, +) { + if trafficLoad == nil { + return + } + + for _, trafficTypeStat := range trafficLoad.TrafficTypeStats { + elapsed := trafficTypeStat.TrafficStats.EndTime.Sub(trafficTypeStat.TrafficStats.StartTime).Seconds() + packetRate := float64(trafficTypeStat.TrafficStats.Packets) / elapsed + byteRate := float64(trafficTypeStat.TrafficStats.Bytes) / elapsed + switch trafficTypeStat.StreamType { + case livekit.StreamType_UPSTREAM: + packetRateIn += packetRate + byteRateIn += byteRate + case livekit.StreamType_DOWNSTREAM: + packetRateOut += packetRate + byteRateOut += byteRate + } + } + return +} diff --git a/livekit/pkg/rtc/types/typesfakes/fake_data_down_track.go b/livekit/pkg/rtc/types/typesfakes/fake_data_down_track.go new file mode 100644 index 0000000..56d0d10 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_data_down_track.go @@ -0,0 +1,229 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeDataDownTrack struct { + CloseStub func() + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + HandleStub func() uint16 + handleMutex sync.RWMutex + handleArgsForCall []struct { + } + handleReturns struct { + result1 uint16 + } + handleReturnsOnCall map[int]struct { + result1 uint16 + } + PublishDataTrackStub func() types.DataTrack + publishDataTrackMutex sync.RWMutex + publishDataTrackArgsForCall []struct { + } + publishDataTrackReturns struct { + result1 types.DataTrack + } + publishDataTrackReturnsOnCall map[int]struct { + result1 types.DataTrack + } + UpdateSubscriptionOptionsStub func(*livekit.DataTrackSubscriptionOptions) + updateSubscriptionOptionsMutex sync.RWMutex + updateSubscriptionOptionsArgsForCall []struct { + arg1 *livekit.DataTrackSubscriptionOptions + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDataDownTrack) Close() { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub() + } +} + +func (fake *FakeDataDownTrack) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeDataDownTrack) CloseCalls(stub func()) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeDataDownTrack) Handle() uint16 { + fake.handleMutex.Lock() + ret, specificReturn := fake.handleReturnsOnCall[len(fake.handleArgsForCall)] + fake.handleArgsForCall = append(fake.handleArgsForCall, struct { + }{}) + stub := fake.HandleStub + fakeReturns := fake.handleReturns + fake.recordInvocation("Handle", []interface{}{}) + fake.handleMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataDownTrack) HandleCallCount() int { + fake.handleMutex.RLock() + defer fake.handleMutex.RUnlock() + return len(fake.handleArgsForCall) +} + +func (fake *FakeDataDownTrack) HandleCalls(stub func() uint16) { + fake.handleMutex.Lock() + defer fake.handleMutex.Unlock() + fake.HandleStub = stub +} + +func (fake *FakeDataDownTrack) HandleReturns(result1 uint16) { + fake.handleMutex.Lock() + defer fake.handleMutex.Unlock() + fake.HandleStub = nil + fake.handleReturns = struct { + result1 uint16 + }{result1} +} + +func (fake *FakeDataDownTrack) HandleReturnsOnCall(i int, result1 uint16) { + fake.handleMutex.Lock() + defer fake.handleMutex.Unlock() + fake.HandleStub = nil + if fake.handleReturnsOnCall == nil { + fake.handleReturnsOnCall = make(map[int]struct { + result1 uint16 + }) + } + fake.handleReturnsOnCall[i] = struct { + result1 uint16 + }{result1} +} + +func (fake *FakeDataDownTrack) PublishDataTrack() types.DataTrack { + fake.publishDataTrackMutex.Lock() + ret, specificReturn := fake.publishDataTrackReturnsOnCall[len(fake.publishDataTrackArgsForCall)] + fake.publishDataTrackArgsForCall = append(fake.publishDataTrackArgsForCall, struct { + }{}) + stub := fake.PublishDataTrackStub + fakeReturns := fake.publishDataTrackReturns + fake.recordInvocation("PublishDataTrack", []interface{}{}) + fake.publishDataTrackMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataDownTrack) PublishDataTrackCallCount() int { + fake.publishDataTrackMutex.RLock() + defer fake.publishDataTrackMutex.RUnlock() + return len(fake.publishDataTrackArgsForCall) +} + +func (fake *FakeDataDownTrack) PublishDataTrackCalls(stub func() types.DataTrack) { + fake.publishDataTrackMutex.Lock() + defer fake.publishDataTrackMutex.Unlock() + fake.PublishDataTrackStub = stub +} + +func (fake *FakeDataDownTrack) PublishDataTrackReturns(result1 types.DataTrack) { + fake.publishDataTrackMutex.Lock() + defer fake.publishDataTrackMutex.Unlock() + fake.PublishDataTrackStub = nil + fake.publishDataTrackReturns = struct { + result1 types.DataTrack + }{result1} +} + +func (fake *FakeDataDownTrack) PublishDataTrackReturnsOnCall(i int, result1 types.DataTrack) { + fake.publishDataTrackMutex.Lock() + defer fake.publishDataTrackMutex.Unlock() + fake.PublishDataTrackStub = nil + if fake.publishDataTrackReturnsOnCall == nil { + fake.publishDataTrackReturnsOnCall = make(map[int]struct { + result1 types.DataTrack + }) + } + fake.publishDataTrackReturnsOnCall[i] = struct { + result1 types.DataTrack + }{result1} +} + +func (fake *FakeDataDownTrack) UpdateSubscriptionOptions(arg1 *livekit.DataTrackSubscriptionOptions) { + fake.updateSubscriptionOptionsMutex.Lock() + fake.updateSubscriptionOptionsArgsForCall = append(fake.updateSubscriptionOptionsArgsForCall, struct { + arg1 *livekit.DataTrackSubscriptionOptions + }{arg1}) + stub := fake.UpdateSubscriptionOptionsStub + fake.recordInvocation("UpdateSubscriptionOptions", []interface{}{arg1}) + fake.updateSubscriptionOptionsMutex.Unlock() + if stub != nil { + fake.UpdateSubscriptionOptionsStub(arg1) + } +} + +func (fake *FakeDataDownTrack) UpdateSubscriptionOptionsCallCount() int { + fake.updateSubscriptionOptionsMutex.RLock() + defer fake.updateSubscriptionOptionsMutex.RUnlock() + return len(fake.updateSubscriptionOptionsArgsForCall) +} + +func (fake *FakeDataDownTrack) UpdateSubscriptionOptionsCalls(stub func(*livekit.DataTrackSubscriptionOptions)) { + fake.updateSubscriptionOptionsMutex.Lock() + defer fake.updateSubscriptionOptionsMutex.Unlock() + fake.UpdateSubscriptionOptionsStub = stub +} + +func (fake *FakeDataDownTrack) UpdateSubscriptionOptionsArgsForCall(i int) *livekit.DataTrackSubscriptionOptions { + fake.updateSubscriptionOptionsMutex.RLock() + defer fake.updateSubscriptionOptionsMutex.RUnlock() + argsForCall := fake.updateSubscriptionOptionsArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataDownTrack) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDataDownTrack) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.DataDownTrack = new(FakeDataDownTrack) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_data_track.go b/livekit/pkg/rtc/types/typesfakes/fake_data_track.go new file mode 100644 index 0000000..c54d77c --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_data_track.go @@ -0,0 +1,786 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeDataTrack struct { + AddDataDownTrackStub func(types.DataTrackSender) error + addDataDownTrackMutex sync.RWMutex + addDataDownTrackArgsForCall []struct { + arg1 types.DataTrackSender + } + addDataDownTrackReturns struct { + result1 error + } + addDataDownTrackReturnsOnCall map[int]struct { + result1 error + } + AddSubscriberStub func(types.LocalParticipant) (types.DataDownTrack, error) + addSubscriberMutex sync.RWMutex + addSubscriberArgsForCall []struct { + arg1 types.LocalParticipant + } + addSubscriberReturns struct { + result1 types.DataDownTrack + result2 error + } + addSubscriberReturnsOnCall map[int]struct { + result1 types.DataDownTrack + result2 error + } + CloseStub func() + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + DeleteDataDownTrackStub func(livekit.ParticipantID) + deleteDataDownTrackMutex sync.RWMutex + deleteDataDownTrackArgsForCall []struct { + arg1 livekit.ParticipantID + } + HandlePacketStub func([]byte, *datatrack.Packet, int64) + handlePacketMutex sync.RWMutex + handlePacketArgsForCall []struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + } + IDStub func() livekit.TrackID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.TrackID + } + iDReturnsOnCall map[int]struct { + result1 livekit.TrackID + } + IsSubscriberStub func(livekit.ParticipantID) bool + isSubscriberMutex sync.RWMutex + isSubscriberArgsForCall []struct { + arg1 livekit.ParticipantID + } + isSubscriberReturns struct { + result1 bool + } + isSubscriberReturnsOnCall map[int]struct { + result1 bool + } + NameStub func() string + nameMutex sync.RWMutex + nameArgsForCall []struct { + } + nameReturns struct { + result1 string + } + nameReturnsOnCall map[int]struct { + result1 string + } + PubHandleStub func() uint16 + pubHandleMutex sync.RWMutex + pubHandleArgsForCall []struct { + } + pubHandleReturns struct { + result1 uint16 + } + pubHandleReturnsOnCall map[int]struct { + result1 uint16 + } + PublisherIDStub func() livekit.ParticipantID + publisherIDMutex sync.RWMutex + publisherIDArgsForCall []struct { + } + publisherIDReturns struct { + result1 livekit.ParticipantID + } + publisherIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + PublisherIdentityStub func() livekit.ParticipantIdentity + publisherIdentityMutex sync.RWMutex + publisherIdentityArgsForCall []struct { + } + publisherIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + publisherIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + RemoveSubscriberStub func(livekit.ParticipantID) + removeSubscriberMutex sync.RWMutex + removeSubscriberArgsForCall []struct { + arg1 livekit.ParticipantID + } + ToProtoStub func() *livekit.DataTrackInfo + toProtoMutex sync.RWMutex + toProtoArgsForCall []struct { + } + toProtoReturns struct { + result1 *livekit.DataTrackInfo + } + toProtoReturnsOnCall map[int]struct { + result1 *livekit.DataTrackInfo + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDataTrack) AddDataDownTrack(arg1 types.DataTrackSender) error { + fake.addDataDownTrackMutex.Lock() + ret, specificReturn := fake.addDataDownTrackReturnsOnCall[len(fake.addDataDownTrackArgsForCall)] + fake.addDataDownTrackArgsForCall = append(fake.addDataDownTrackArgsForCall, struct { + arg1 types.DataTrackSender + }{arg1}) + stub := fake.AddDataDownTrackStub + fakeReturns := fake.addDataDownTrackReturns + fake.recordInvocation("AddDataDownTrack", []interface{}{arg1}) + fake.addDataDownTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) AddDataDownTrackCallCount() int { + fake.addDataDownTrackMutex.RLock() + defer fake.addDataDownTrackMutex.RUnlock() + return len(fake.addDataDownTrackArgsForCall) +} + +func (fake *FakeDataTrack) AddDataDownTrackCalls(stub func(types.DataTrackSender) error) { + fake.addDataDownTrackMutex.Lock() + defer fake.addDataDownTrackMutex.Unlock() + fake.AddDataDownTrackStub = stub +} + +func (fake *FakeDataTrack) AddDataDownTrackArgsForCall(i int) types.DataTrackSender { + fake.addDataDownTrackMutex.RLock() + defer fake.addDataDownTrackMutex.RUnlock() + argsForCall := fake.addDataDownTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataTrack) AddDataDownTrackReturns(result1 error) { + fake.addDataDownTrackMutex.Lock() + defer fake.addDataDownTrackMutex.Unlock() + fake.AddDataDownTrackStub = nil + fake.addDataDownTrackReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDataTrack) AddDataDownTrackReturnsOnCall(i int, result1 error) { + fake.addDataDownTrackMutex.Lock() + defer fake.addDataDownTrackMutex.Unlock() + fake.AddDataDownTrackStub = nil + if fake.addDataDownTrackReturnsOnCall == nil { + fake.addDataDownTrackReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.addDataDownTrackReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDataTrack) AddSubscriber(arg1 types.LocalParticipant) (types.DataDownTrack, error) { + fake.addSubscriberMutex.Lock() + ret, specificReturn := fake.addSubscriberReturnsOnCall[len(fake.addSubscriberArgsForCall)] + fake.addSubscriberArgsForCall = append(fake.addSubscriberArgsForCall, struct { + arg1 types.LocalParticipant + }{arg1}) + stub := fake.AddSubscriberStub + fakeReturns := fake.addSubscriberReturns + fake.recordInvocation("AddSubscriber", []interface{}{arg1}) + fake.addSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeDataTrack) AddSubscriberCallCount() int { + fake.addSubscriberMutex.RLock() + defer fake.addSubscriberMutex.RUnlock() + return len(fake.addSubscriberArgsForCall) +} + +func (fake *FakeDataTrack) AddSubscriberCalls(stub func(types.LocalParticipant) (types.DataDownTrack, error)) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = stub +} + +func (fake *FakeDataTrack) AddSubscriberArgsForCall(i int) types.LocalParticipant { + fake.addSubscriberMutex.RLock() + defer fake.addSubscriberMutex.RUnlock() + argsForCall := fake.addSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataTrack) AddSubscriberReturns(result1 types.DataDownTrack, result2 error) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = nil + fake.addSubscriberReturns = struct { + result1 types.DataDownTrack + result2 error + }{result1, result2} +} + +func (fake *FakeDataTrack) AddSubscriberReturnsOnCall(i int, result1 types.DataDownTrack, result2 error) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = nil + if fake.addSubscriberReturnsOnCall == nil { + fake.addSubscriberReturnsOnCall = make(map[int]struct { + result1 types.DataDownTrack + result2 error + }) + } + fake.addSubscriberReturnsOnCall[i] = struct { + result1 types.DataDownTrack + result2 error + }{result1, result2} +} + +func (fake *FakeDataTrack) Close() { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub() + } +} + +func (fake *FakeDataTrack) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeDataTrack) CloseCalls(stub func()) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeDataTrack) DeleteDataDownTrack(arg1 livekit.ParticipantID) { + fake.deleteDataDownTrackMutex.Lock() + fake.deleteDataDownTrackArgsForCall = append(fake.deleteDataDownTrackArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.DeleteDataDownTrackStub + fake.recordInvocation("DeleteDataDownTrack", []interface{}{arg1}) + fake.deleteDataDownTrackMutex.Unlock() + if stub != nil { + fake.DeleteDataDownTrackStub(arg1) + } +} + +func (fake *FakeDataTrack) DeleteDataDownTrackCallCount() int { + fake.deleteDataDownTrackMutex.RLock() + defer fake.deleteDataDownTrackMutex.RUnlock() + return len(fake.deleteDataDownTrackArgsForCall) +} + +func (fake *FakeDataTrack) DeleteDataDownTrackCalls(stub func(livekit.ParticipantID)) { + fake.deleteDataDownTrackMutex.Lock() + defer fake.deleteDataDownTrackMutex.Unlock() + fake.DeleteDataDownTrackStub = stub +} + +func (fake *FakeDataTrack) DeleteDataDownTrackArgsForCall(i int) livekit.ParticipantID { + fake.deleteDataDownTrackMutex.RLock() + defer fake.deleteDataDownTrackMutex.RUnlock() + argsForCall := fake.deleteDataDownTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataTrack) HandlePacket(arg1 []byte, arg2 *datatrack.Packet, arg3 int64) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.handlePacketMutex.Lock() + fake.handlePacketArgsForCall = append(fake.handlePacketArgsForCall, struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + }{arg1Copy, arg2, arg3}) + stub := fake.HandlePacketStub + fake.recordInvocation("HandlePacket", []interface{}{arg1Copy, arg2, arg3}) + fake.handlePacketMutex.Unlock() + if stub != nil { + fake.HandlePacketStub(arg1, arg2, arg3) + } +} + +func (fake *FakeDataTrack) HandlePacketCallCount() int { + fake.handlePacketMutex.RLock() + defer fake.handlePacketMutex.RUnlock() + return len(fake.handlePacketArgsForCall) +} + +func (fake *FakeDataTrack) HandlePacketCalls(stub func([]byte, *datatrack.Packet, int64)) { + fake.handlePacketMutex.Lock() + defer fake.handlePacketMutex.Unlock() + fake.HandlePacketStub = stub +} + +func (fake *FakeDataTrack) HandlePacketArgsForCall(i int) ([]byte, *datatrack.Packet, int64) { + fake.handlePacketMutex.RLock() + defer fake.handlePacketMutex.RUnlock() + argsForCall := fake.handlePacketArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeDataTrack) ID() livekit.TrackID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeDataTrack) IDCalls(stub func() livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeDataTrack) IDReturns(result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeDataTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.TrackID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeDataTrack) IsSubscriber(arg1 livekit.ParticipantID) bool { + fake.isSubscriberMutex.Lock() + ret, specificReturn := fake.isSubscriberReturnsOnCall[len(fake.isSubscriberArgsForCall)] + fake.isSubscriberArgsForCall = append(fake.isSubscriberArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.IsSubscriberStub + fakeReturns := fake.isSubscriberReturns + fake.recordInvocation("IsSubscriber", []interface{}{arg1}) + fake.isSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) IsSubscriberCallCount() int { + fake.isSubscriberMutex.RLock() + defer fake.isSubscriberMutex.RUnlock() + return len(fake.isSubscriberArgsForCall) +} + +func (fake *FakeDataTrack) IsSubscriberCalls(stub func(livekit.ParticipantID) bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = stub +} + +func (fake *FakeDataTrack) IsSubscriberArgsForCall(i int) livekit.ParticipantID { + fake.isSubscriberMutex.RLock() + defer fake.isSubscriberMutex.RUnlock() + argsForCall := fake.isSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataTrack) IsSubscriberReturns(result1 bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = nil + fake.isSubscriberReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeDataTrack) IsSubscriberReturnsOnCall(i int, result1 bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = nil + if fake.isSubscriberReturnsOnCall == nil { + fake.isSubscriberReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isSubscriberReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeDataTrack) Name() string { + fake.nameMutex.Lock() + ret, specificReturn := fake.nameReturnsOnCall[len(fake.nameArgsForCall)] + fake.nameArgsForCall = append(fake.nameArgsForCall, struct { + }{}) + stub := fake.NameStub + fakeReturns := fake.nameReturns + fake.recordInvocation("Name", []interface{}{}) + fake.nameMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) NameCallCount() int { + fake.nameMutex.RLock() + defer fake.nameMutex.RUnlock() + return len(fake.nameArgsForCall) +} + +func (fake *FakeDataTrack) NameCalls(stub func() string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = stub +} + +func (fake *FakeDataTrack) NameReturns(result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + fake.nameReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeDataTrack) NameReturnsOnCall(i int, result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + if fake.nameReturnsOnCall == nil { + fake.nameReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.nameReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeDataTrack) PubHandle() uint16 { + fake.pubHandleMutex.Lock() + ret, specificReturn := fake.pubHandleReturnsOnCall[len(fake.pubHandleArgsForCall)] + fake.pubHandleArgsForCall = append(fake.pubHandleArgsForCall, struct { + }{}) + stub := fake.PubHandleStub + fakeReturns := fake.pubHandleReturns + fake.recordInvocation("PubHandle", []interface{}{}) + fake.pubHandleMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) PubHandleCallCount() int { + fake.pubHandleMutex.RLock() + defer fake.pubHandleMutex.RUnlock() + return len(fake.pubHandleArgsForCall) +} + +func (fake *FakeDataTrack) PubHandleCalls(stub func() uint16) { + fake.pubHandleMutex.Lock() + defer fake.pubHandleMutex.Unlock() + fake.PubHandleStub = stub +} + +func (fake *FakeDataTrack) PubHandleReturns(result1 uint16) { + fake.pubHandleMutex.Lock() + defer fake.pubHandleMutex.Unlock() + fake.PubHandleStub = nil + fake.pubHandleReturns = struct { + result1 uint16 + }{result1} +} + +func (fake *FakeDataTrack) PubHandleReturnsOnCall(i int, result1 uint16) { + fake.pubHandleMutex.Lock() + defer fake.pubHandleMutex.Unlock() + fake.PubHandleStub = nil + if fake.pubHandleReturnsOnCall == nil { + fake.pubHandleReturnsOnCall = make(map[int]struct { + result1 uint16 + }) + } + fake.pubHandleReturnsOnCall[i] = struct { + result1 uint16 + }{result1} +} + +func (fake *FakeDataTrack) PublisherID() livekit.ParticipantID { + fake.publisherIDMutex.Lock() + ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] + fake.publisherIDArgsForCall = append(fake.publisherIDArgsForCall, struct { + }{}) + stub := fake.PublisherIDStub + fakeReturns := fake.publisherIDReturns + fake.recordInvocation("PublisherID", []interface{}{}) + fake.publisherIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) PublisherIDCallCount() int { + fake.publisherIDMutex.RLock() + defer fake.publisherIDMutex.RUnlock() + return len(fake.publisherIDArgsForCall) +} + +func (fake *FakeDataTrack) PublisherIDCalls(stub func() livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = stub +} + +func (fake *FakeDataTrack) PublisherIDReturns(result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + fake.publisherIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeDataTrack) PublisherIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + if fake.publisherIDReturnsOnCall == nil { + fake.publisherIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.publisherIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeDataTrack) PublisherIdentity() livekit.ParticipantIdentity { + fake.publisherIdentityMutex.Lock() + ret, specificReturn := fake.publisherIdentityReturnsOnCall[len(fake.publisherIdentityArgsForCall)] + fake.publisherIdentityArgsForCall = append(fake.publisherIdentityArgsForCall, struct { + }{}) + stub := fake.PublisherIdentityStub + fakeReturns := fake.publisherIdentityReturns + fake.recordInvocation("PublisherIdentity", []interface{}{}) + fake.publisherIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) PublisherIdentityCallCount() int { + fake.publisherIdentityMutex.RLock() + defer fake.publisherIdentityMutex.RUnlock() + return len(fake.publisherIdentityArgsForCall) +} + +func (fake *FakeDataTrack) PublisherIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = stub +} + +func (fake *FakeDataTrack) PublisherIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + fake.publisherIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeDataTrack) PublisherIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + if fake.publisherIdentityReturnsOnCall == nil { + fake.publisherIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.publisherIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeDataTrack) RemoveSubscriber(arg1 livekit.ParticipantID) { + fake.removeSubscriberMutex.Lock() + fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.RemoveSubscriberStub + fake.recordInvocation("RemoveSubscriber", []interface{}{arg1}) + fake.removeSubscriberMutex.Unlock() + if stub != nil { + fake.RemoveSubscriberStub(arg1) + } +} + +func (fake *FakeDataTrack) RemoveSubscriberCallCount() int { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + return len(fake.removeSubscriberArgsForCall) +} + +func (fake *FakeDataTrack) RemoveSubscriberCalls(stub func(livekit.ParticipantID)) { + fake.removeSubscriberMutex.Lock() + defer fake.removeSubscriberMutex.Unlock() + fake.RemoveSubscriberStub = stub +} + +func (fake *FakeDataTrack) RemoveSubscriberArgsForCall(i int) livekit.ParticipantID { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + argsForCall := fake.removeSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataTrack) ToProto() *livekit.DataTrackInfo { + fake.toProtoMutex.Lock() + ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] + fake.toProtoArgsForCall = append(fake.toProtoArgsForCall, struct { + }{}) + stub := fake.ToProtoStub + fakeReturns := fake.toProtoReturns + fake.recordInvocation("ToProto", []interface{}{}) + fake.toProtoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrack) ToProtoCallCount() int { + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() + return len(fake.toProtoArgsForCall) +} + +func (fake *FakeDataTrack) ToProtoCalls(stub func() *livekit.DataTrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = stub +} + +func (fake *FakeDataTrack) ToProtoReturns(result1 *livekit.DataTrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + fake.toProtoReturns = struct { + result1 *livekit.DataTrackInfo + }{result1} +} + +func (fake *FakeDataTrack) ToProtoReturnsOnCall(i int, result1 *livekit.DataTrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + if fake.toProtoReturnsOnCall == nil { + fake.toProtoReturnsOnCall = make(map[int]struct { + result1 *livekit.DataTrackInfo + }) + } + fake.toProtoReturnsOnCall[i] = struct { + result1 *livekit.DataTrackInfo + }{result1} +} + +func (fake *FakeDataTrack) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDataTrack) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.DataTrack = new(FakeDataTrack) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_data_track_sender.go b/livekit/pkg/rtc/types/typesfakes/fake_data_track_sender.go new file mode 100644 index 0000000..a9a6948 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_data_track_sender.go @@ -0,0 +1,148 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeDataTrackSender struct { + SubscriberIDStub func() livekit.ParticipantID + subscriberIDMutex sync.RWMutex + subscriberIDArgsForCall []struct { + } + subscriberIDReturns struct { + result1 livekit.ParticipantID + } + subscriberIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + WritePacketStub func([]byte, *datatrack.Packet, int64) + writePacketMutex sync.RWMutex + writePacketArgsForCall []struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDataTrackSender) SubscriberID() livekit.ParticipantID { + fake.subscriberIDMutex.Lock() + ret, specificReturn := fake.subscriberIDReturnsOnCall[len(fake.subscriberIDArgsForCall)] + fake.subscriberIDArgsForCall = append(fake.subscriberIDArgsForCall, struct { + }{}) + stub := fake.SubscriberIDStub + fakeReturns := fake.subscriberIDReturns + fake.recordInvocation("SubscriberID", []interface{}{}) + fake.subscriberIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrackSender) SubscriberIDCallCount() int { + fake.subscriberIDMutex.RLock() + defer fake.subscriberIDMutex.RUnlock() + return len(fake.subscriberIDArgsForCall) +} + +func (fake *FakeDataTrackSender) SubscriberIDCalls(stub func() livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = stub +} + +func (fake *FakeDataTrackSender) SubscriberIDReturns(result1 livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = nil + fake.subscriberIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeDataTrackSender) SubscriberIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = nil + if fake.subscriberIDReturnsOnCall == nil { + fake.subscriberIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.subscriberIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeDataTrackSender) WritePacket(arg1 []byte, arg2 *datatrack.Packet, arg3 int64) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.writePacketMutex.Lock() + fake.writePacketArgsForCall = append(fake.writePacketArgsForCall, struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + }{arg1Copy, arg2, arg3}) + stub := fake.WritePacketStub + fake.recordInvocation("WritePacket", []interface{}{arg1Copy, arg2, arg3}) + fake.writePacketMutex.Unlock() + if stub != nil { + fake.WritePacketStub(arg1, arg2, arg3) + } +} + +func (fake *FakeDataTrackSender) WritePacketCallCount() int { + fake.writePacketMutex.RLock() + defer fake.writePacketMutex.RUnlock() + return len(fake.writePacketArgsForCall) +} + +func (fake *FakeDataTrackSender) WritePacketCalls(stub func([]byte, *datatrack.Packet, int64)) { + fake.writePacketMutex.Lock() + defer fake.writePacketMutex.Unlock() + fake.WritePacketStub = stub +} + +func (fake *FakeDataTrackSender) WritePacketArgsForCall(i int) ([]byte, *datatrack.Packet, int64) { + fake.writePacketMutex.RLock() + defer fake.writePacketMutex.RUnlock() + argsForCall := fake.writePacketArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeDataTrackSender) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDataTrackSender) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.DataTrackSender = new(FakeDataTrackSender) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_data_track_transport.go b/livekit/pkg/rtc/types/typesfakes/fake_data_track_transport.go new file mode 100644 index 0000000..d93c84f --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_data_track_transport.go @@ -0,0 +1,114 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +type FakeDataTrackTransport struct { + SendDataTrackMessageStub func([]byte) error + sendDataTrackMessageMutex sync.RWMutex + sendDataTrackMessageArgsForCall []struct { + arg1 []byte + } + sendDataTrackMessageReturns struct { + result1 error + } + sendDataTrackMessageReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDataTrackTransport) SendDataTrackMessage(arg1 []byte) error { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.sendDataTrackMessageMutex.Lock() + ret, specificReturn := fake.sendDataTrackMessageReturnsOnCall[len(fake.sendDataTrackMessageArgsForCall)] + fake.sendDataTrackMessageArgsForCall = append(fake.sendDataTrackMessageArgsForCall, struct { + arg1 []byte + }{arg1Copy}) + stub := fake.SendDataTrackMessageStub + fakeReturns := fake.sendDataTrackMessageReturns + fake.recordInvocation("SendDataTrackMessage", []interface{}{arg1Copy}) + fake.sendDataTrackMessageMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDataTrackTransport) SendDataTrackMessageCallCount() int { + fake.sendDataTrackMessageMutex.RLock() + defer fake.sendDataTrackMessageMutex.RUnlock() + return len(fake.sendDataTrackMessageArgsForCall) +} + +func (fake *FakeDataTrackTransport) SendDataTrackMessageCalls(stub func([]byte) error) { + fake.sendDataTrackMessageMutex.Lock() + defer fake.sendDataTrackMessageMutex.Unlock() + fake.SendDataTrackMessageStub = stub +} + +func (fake *FakeDataTrackTransport) SendDataTrackMessageArgsForCall(i int) []byte { + fake.sendDataTrackMessageMutex.RLock() + defer fake.sendDataTrackMessageMutex.RUnlock() + argsForCall := fake.sendDataTrackMessageArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDataTrackTransport) SendDataTrackMessageReturns(result1 error) { + fake.sendDataTrackMessageMutex.Lock() + defer fake.sendDataTrackMessageMutex.Unlock() + fake.SendDataTrackMessageStub = nil + fake.sendDataTrackMessageReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDataTrackTransport) SendDataTrackMessageReturnsOnCall(i int, result1 error) { + fake.sendDataTrackMessageMutex.Lock() + defer fake.sendDataTrackMessageMutex.Unlock() + fake.SendDataTrackMessageStub = nil + if fake.sendDataTrackMessageReturnsOnCall == nil { + fake.sendDataTrackMessageReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendDataTrackMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDataTrackTransport) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDataTrackTransport) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.DataTrackTransport = new(FakeDataTrackTransport) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_local_media_track.go b/livekit/pkg/rtc/types/typesfakes/fake_local_media_track.go new file mode 100644 index 0000000..a327509 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_local_media_track.go @@ -0,0 +1,2316 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type FakeLocalMediaTrack struct { + AddOnCloseStub func(func(isExpectedToResume bool)) + addOnCloseMutex sync.RWMutex + addOnCloseArgsForCall []struct { + arg1 func(isExpectedToResume bool) + } + AddSubscriberStub func(types.LocalParticipant) (types.SubscribedTrack, error) + addSubscriberMutex sync.RWMutex + addSubscriberArgsForCall []struct { + arg1 types.LocalParticipant + } + addSubscriberReturns struct { + result1 types.SubscribedTrack + result2 error + } + addSubscriberReturnsOnCall map[int]struct { + result1 types.SubscribedTrack + result2 error + } + ClearAllReceiversStub func(bool) + clearAllReceiversMutex sync.RWMutex + clearAllReceiversArgsForCall []struct { + arg1 bool + } + ClearSubscriberNodesStub func() + clearSubscriberNodesMutex sync.RWMutex + clearSubscriberNodesArgsForCall []struct { + } + CloseStub func(bool) + closeMutex sync.RWMutex + closeArgsForCall []struct { + arg1 bool + } + GetAllSubscribersStub func() []livekit.ParticipantID + getAllSubscribersMutex sync.RWMutex + getAllSubscribersArgsForCall []struct { + } + getAllSubscribersReturns struct { + result1 []livekit.ParticipantID + } + getAllSubscribersReturnsOnCall map[int]struct { + result1 []livekit.ParticipantID + } + GetAudioLevelStub func() (float64, bool) + getAudioLevelMutex sync.RWMutex + getAudioLevelArgsForCall []struct { + } + getAudioLevelReturns struct { + result1 float64 + result2 bool + } + getAudioLevelReturnsOnCall map[int]struct { + result1 float64 + result2 bool + } + GetConnectionScoreAndQualityStub func() (float32, livekit.ConnectionQuality) + getConnectionScoreAndQualityMutex sync.RWMutex + getConnectionScoreAndQualityArgsForCall []struct { + } + getConnectionScoreAndQualityReturns struct { + result1 float32 + result2 livekit.ConnectionQuality + } + getConnectionScoreAndQualityReturnsOnCall map[int]struct { + result1 float32 + result2 livekit.ConnectionQuality + } + GetNumSubscribersStub func() int + getNumSubscribersMutex sync.RWMutex + getNumSubscribersArgsForCall []struct { + } + getNumSubscribersReturns struct { + result1 int + } + getNumSubscribersReturnsOnCall map[int]struct { + result1 int + } + GetQualityForDimensionStub func(mime.MimeType, uint32, uint32) livekit.VideoQuality + getQualityForDimensionMutex sync.RWMutex + getQualityForDimensionArgsForCall []struct { + arg1 mime.MimeType + arg2 uint32 + arg3 uint32 + } + getQualityForDimensionReturns struct { + result1 livekit.VideoQuality + } + getQualityForDimensionReturnsOnCall map[int]struct { + result1 livekit.VideoQuality + } + GetTemporalLayerForSpatialFpsStub func(mime.MimeType, int32, uint32) int32 + getTemporalLayerForSpatialFpsMutex sync.RWMutex + getTemporalLayerForSpatialFpsArgsForCall []struct { + arg1 mime.MimeType + arg2 int32 + arg3 uint32 + } + getTemporalLayerForSpatialFpsReturns struct { + result1 int32 + } + getTemporalLayerForSpatialFpsReturnsOnCall map[int]struct { + result1 int32 + } + GetTrackStatsStub func() *livekit.RTPStats + getTrackStatsMutex sync.RWMutex + getTrackStatsArgsForCall []struct { + } + getTrackStatsReturns struct { + result1 *livekit.RTPStats + } + getTrackStatsReturnsOnCall map[int]struct { + result1 *livekit.RTPStats + } + HasSdpCidStub func(string) bool + hasSdpCidMutex sync.RWMutex + hasSdpCidArgsForCall []struct { + arg1 string + } + hasSdpCidReturns struct { + result1 bool + } + hasSdpCidReturnsOnCall map[int]struct { + result1 bool + } + HasSignalCidStub func(string) bool + hasSignalCidMutex sync.RWMutex + hasSignalCidArgsForCall []struct { + arg1 string + } + hasSignalCidReturns struct { + result1 bool + } + hasSignalCidReturnsOnCall map[int]struct { + result1 bool + } + IDStub func() livekit.TrackID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.TrackID + } + iDReturnsOnCall map[int]struct { + result1 livekit.TrackID + } + IsEncryptedStub func() bool + isEncryptedMutex sync.RWMutex + isEncryptedArgsForCall []struct { + } + isEncryptedReturns struct { + result1 bool + } + isEncryptedReturnsOnCall map[int]struct { + result1 bool + } + IsMutedStub func() bool + isMutedMutex sync.RWMutex + isMutedArgsForCall []struct { + } + isMutedReturns struct { + result1 bool + } + isMutedReturnsOnCall map[int]struct { + result1 bool + } + IsOpenStub func() bool + isOpenMutex sync.RWMutex + isOpenArgsForCall []struct { + } + isOpenReturns struct { + result1 bool + } + isOpenReturnsOnCall map[int]struct { + result1 bool + } + IsSubscriberStub func(livekit.ParticipantID) bool + isSubscriberMutex sync.RWMutex + isSubscriberArgsForCall []struct { + arg1 livekit.ParticipantID + } + isSubscriberReturns struct { + result1 bool + } + isSubscriberReturnsOnCall map[int]struct { + result1 bool + } + KindStub func() livekit.TrackType + kindMutex sync.RWMutex + kindArgsForCall []struct { + } + kindReturns struct { + result1 livekit.TrackType + } + kindReturnsOnCall map[int]struct { + result1 livekit.TrackType + } + LoggerStub func() logger.Logger + loggerMutex sync.RWMutex + loggerArgsForCall []struct { + } + loggerReturns struct { + result1 logger.Logger + } + loggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + NameStub func() string + nameMutex sync.RWMutex + nameArgsForCall []struct { + } + nameReturns struct { + result1 string + } + nameReturnsOnCall map[int]struct { + result1 string + } + NotifySubscriberNodeMaxQualityStub func(livekit.NodeID, []types.SubscribedCodecQuality) + notifySubscriberNodeMaxQualityMutex sync.RWMutex + notifySubscriberNodeMaxQualityArgsForCall []struct { + arg1 livekit.NodeID + arg2 []types.SubscribedCodecQuality + } + NotifySubscriberNodeMediaLossStub func(livekit.NodeID, uint8) + notifySubscriberNodeMediaLossMutex sync.RWMutex + notifySubscriberNodeMediaLossArgsForCall []struct { + arg1 livekit.NodeID + arg2 uint8 + } + NotifySubscriptionNodeStub func(livekit.NodeID, []*livekit.SubscribedAudioCodec) + notifySubscriptionNodeMutex sync.RWMutex + notifySubscriptionNodeArgsForCall []struct { + arg1 livekit.NodeID + arg2 []*livekit.SubscribedAudioCodec + } + OnTrackSubscribedStub func() + onTrackSubscribedMutex sync.RWMutex + onTrackSubscribedArgsForCall []struct { + } + PublisherIDStub func() livekit.ParticipantID + publisherIDMutex sync.RWMutex + publisherIDArgsForCall []struct { + } + publisherIDReturns struct { + result1 livekit.ParticipantID + } + publisherIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + PublisherIdentityStub func() livekit.ParticipantIdentity + publisherIdentityMutex sync.RWMutex + publisherIdentityArgsForCall []struct { + } + publisherIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + publisherIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + PublisherVersionStub func() uint32 + publisherVersionMutex sync.RWMutex + publisherVersionArgsForCall []struct { + } + publisherVersionReturns struct { + result1 uint32 + } + publisherVersionReturnsOnCall map[int]struct { + result1 uint32 + } + ReceiversStub func() []sfu.TrackReceiver + receiversMutex sync.RWMutex + receiversArgsForCall []struct { + } + receiversReturns struct { + result1 []sfu.TrackReceiver + } + receiversReturnsOnCall map[int]struct { + result1 []sfu.TrackReceiver + } + RemoveSubscriberStub func(livekit.ParticipantID, bool) + removeSubscriberMutex sync.RWMutex + removeSubscriberArgsForCall []struct { + arg1 livekit.ParticipantID + arg2 bool + } + RestartStub func() + restartMutex sync.RWMutex + restartArgsForCall []struct { + } + RevokeDisallowedSubscribersStub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity + revokeDisallowedSubscribersMutex sync.RWMutex + revokeDisallowedSubscribersArgsForCall []struct { + arg1 []livekit.ParticipantIdentity + } + revokeDisallowedSubscribersReturns struct { + result1 []livekit.ParticipantIdentity + } + revokeDisallowedSubscribersReturnsOnCall map[int]struct { + result1 []livekit.ParticipantIdentity + } + SetMutedStub func(bool) + setMutedMutex sync.RWMutex + setMutedArgsForCall []struct { + arg1 bool + } + SetRTTStub func(uint32) + setRTTMutex sync.RWMutex + setRTTArgsForCall []struct { + arg1 uint32 + } + SourceStub func() livekit.TrackSource + sourceMutex sync.RWMutex + sourceArgsForCall []struct { + } + sourceReturns struct { + result1 livekit.TrackSource + } + sourceReturnsOnCall map[int]struct { + result1 livekit.TrackSource + } + StreamStub func() string + streamMutex sync.RWMutex + streamArgsForCall []struct { + } + streamReturns struct { + result1 string + } + streamReturnsOnCall map[int]struct { + result1 string + } + ToProtoStub func() *livekit.TrackInfo + toProtoMutex sync.RWMutex + toProtoArgsForCall []struct { + } + toProtoReturns struct { + result1 *livekit.TrackInfo + } + toProtoReturnsOnCall map[int]struct { + result1 *livekit.TrackInfo + } + UpdateAudioTrackStub func(*livekit.UpdateLocalAudioTrack) + updateAudioTrackMutex sync.RWMutex + updateAudioTrackArgsForCall []struct { + arg1 *livekit.UpdateLocalAudioTrack + } + UpdateTrackInfoStub func(*livekit.TrackInfo) + updateTrackInfoMutex sync.RWMutex + updateTrackInfoArgsForCall []struct { + arg1 *livekit.TrackInfo + } + UpdateVideoTrackStub func(*livekit.UpdateLocalVideoTrack) + updateVideoTrackMutex sync.RWMutex + updateVideoTrackArgsForCall []struct { + arg1 *livekit.UpdateLocalVideoTrack + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeLocalMediaTrack) AddOnClose(arg1 func(isExpectedToResume bool)) { + fake.addOnCloseMutex.Lock() + fake.addOnCloseArgsForCall = append(fake.addOnCloseArgsForCall, struct { + arg1 func(isExpectedToResume bool) + }{arg1}) + stub := fake.AddOnCloseStub + fake.recordInvocation("AddOnClose", []interface{}{arg1}) + fake.addOnCloseMutex.Unlock() + if stub != nil { + fake.AddOnCloseStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) AddOnCloseCallCount() int { + fake.addOnCloseMutex.RLock() + defer fake.addOnCloseMutex.RUnlock() + return len(fake.addOnCloseArgsForCall) +} + +func (fake *FakeLocalMediaTrack) AddOnCloseCalls(stub func(func(isExpectedToResume bool))) { + fake.addOnCloseMutex.Lock() + defer fake.addOnCloseMutex.Unlock() + fake.AddOnCloseStub = stub +} + +func (fake *FakeLocalMediaTrack) AddOnCloseArgsForCall(i int) func(isExpectedToResume bool) { + fake.addOnCloseMutex.RLock() + defer fake.addOnCloseMutex.RUnlock() + argsForCall := fake.addOnCloseArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) AddSubscriber(arg1 types.LocalParticipant) (types.SubscribedTrack, error) { + fake.addSubscriberMutex.Lock() + ret, specificReturn := fake.addSubscriberReturnsOnCall[len(fake.addSubscriberArgsForCall)] + fake.addSubscriberArgsForCall = append(fake.addSubscriberArgsForCall, struct { + arg1 types.LocalParticipant + }{arg1}) + stub := fake.AddSubscriberStub + fakeReturns := fake.addSubscriberReturns + fake.recordInvocation("AddSubscriber", []interface{}{arg1}) + fake.addSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalMediaTrack) AddSubscriberCallCount() int { + fake.addSubscriberMutex.RLock() + defer fake.addSubscriberMutex.RUnlock() + return len(fake.addSubscriberArgsForCall) +} + +func (fake *FakeLocalMediaTrack) AddSubscriberCalls(stub func(types.LocalParticipant) (types.SubscribedTrack, error)) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = stub +} + +func (fake *FakeLocalMediaTrack) AddSubscriberArgsForCall(i int) types.LocalParticipant { + fake.addSubscriberMutex.RLock() + defer fake.addSubscriberMutex.RUnlock() + argsForCall := fake.addSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) AddSubscriberReturns(result1 types.SubscribedTrack, result2 error) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = nil + fake.addSubscriberReturns = struct { + result1 types.SubscribedTrack + result2 error + }{result1, result2} +} + +func (fake *FakeLocalMediaTrack) AddSubscriberReturnsOnCall(i int, result1 types.SubscribedTrack, result2 error) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = nil + if fake.addSubscriberReturnsOnCall == nil { + fake.addSubscriberReturnsOnCall = make(map[int]struct { + result1 types.SubscribedTrack + result2 error + }) + } + fake.addSubscriberReturnsOnCall[i] = struct { + result1 types.SubscribedTrack + result2 error + }{result1, result2} +} + +func (fake *FakeLocalMediaTrack) ClearAllReceivers(arg1 bool) { + fake.clearAllReceiversMutex.Lock() + fake.clearAllReceiversArgsForCall = append(fake.clearAllReceiversArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.ClearAllReceiversStub + fake.recordInvocation("ClearAllReceivers", []interface{}{arg1}) + fake.clearAllReceiversMutex.Unlock() + if stub != nil { + fake.ClearAllReceiversStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) ClearAllReceiversCallCount() int { + fake.clearAllReceiversMutex.RLock() + defer fake.clearAllReceiversMutex.RUnlock() + return len(fake.clearAllReceiversArgsForCall) +} + +func (fake *FakeLocalMediaTrack) ClearAllReceiversCalls(stub func(bool)) { + fake.clearAllReceiversMutex.Lock() + defer fake.clearAllReceiversMutex.Unlock() + fake.ClearAllReceiversStub = stub +} + +func (fake *FakeLocalMediaTrack) ClearAllReceiversArgsForCall(i int) bool { + fake.clearAllReceiversMutex.RLock() + defer fake.clearAllReceiversMutex.RUnlock() + argsForCall := fake.clearAllReceiversArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) ClearSubscriberNodes() { + fake.clearSubscriberNodesMutex.Lock() + fake.clearSubscriberNodesArgsForCall = append(fake.clearSubscriberNodesArgsForCall, struct { + }{}) + stub := fake.ClearSubscriberNodesStub + fake.recordInvocation("ClearSubscriberNodes", []interface{}{}) + fake.clearSubscriberNodesMutex.Unlock() + if stub != nil { + fake.ClearSubscriberNodesStub() + } +} + +func (fake *FakeLocalMediaTrack) ClearSubscriberNodesCallCount() int { + fake.clearSubscriberNodesMutex.RLock() + defer fake.clearSubscriberNodesMutex.RUnlock() + return len(fake.clearSubscriberNodesArgsForCall) +} + +func (fake *FakeLocalMediaTrack) ClearSubscriberNodesCalls(stub func()) { + fake.clearSubscriberNodesMutex.Lock() + defer fake.clearSubscriberNodesMutex.Unlock() + fake.ClearSubscriberNodesStub = stub +} + +func (fake *FakeLocalMediaTrack) Close(arg1 bool) { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{arg1}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeLocalMediaTrack) CloseCalls(stub func(bool)) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeLocalMediaTrack) CloseArgsForCall(i int) bool { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + argsForCall := fake.closeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) GetAllSubscribers() []livekit.ParticipantID { + fake.getAllSubscribersMutex.Lock() + ret, specificReturn := fake.getAllSubscribersReturnsOnCall[len(fake.getAllSubscribersArgsForCall)] + fake.getAllSubscribersArgsForCall = append(fake.getAllSubscribersArgsForCall, struct { + }{}) + stub := fake.GetAllSubscribersStub + fakeReturns := fake.getAllSubscribersReturns + fake.recordInvocation("GetAllSubscribers", []interface{}{}) + fake.getAllSubscribersMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) GetAllSubscribersCallCount() int { + fake.getAllSubscribersMutex.RLock() + defer fake.getAllSubscribersMutex.RUnlock() + return len(fake.getAllSubscribersArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetAllSubscribersCalls(stub func() []livekit.ParticipantID) { + fake.getAllSubscribersMutex.Lock() + defer fake.getAllSubscribersMutex.Unlock() + fake.GetAllSubscribersStub = stub +} + +func (fake *FakeLocalMediaTrack) GetAllSubscribersReturns(result1 []livekit.ParticipantID) { + fake.getAllSubscribersMutex.Lock() + defer fake.getAllSubscribersMutex.Unlock() + fake.GetAllSubscribersStub = nil + fake.getAllSubscribersReturns = struct { + result1 []livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetAllSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantID) { + fake.getAllSubscribersMutex.Lock() + defer fake.getAllSubscribersMutex.Unlock() + fake.GetAllSubscribersStub = nil + if fake.getAllSubscribersReturnsOnCall == nil { + fake.getAllSubscribersReturnsOnCall = make(map[int]struct { + result1 []livekit.ParticipantID + }) + } + fake.getAllSubscribersReturnsOnCall[i] = struct { + result1 []livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetAudioLevel() (float64, bool) { + fake.getAudioLevelMutex.Lock() + ret, specificReturn := fake.getAudioLevelReturnsOnCall[len(fake.getAudioLevelArgsForCall)] + fake.getAudioLevelArgsForCall = append(fake.getAudioLevelArgsForCall, struct { + }{}) + stub := fake.GetAudioLevelStub + fakeReturns := fake.getAudioLevelReturns + fake.recordInvocation("GetAudioLevel", []interface{}{}) + fake.getAudioLevelMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalMediaTrack) GetAudioLevelCallCount() int { + fake.getAudioLevelMutex.RLock() + defer fake.getAudioLevelMutex.RUnlock() + return len(fake.getAudioLevelArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetAudioLevelCalls(stub func() (float64, bool)) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = stub +} + +func (fake *FakeLocalMediaTrack) GetAudioLevelReturns(result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + fake.getAudioLevelReturns = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeLocalMediaTrack) GetAudioLevelReturnsOnCall(i int, result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + if fake.getAudioLevelReturnsOnCall == nil { + fake.getAudioLevelReturnsOnCall = make(map[int]struct { + result1 float64 + result2 bool + }) + } + fake.getAudioLevelReturnsOnCall[i] = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeLocalMediaTrack) GetConnectionScoreAndQuality() (float32, livekit.ConnectionQuality) { + fake.getConnectionScoreAndQualityMutex.Lock() + ret, specificReturn := fake.getConnectionScoreAndQualityReturnsOnCall[len(fake.getConnectionScoreAndQualityArgsForCall)] + fake.getConnectionScoreAndQualityArgsForCall = append(fake.getConnectionScoreAndQualityArgsForCall, struct { + }{}) + stub := fake.GetConnectionScoreAndQualityStub + fakeReturns := fake.getConnectionScoreAndQualityReturns + fake.recordInvocation("GetConnectionScoreAndQuality", []interface{}{}) + fake.getConnectionScoreAndQualityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalMediaTrack) GetConnectionScoreAndQualityCallCount() int { + fake.getConnectionScoreAndQualityMutex.RLock() + defer fake.getConnectionScoreAndQualityMutex.RUnlock() + return len(fake.getConnectionScoreAndQualityArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetConnectionScoreAndQualityCalls(stub func() (float32, livekit.ConnectionQuality)) { + fake.getConnectionScoreAndQualityMutex.Lock() + defer fake.getConnectionScoreAndQualityMutex.Unlock() + fake.GetConnectionScoreAndQualityStub = stub +} + +func (fake *FakeLocalMediaTrack) GetConnectionScoreAndQualityReturns(result1 float32, result2 livekit.ConnectionQuality) { + fake.getConnectionScoreAndQualityMutex.Lock() + defer fake.getConnectionScoreAndQualityMutex.Unlock() + fake.GetConnectionScoreAndQualityStub = nil + fake.getConnectionScoreAndQualityReturns = struct { + result1 float32 + result2 livekit.ConnectionQuality + }{result1, result2} +} + +func (fake *FakeLocalMediaTrack) GetConnectionScoreAndQualityReturnsOnCall(i int, result1 float32, result2 livekit.ConnectionQuality) { + fake.getConnectionScoreAndQualityMutex.Lock() + defer fake.getConnectionScoreAndQualityMutex.Unlock() + fake.GetConnectionScoreAndQualityStub = nil + if fake.getConnectionScoreAndQualityReturnsOnCall == nil { + fake.getConnectionScoreAndQualityReturnsOnCall = make(map[int]struct { + result1 float32 + result2 livekit.ConnectionQuality + }) + } + fake.getConnectionScoreAndQualityReturnsOnCall[i] = struct { + result1 float32 + result2 livekit.ConnectionQuality + }{result1, result2} +} + +func (fake *FakeLocalMediaTrack) GetNumSubscribers() int { + fake.getNumSubscribersMutex.Lock() + ret, specificReturn := fake.getNumSubscribersReturnsOnCall[len(fake.getNumSubscribersArgsForCall)] + fake.getNumSubscribersArgsForCall = append(fake.getNumSubscribersArgsForCall, struct { + }{}) + stub := fake.GetNumSubscribersStub + fakeReturns := fake.getNumSubscribersReturns + fake.recordInvocation("GetNumSubscribers", []interface{}{}) + fake.getNumSubscribersMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) GetNumSubscribersCallCount() int { + fake.getNumSubscribersMutex.RLock() + defer fake.getNumSubscribersMutex.RUnlock() + return len(fake.getNumSubscribersArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetNumSubscribersCalls(stub func() int) { + fake.getNumSubscribersMutex.Lock() + defer fake.getNumSubscribersMutex.Unlock() + fake.GetNumSubscribersStub = stub +} + +func (fake *FakeLocalMediaTrack) GetNumSubscribersReturns(result1 int) { + fake.getNumSubscribersMutex.Lock() + defer fake.getNumSubscribersMutex.Unlock() + fake.GetNumSubscribersStub = nil + fake.getNumSubscribersReturns = struct { + result1 int + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetNumSubscribersReturnsOnCall(i int, result1 int) { + fake.getNumSubscribersMutex.Lock() + defer fake.getNumSubscribersMutex.Unlock() + fake.GetNumSubscribersStub = nil + if fake.getNumSubscribersReturnsOnCall == nil { + fake.getNumSubscribersReturnsOnCall = make(map[int]struct { + result1 int + }) + } + fake.getNumSubscribersReturnsOnCall[i] = struct { + result1 int + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetQualityForDimension(arg1 mime.MimeType, arg2 uint32, arg3 uint32) livekit.VideoQuality { + fake.getQualityForDimensionMutex.Lock() + ret, specificReturn := fake.getQualityForDimensionReturnsOnCall[len(fake.getQualityForDimensionArgsForCall)] + fake.getQualityForDimensionArgsForCall = append(fake.getQualityForDimensionArgsForCall, struct { + arg1 mime.MimeType + arg2 uint32 + arg3 uint32 + }{arg1, arg2, arg3}) + stub := fake.GetQualityForDimensionStub + fakeReturns := fake.getQualityForDimensionReturns + fake.recordInvocation("GetQualityForDimension", []interface{}{arg1, arg2, arg3}) + fake.getQualityForDimensionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) GetQualityForDimensionCallCount() int { + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() + return len(fake.getQualityForDimensionArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetQualityForDimensionCalls(stub func(mime.MimeType, uint32, uint32) livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = stub +} + +func (fake *FakeLocalMediaTrack) GetQualityForDimensionArgsForCall(i int) (mime.MimeType, uint32, uint32) { + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() + argsForCall := fake.getQualityForDimensionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalMediaTrack) GetQualityForDimensionReturns(result1 livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = nil + fake.getQualityForDimensionReturns = struct { + result1 livekit.VideoQuality + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetQualityForDimensionReturnsOnCall(i int, result1 livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = nil + if fake.getQualityForDimensionReturnsOnCall == nil { + fake.getQualityForDimensionReturnsOnCall = make(map[int]struct { + result1 livekit.VideoQuality + }) + } + fake.getQualityForDimensionReturnsOnCall[i] = struct { + result1 livekit.VideoQuality + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFps(arg1 mime.MimeType, arg2 int32, arg3 uint32) int32 { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + ret, specificReturn := fake.getTemporalLayerForSpatialFpsReturnsOnCall[len(fake.getTemporalLayerForSpatialFpsArgsForCall)] + fake.getTemporalLayerForSpatialFpsArgsForCall = append(fake.getTemporalLayerForSpatialFpsArgsForCall, struct { + arg1 mime.MimeType + arg2 int32 + arg3 uint32 + }{arg1, arg2, arg3}) + stub := fake.GetTemporalLayerForSpatialFpsStub + fakeReturns := fake.getTemporalLayerForSpatialFpsReturns + fake.recordInvocation("GetTemporalLayerForSpatialFps", []interface{}{arg1, arg2, arg3}) + fake.getTemporalLayerForSpatialFpsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCallCount() int { + fake.getTemporalLayerForSpatialFpsMutex.RLock() + defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() + return len(fake.getTemporalLayerForSpatialFpsArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(mime.MimeType, int32, uint32) int32) { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() + fake.GetTemporalLayerForSpatialFpsStub = stub +} + +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (mime.MimeType, int32, uint32) { + fake.getTemporalLayerForSpatialFpsMutex.RLock() + defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() + argsForCall := fake.getTemporalLayerForSpatialFpsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsReturns(result1 int32) { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() + fake.GetTemporalLayerForSpatialFpsStub = nil + fake.getTemporalLayerForSpatialFpsReturns = struct { + result1 int32 + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetTemporalLayerForSpatialFpsReturnsOnCall(i int, result1 int32) { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() + fake.GetTemporalLayerForSpatialFpsStub = nil + if fake.getTemporalLayerForSpatialFpsReturnsOnCall == nil { + fake.getTemporalLayerForSpatialFpsReturnsOnCall = make(map[int]struct { + result1 int32 + }) + } + fake.getTemporalLayerForSpatialFpsReturnsOnCall[i] = struct { + result1 int32 + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetTrackStats() *livekit.RTPStats { + fake.getTrackStatsMutex.Lock() + ret, specificReturn := fake.getTrackStatsReturnsOnCall[len(fake.getTrackStatsArgsForCall)] + fake.getTrackStatsArgsForCall = append(fake.getTrackStatsArgsForCall, struct { + }{}) + stub := fake.GetTrackStatsStub + fakeReturns := fake.getTrackStatsReturns + fake.recordInvocation("GetTrackStats", []interface{}{}) + fake.getTrackStatsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) GetTrackStatsCallCount() int { + fake.getTrackStatsMutex.RLock() + defer fake.getTrackStatsMutex.RUnlock() + return len(fake.getTrackStatsArgsForCall) +} + +func (fake *FakeLocalMediaTrack) GetTrackStatsCalls(stub func() *livekit.RTPStats) { + fake.getTrackStatsMutex.Lock() + defer fake.getTrackStatsMutex.Unlock() + fake.GetTrackStatsStub = stub +} + +func (fake *FakeLocalMediaTrack) GetTrackStatsReturns(result1 *livekit.RTPStats) { + fake.getTrackStatsMutex.Lock() + defer fake.getTrackStatsMutex.Unlock() + fake.GetTrackStatsStub = nil + fake.getTrackStatsReturns = struct { + result1 *livekit.RTPStats + }{result1} +} + +func (fake *FakeLocalMediaTrack) GetTrackStatsReturnsOnCall(i int, result1 *livekit.RTPStats) { + fake.getTrackStatsMutex.Lock() + defer fake.getTrackStatsMutex.Unlock() + fake.GetTrackStatsStub = nil + if fake.getTrackStatsReturnsOnCall == nil { + fake.getTrackStatsReturnsOnCall = make(map[int]struct { + result1 *livekit.RTPStats + }) + } + fake.getTrackStatsReturnsOnCall[i] = struct { + result1 *livekit.RTPStats + }{result1} +} + +func (fake *FakeLocalMediaTrack) HasSdpCid(arg1 string) bool { + fake.hasSdpCidMutex.Lock() + ret, specificReturn := fake.hasSdpCidReturnsOnCall[len(fake.hasSdpCidArgsForCall)] + fake.hasSdpCidArgsForCall = append(fake.hasSdpCidArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.HasSdpCidStub + fakeReturns := fake.hasSdpCidReturns + fake.recordInvocation("HasSdpCid", []interface{}{arg1}) + fake.hasSdpCidMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) HasSdpCidCallCount() int { + fake.hasSdpCidMutex.RLock() + defer fake.hasSdpCidMutex.RUnlock() + return len(fake.hasSdpCidArgsForCall) +} + +func (fake *FakeLocalMediaTrack) HasSdpCidCalls(stub func(string) bool) { + fake.hasSdpCidMutex.Lock() + defer fake.hasSdpCidMutex.Unlock() + fake.HasSdpCidStub = stub +} + +func (fake *FakeLocalMediaTrack) HasSdpCidArgsForCall(i int) string { + fake.hasSdpCidMutex.RLock() + defer fake.hasSdpCidMutex.RUnlock() + argsForCall := fake.hasSdpCidArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) HasSdpCidReturns(result1 bool) { + fake.hasSdpCidMutex.Lock() + defer fake.hasSdpCidMutex.Unlock() + fake.HasSdpCidStub = nil + fake.hasSdpCidReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) HasSdpCidReturnsOnCall(i int, result1 bool) { + fake.hasSdpCidMutex.Lock() + defer fake.hasSdpCidMutex.Unlock() + fake.HasSdpCidStub = nil + if fake.hasSdpCidReturnsOnCall == nil { + fake.hasSdpCidReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasSdpCidReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) HasSignalCid(arg1 string) bool { + fake.hasSignalCidMutex.Lock() + ret, specificReturn := fake.hasSignalCidReturnsOnCall[len(fake.hasSignalCidArgsForCall)] + fake.hasSignalCidArgsForCall = append(fake.hasSignalCidArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.HasSignalCidStub + fakeReturns := fake.hasSignalCidReturns + fake.recordInvocation("HasSignalCid", []interface{}{arg1}) + fake.hasSignalCidMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) HasSignalCidCallCount() int { + fake.hasSignalCidMutex.RLock() + defer fake.hasSignalCidMutex.RUnlock() + return len(fake.hasSignalCidArgsForCall) +} + +func (fake *FakeLocalMediaTrack) HasSignalCidCalls(stub func(string) bool) { + fake.hasSignalCidMutex.Lock() + defer fake.hasSignalCidMutex.Unlock() + fake.HasSignalCidStub = stub +} + +func (fake *FakeLocalMediaTrack) HasSignalCidArgsForCall(i int) string { + fake.hasSignalCidMutex.RLock() + defer fake.hasSignalCidMutex.RUnlock() + argsForCall := fake.hasSignalCidArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) HasSignalCidReturns(result1 bool) { + fake.hasSignalCidMutex.Lock() + defer fake.hasSignalCidMutex.Unlock() + fake.HasSignalCidStub = nil + fake.hasSignalCidReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) HasSignalCidReturnsOnCall(i int, result1 bool) { + fake.hasSignalCidMutex.Lock() + defer fake.hasSignalCidMutex.Unlock() + fake.HasSignalCidStub = nil + if fake.hasSignalCidReturnsOnCall == nil { + fake.hasSignalCidReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasSignalCidReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) ID() livekit.TrackID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeLocalMediaTrack) IDCalls(stub func() livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeLocalMediaTrack) IDReturns(result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeLocalMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.TrackID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsEncrypted() bool { + fake.isEncryptedMutex.Lock() + ret, specificReturn := fake.isEncryptedReturnsOnCall[len(fake.isEncryptedArgsForCall)] + fake.isEncryptedArgsForCall = append(fake.isEncryptedArgsForCall, struct { + }{}) + stub := fake.IsEncryptedStub + fakeReturns := fake.isEncryptedReturns + fake.recordInvocation("IsEncrypted", []interface{}{}) + fake.isEncryptedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) IsEncryptedCallCount() int { + fake.isEncryptedMutex.RLock() + defer fake.isEncryptedMutex.RUnlock() + return len(fake.isEncryptedArgsForCall) +} + +func (fake *FakeLocalMediaTrack) IsEncryptedCalls(stub func() bool) { + fake.isEncryptedMutex.Lock() + defer fake.isEncryptedMutex.Unlock() + fake.IsEncryptedStub = stub +} + +func (fake *FakeLocalMediaTrack) IsEncryptedReturns(result1 bool) { + fake.isEncryptedMutex.Lock() + defer fake.isEncryptedMutex.Unlock() + fake.IsEncryptedStub = nil + fake.isEncryptedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsEncryptedReturnsOnCall(i int, result1 bool) { + fake.isEncryptedMutex.Lock() + defer fake.isEncryptedMutex.Unlock() + fake.IsEncryptedStub = nil + if fake.isEncryptedReturnsOnCall == nil { + fake.isEncryptedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isEncryptedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsMuted() bool { + fake.isMutedMutex.Lock() + ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)] + fake.isMutedArgsForCall = append(fake.isMutedArgsForCall, struct { + }{}) + stub := fake.IsMutedStub + fakeReturns := fake.isMutedReturns + fake.recordInvocation("IsMuted", []interface{}{}) + fake.isMutedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) IsMutedCallCount() int { + fake.isMutedMutex.RLock() + defer fake.isMutedMutex.RUnlock() + return len(fake.isMutedArgsForCall) +} + +func (fake *FakeLocalMediaTrack) IsMutedCalls(stub func() bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = stub +} + +func (fake *FakeLocalMediaTrack) IsMutedReturns(result1 bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = nil + fake.isMutedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsMutedReturnsOnCall(i int, result1 bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = nil + if fake.isMutedReturnsOnCall == nil { + fake.isMutedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isMutedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsOpen() bool { + fake.isOpenMutex.Lock() + ret, specificReturn := fake.isOpenReturnsOnCall[len(fake.isOpenArgsForCall)] + fake.isOpenArgsForCall = append(fake.isOpenArgsForCall, struct { + }{}) + stub := fake.IsOpenStub + fakeReturns := fake.isOpenReturns + fake.recordInvocation("IsOpen", []interface{}{}) + fake.isOpenMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) IsOpenCallCount() int { + fake.isOpenMutex.RLock() + defer fake.isOpenMutex.RUnlock() + return len(fake.isOpenArgsForCall) +} + +func (fake *FakeLocalMediaTrack) IsOpenCalls(stub func() bool) { + fake.isOpenMutex.Lock() + defer fake.isOpenMutex.Unlock() + fake.IsOpenStub = stub +} + +func (fake *FakeLocalMediaTrack) IsOpenReturns(result1 bool) { + fake.isOpenMutex.Lock() + defer fake.isOpenMutex.Unlock() + fake.IsOpenStub = nil + fake.isOpenReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsOpenReturnsOnCall(i int, result1 bool) { + fake.isOpenMutex.Lock() + defer fake.isOpenMutex.Unlock() + fake.IsOpenStub = nil + if fake.isOpenReturnsOnCall == nil { + fake.isOpenReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isOpenReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsSubscriber(arg1 livekit.ParticipantID) bool { + fake.isSubscriberMutex.Lock() + ret, specificReturn := fake.isSubscriberReturnsOnCall[len(fake.isSubscriberArgsForCall)] + fake.isSubscriberArgsForCall = append(fake.isSubscriberArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.IsSubscriberStub + fakeReturns := fake.isSubscriberReturns + fake.recordInvocation("IsSubscriber", []interface{}{arg1}) + fake.isSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) IsSubscriberCallCount() int { + fake.isSubscriberMutex.RLock() + defer fake.isSubscriberMutex.RUnlock() + return len(fake.isSubscriberArgsForCall) +} + +func (fake *FakeLocalMediaTrack) IsSubscriberCalls(stub func(livekit.ParticipantID) bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = stub +} + +func (fake *FakeLocalMediaTrack) IsSubscriberArgsForCall(i int) livekit.ParticipantID { + fake.isSubscriberMutex.RLock() + defer fake.isSubscriberMutex.RUnlock() + argsForCall := fake.isSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) IsSubscriberReturns(result1 bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = nil + fake.isSubscriberReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) IsSubscriberReturnsOnCall(i int, result1 bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = nil + if fake.isSubscriberReturnsOnCall == nil { + fake.isSubscriberReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isSubscriberReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalMediaTrack) Kind() livekit.TrackType { + fake.kindMutex.Lock() + ret, specificReturn := fake.kindReturnsOnCall[len(fake.kindArgsForCall)] + fake.kindArgsForCall = append(fake.kindArgsForCall, struct { + }{}) + stub := fake.KindStub + fakeReturns := fake.kindReturns + fake.recordInvocation("Kind", []interface{}{}) + fake.kindMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) KindCallCount() int { + fake.kindMutex.RLock() + defer fake.kindMutex.RUnlock() + return len(fake.kindArgsForCall) +} + +func (fake *FakeLocalMediaTrack) KindCalls(stub func() livekit.TrackType) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = stub +} + +func (fake *FakeLocalMediaTrack) KindReturns(result1 livekit.TrackType) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + fake.kindReturns = struct { + result1 livekit.TrackType + }{result1} +} + +func (fake *FakeLocalMediaTrack) KindReturnsOnCall(i int, result1 livekit.TrackType) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + if fake.kindReturnsOnCall == nil { + fake.kindReturnsOnCall = make(map[int]struct { + result1 livekit.TrackType + }) + } + fake.kindReturnsOnCall[i] = struct { + result1 livekit.TrackType + }{result1} +} + +func (fake *FakeLocalMediaTrack) Logger() logger.Logger { + fake.loggerMutex.Lock() + ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)] + fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct { + }{}) + stub := fake.LoggerStub + fakeReturns := fake.loggerReturns + fake.recordInvocation("Logger", []interface{}{}) + fake.loggerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) LoggerCallCount() int { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + return len(fake.loggerArgsForCall) +} + +func (fake *FakeLocalMediaTrack) LoggerCalls(stub func() logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = stub +} + +func (fake *FakeLocalMediaTrack) LoggerReturns(result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + fake.loggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeLocalMediaTrack) LoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + if fake.loggerReturnsOnCall == nil { + fake.loggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.loggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeLocalMediaTrack) Name() string { + fake.nameMutex.Lock() + ret, specificReturn := fake.nameReturnsOnCall[len(fake.nameArgsForCall)] + fake.nameArgsForCall = append(fake.nameArgsForCall, struct { + }{}) + stub := fake.NameStub + fakeReturns := fake.nameReturns + fake.recordInvocation("Name", []interface{}{}) + fake.nameMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) NameCallCount() int { + fake.nameMutex.RLock() + defer fake.nameMutex.RUnlock() + return len(fake.nameArgsForCall) +} + +func (fake *FakeLocalMediaTrack) NameCalls(stub func() string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = stub +} + +func (fake *FakeLocalMediaTrack) NameReturns(result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + fake.nameReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeLocalMediaTrack) NameReturnsOnCall(i int, result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + if fake.nameReturnsOnCall == nil { + fake.nameReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.nameReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMaxQuality(arg1 livekit.NodeID, arg2 []types.SubscribedCodecQuality) { + var arg2Copy []types.SubscribedCodecQuality + if arg2 != nil { + arg2Copy = make([]types.SubscribedCodecQuality, len(arg2)) + copy(arg2Copy, arg2) + } + fake.notifySubscriberNodeMaxQualityMutex.Lock() + fake.notifySubscriberNodeMaxQualityArgsForCall = append(fake.notifySubscriberNodeMaxQualityArgsForCall, struct { + arg1 livekit.NodeID + arg2 []types.SubscribedCodecQuality + }{arg1, arg2Copy}) + stub := fake.NotifySubscriberNodeMaxQualityStub + fake.recordInvocation("NotifySubscriberNodeMaxQuality", []interface{}{arg1, arg2Copy}) + fake.notifySubscriberNodeMaxQualityMutex.Unlock() + if stub != nil { + fake.NotifySubscriberNodeMaxQualityStub(arg1, arg2) + } +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMaxQualityCallCount() int { + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + return len(fake.notifySubscriberNodeMaxQualityArgsForCall) +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMaxQualityCalls(stub func(livekit.NodeID, []types.SubscribedCodecQuality)) { + fake.notifySubscriberNodeMaxQualityMutex.Lock() + defer fake.notifySubscriberNodeMaxQualityMutex.Unlock() + fake.NotifySubscriberNodeMaxQualityStub = stub +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMaxQualityArgsForCall(i int) (livekit.NodeID, []types.SubscribedCodecQuality) { + fake.notifySubscriberNodeMaxQualityMutex.RLock() + defer fake.notifySubscriberNodeMaxQualityMutex.RUnlock() + argsForCall := fake.notifySubscriberNodeMaxQualityArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLoss(arg1 livekit.NodeID, arg2 uint8) { + fake.notifySubscriberNodeMediaLossMutex.Lock() + fake.notifySubscriberNodeMediaLossArgsForCall = append(fake.notifySubscriberNodeMediaLossArgsForCall, struct { + arg1 livekit.NodeID + arg2 uint8 + }{arg1, arg2}) + stub := fake.NotifySubscriberNodeMediaLossStub + fake.recordInvocation("NotifySubscriberNodeMediaLoss", []interface{}{arg1, arg2}) + fake.notifySubscriberNodeMediaLossMutex.Unlock() + if stub != nil { + fake.NotifySubscriberNodeMediaLossStub(arg1, arg2) + } +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLossCallCount() int { + fake.notifySubscriberNodeMediaLossMutex.RLock() + defer fake.notifySubscriberNodeMediaLossMutex.RUnlock() + return len(fake.notifySubscriberNodeMediaLossArgsForCall) +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLossCalls(stub func(livekit.NodeID, uint8)) { + fake.notifySubscriberNodeMediaLossMutex.Lock() + defer fake.notifySubscriberNodeMediaLossMutex.Unlock() + fake.NotifySubscriberNodeMediaLossStub = stub +} + +func (fake *FakeLocalMediaTrack) NotifySubscriberNodeMediaLossArgsForCall(i int) (livekit.NodeID, uint8) { + fake.notifySubscriberNodeMediaLossMutex.RLock() + defer fake.notifySubscriberNodeMediaLossMutex.RUnlock() + argsForCall := fake.notifySubscriberNodeMediaLossArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalMediaTrack) NotifySubscriptionNode(arg1 livekit.NodeID, arg2 []*livekit.SubscribedAudioCodec) { + var arg2Copy []*livekit.SubscribedAudioCodec + if arg2 != nil { + arg2Copy = make([]*livekit.SubscribedAudioCodec, len(arg2)) + copy(arg2Copy, arg2) + } + fake.notifySubscriptionNodeMutex.Lock() + fake.notifySubscriptionNodeArgsForCall = append(fake.notifySubscriptionNodeArgsForCall, struct { + arg1 livekit.NodeID + arg2 []*livekit.SubscribedAudioCodec + }{arg1, arg2Copy}) + stub := fake.NotifySubscriptionNodeStub + fake.recordInvocation("NotifySubscriptionNode", []interface{}{arg1, arg2Copy}) + fake.notifySubscriptionNodeMutex.Unlock() + if stub != nil { + fake.NotifySubscriptionNodeStub(arg1, arg2) + } +} + +func (fake *FakeLocalMediaTrack) NotifySubscriptionNodeCallCount() int { + fake.notifySubscriptionNodeMutex.RLock() + defer fake.notifySubscriptionNodeMutex.RUnlock() + return len(fake.notifySubscriptionNodeArgsForCall) +} + +func (fake *FakeLocalMediaTrack) NotifySubscriptionNodeCalls(stub func(livekit.NodeID, []*livekit.SubscribedAudioCodec)) { + fake.notifySubscriptionNodeMutex.Lock() + defer fake.notifySubscriptionNodeMutex.Unlock() + fake.NotifySubscriptionNodeStub = stub +} + +func (fake *FakeLocalMediaTrack) NotifySubscriptionNodeArgsForCall(i int) (livekit.NodeID, []*livekit.SubscribedAudioCodec) { + fake.notifySubscriptionNodeMutex.RLock() + defer fake.notifySubscriptionNodeMutex.RUnlock() + argsForCall := fake.notifySubscriptionNodeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribed() { + fake.onTrackSubscribedMutex.Lock() + fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct { + }{}) + stub := fake.OnTrackSubscribedStub + fake.recordInvocation("OnTrackSubscribed", []interface{}{}) + fake.onTrackSubscribedMutex.Unlock() + if stub != nil { + fake.OnTrackSubscribedStub() + } +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribedCallCount() int { + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() + return len(fake.onTrackSubscribedArgsForCall) +} + +func (fake *FakeLocalMediaTrack) OnTrackSubscribedCalls(stub func()) { + fake.onTrackSubscribedMutex.Lock() + defer fake.onTrackSubscribedMutex.Unlock() + fake.OnTrackSubscribedStub = stub +} + +func (fake *FakeLocalMediaTrack) PublisherID() livekit.ParticipantID { + fake.publisherIDMutex.Lock() + ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] + fake.publisherIDArgsForCall = append(fake.publisherIDArgsForCall, struct { + }{}) + stub := fake.PublisherIDStub + fakeReturns := fake.publisherIDReturns + fake.recordInvocation("PublisherID", []interface{}{}) + fake.publisherIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) PublisherIDCallCount() int { + fake.publisherIDMutex.RLock() + defer fake.publisherIDMutex.RUnlock() + return len(fake.publisherIDArgsForCall) +} + +func (fake *FakeLocalMediaTrack) PublisherIDCalls(stub func() livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = stub +} + +func (fake *FakeLocalMediaTrack) PublisherIDReturns(result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + fake.publisherIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalMediaTrack) PublisherIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + if fake.publisherIDReturnsOnCall == nil { + fake.publisherIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.publisherIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalMediaTrack) PublisherIdentity() livekit.ParticipantIdentity { + fake.publisherIdentityMutex.Lock() + ret, specificReturn := fake.publisherIdentityReturnsOnCall[len(fake.publisherIdentityArgsForCall)] + fake.publisherIdentityArgsForCall = append(fake.publisherIdentityArgsForCall, struct { + }{}) + stub := fake.PublisherIdentityStub + fakeReturns := fake.publisherIdentityReturns + fake.recordInvocation("PublisherIdentity", []interface{}{}) + fake.publisherIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) PublisherIdentityCallCount() int { + fake.publisherIdentityMutex.RLock() + defer fake.publisherIdentityMutex.RUnlock() + return len(fake.publisherIdentityArgsForCall) +} + +func (fake *FakeLocalMediaTrack) PublisherIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = stub +} + +func (fake *FakeLocalMediaTrack) PublisherIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + fake.publisherIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeLocalMediaTrack) PublisherIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + if fake.publisherIdentityReturnsOnCall == nil { + fake.publisherIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.publisherIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeLocalMediaTrack) PublisherVersion() uint32 { + fake.publisherVersionMutex.Lock() + ret, specificReturn := fake.publisherVersionReturnsOnCall[len(fake.publisherVersionArgsForCall)] + fake.publisherVersionArgsForCall = append(fake.publisherVersionArgsForCall, struct { + }{}) + stub := fake.PublisherVersionStub + fakeReturns := fake.publisherVersionReturns + fake.recordInvocation("PublisherVersion", []interface{}{}) + fake.publisherVersionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) PublisherVersionCallCount() int { + fake.publisherVersionMutex.RLock() + defer fake.publisherVersionMutex.RUnlock() + return len(fake.publisherVersionArgsForCall) +} + +func (fake *FakeLocalMediaTrack) PublisherVersionCalls(stub func() uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = stub +} + +func (fake *FakeLocalMediaTrack) PublisherVersionReturns(result1 uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = nil + fake.publisherVersionReturns = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeLocalMediaTrack) PublisherVersionReturnsOnCall(i int, result1 uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = nil + if fake.publisherVersionReturnsOnCall == nil { + fake.publisherVersionReturnsOnCall = make(map[int]struct { + result1 uint32 + }) + } + fake.publisherVersionReturnsOnCall[i] = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeLocalMediaTrack) Receivers() []sfu.TrackReceiver { + fake.receiversMutex.Lock() + ret, specificReturn := fake.receiversReturnsOnCall[len(fake.receiversArgsForCall)] + fake.receiversArgsForCall = append(fake.receiversArgsForCall, struct { + }{}) + stub := fake.ReceiversStub + fakeReturns := fake.receiversReturns + fake.recordInvocation("Receivers", []interface{}{}) + fake.receiversMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) ReceiversCallCount() int { + fake.receiversMutex.RLock() + defer fake.receiversMutex.RUnlock() + return len(fake.receiversArgsForCall) +} + +func (fake *FakeLocalMediaTrack) ReceiversCalls(stub func() []sfu.TrackReceiver) { + fake.receiversMutex.Lock() + defer fake.receiversMutex.Unlock() + fake.ReceiversStub = stub +} + +func (fake *FakeLocalMediaTrack) ReceiversReturns(result1 []sfu.TrackReceiver) { + fake.receiversMutex.Lock() + defer fake.receiversMutex.Unlock() + fake.ReceiversStub = nil + fake.receiversReturns = struct { + result1 []sfu.TrackReceiver + }{result1} +} + +func (fake *FakeLocalMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.TrackReceiver) { + fake.receiversMutex.Lock() + defer fake.receiversMutex.Unlock() + fake.ReceiversStub = nil + if fake.receiversReturnsOnCall == nil { + fake.receiversReturnsOnCall = make(map[int]struct { + result1 []sfu.TrackReceiver + }) + } + fake.receiversReturnsOnCall[i] = struct { + result1 []sfu.TrackReceiver + }{result1} +} + +func (fake *FakeLocalMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) { + fake.removeSubscriberMutex.Lock() + fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { + arg1 livekit.ParticipantID + arg2 bool + }{arg1, arg2}) + stub := fake.RemoveSubscriberStub + fake.recordInvocation("RemoveSubscriber", []interface{}{arg1, arg2}) + fake.removeSubscriberMutex.Unlock() + if stub != nil { + fake.RemoveSubscriberStub(arg1, arg2) + } +} + +func (fake *FakeLocalMediaTrack) RemoveSubscriberCallCount() int { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + return len(fake.removeSubscriberArgsForCall) +} + +func (fake *FakeLocalMediaTrack) RemoveSubscriberCalls(stub func(livekit.ParticipantID, bool)) { + fake.removeSubscriberMutex.Lock() + defer fake.removeSubscriberMutex.Unlock() + fake.RemoveSubscriberStub = stub +} + +func (fake *FakeLocalMediaTrack) RemoveSubscriberArgsForCall(i int) (livekit.ParticipantID, bool) { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + argsForCall := fake.removeSubscriberArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalMediaTrack) Restart() { + fake.restartMutex.Lock() + fake.restartArgsForCall = append(fake.restartArgsForCall, struct { + }{}) + stub := fake.RestartStub + fake.recordInvocation("Restart", []interface{}{}) + fake.restartMutex.Unlock() + if stub != nil { + fake.RestartStub() + } +} + +func (fake *FakeLocalMediaTrack) RestartCallCount() int { + fake.restartMutex.RLock() + defer fake.restartMutex.RUnlock() + return len(fake.restartArgsForCall) +} + +func (fake *FakeLocalMediaTrack) RestartCalls(stub func()) { + fake.restartMutex.Lock() + defer fake.restartMutex.Unlock() + fake.RestartStub = stub +} + +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { + var arg1Copy []livekit.ParticipantIdentity + if arg1 != nil { + arg1Copy = make([]livekit.ParticipantIdentity, len(arg1)) + copy(arg1Copy, arg1) + } + fake.revokeDisallowedSubscribersMutex.Lock() + ret, specificReturn := fake.revokeDisallowedSubscribersReturnsOnCall[len(fake.revokeDisallowedSubscribersArgsForCall)] + fake.revokeDisallowedSubscribersArgsForCall = append(fake.revokeDisallowedSubscribersArgsForCall, struct { + arg1 []livekit.ParticipantIdentity + }{arg1Copy}) + stub := fake.RevokeDisallowedSubscribersStub + fakeReturns := fake.revokeDisallowedSubscribersReturns + fake.recordInvocation("RevokeDisallowedSubscribers", []interface{}{arg1Copy}) + fake.revokeDisallowedSubscribersMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersCallCount() int { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + return len(fake.revokeDisallowedSubscribersArgsForCall) +} + +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = stub +} + +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []livekit.ParticipantIdentity { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + argsForCall := fake.revokeDisallowedSubscribersArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturns(result1 []livekit.ParticipantIdentity) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + fake.revokeDisallowedSubscribersReturns = struct { + result1 []livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeLocalMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantIdentity) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + if fake.revokeDisallowedSubscribersReturnsOnCall == nil { + fake.revokeDisallowedSubscribersReturnsOnCall = make(map[int]struct { + result1 []livekit.ParticipantIdentity + }) + } + fake.revokeDisallowedSubscribersReturnsOnCall[i] = struct { + result1 []livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeLocalMediaTrack) SetMuted(arg1 bool) { + fake.setMutedMutex.Lock() + fake.setMutedArgsForCall = append(fake.setMutedArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.SetMutedStub + fake.recordInvocation("SetMuted", []interface{}{arg1}) + fake.setMutedMutex.Unlock() + if stub != nil { + fake.SetMutedStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) SetMutedCallCount() int { + fake.setMutedMutex.RLock() + defer fake.setMutedMutex.RUnlock() + return len(fake.setMutedArgsForCall) +} + +func (fake *FakeLocalMediaTrack) SetMutedCalls(stub func(bool)) { + fake.setMutedMutex.Lock() + defer fake.setMutedMutex.Unlock() + fake.SetMutedStub = stub +} + +func (fake *FakeLocalMediaTrack) SetMutedArgsForCall(i int) bool { + fake.setMutedMutex.RLock() + defer fake.setMutedMutex.RUnlock() + argsForCall := fake.setMutedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) SetRTT(arg1 uint32) { + fake.setRTTMutex.Lock() + fake.setRTTArgsForCall = append(fake.setRTTArgsForCall, struct { + arg1 uint32 + }{arg1}) + stub := fake.SetRTTStub + fake.recordInvocation("SetRTT", []interface{}{arg1}) + fake.setRTTMutex.Unlock() + if stub != nil { + fake.SetRTTStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) SetRTTCallCount() int { + fake.setRTTMutex.RLock() + defer fake.setRTTMutex.RUnlock() + return len(fake.setRTTArgsForCall) +} + +func (fake *FakeLocalMediaTrack) SetRTTCalls(stub func(uint32)) { + fake.setRTTMutex.Lock() + defer fake.setRTTMutex.Unlock() + fake.SetRTTStub = stub +} + +func (fake *FakeLocalMediaTrack) SetRTTArgsForCall(i int) uint32 { + fake.setRTTMutex.RLock() + defer fake.setRTTMutex.RUnlock() + argsForCall := fake.setRTTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) Source() livekit.TrackSource { + fake.sourceMutex.Lock() + ret, specificReturn := fake.sourceReturnsOnCall[len(fake.sourceArgsForCall)] + fake.sourceArgsForCall = append(fake.sourceArgsForCall, struct { + }{}) + stub := fake.SourceStub + fakeReturns := fake.sourceReturns + fake.recordInvocation("Source", []interface{}{}) + fake.sourceMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) SourceCallCount() int { + fake.sourceMutex.RLock() + defer fake.sourceMutex.RUnlock() + return len(fake.sourceArgsForCall) +} + +func (fake *FakeLocalMediaTrack) SourceCalls(stub func() livekit.TrackSource) { + fake.sourceMutex.Lock() + defer fake.sourceMutex.Unlock() + fake.SourceStub = stub +} + +func (fake *FakeLocalMediaTrack) SourceReturns(result1 livekit.TrackSource) { + fake.sourceMutex.Lock() + defer fake.sourceMutex.Unlock() + fake.SourceStub = nil + fake.sourceReturns = struct { + result1 livekit.TrackSource + }{result1} +} + +func (fake *FakeLocalMediaTrack) SourceReturnsOnCall(i int, result1 livekit.TrackSource) { + fake.sourceMutex.Lock() + defer fake.sourceMutex.Unlock() + fake.SourceStub = nil + if fake.sourceReturnsOnCall == nil { + fake.sourceReturnsOnCall = make(map[int]struct { + result1 livekit.TrackSource + }) + } + fake.sourceReturnsOnCall[i] = struct { + result1 livekit.TrackSource + }{result1} +} + +func (fake *FakeLocalMediaTrack) Stream() string { + fake.streamMutex.Lock() + ret, specificReturn := fake.streamReturnsOnCall[len(fake.streamArgsForCall)] + fake.streamArgsForCall = append(fake.streamArgsForCall, struct { + }{}) + stub := fake.StreamStub + fakeReturns := fake.streamReturns + fake.recordInvocation("Stream", []interface{}{}) + fake.streamMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) StreamCallCount() int { + fake.streamMutex.RLock() + defer fake.streamMutex.RUnlock() + return len(fake.streamArgsForCall) +} + +func (fake *FakeLocalMediaTrack) StreamCalls(stub func() string) { + fake.streamMutex.Lock() + defer fake.streamMutex.Unlock() + fake.StreamStub = stub +} + +func (fake *FakeLocalMediaTrack) StreamReturns(result1 string) { + fake.streamMutex.Lock() + defer fake.streamMutex.Unlock() + fake.StreamStub = nil + fake.streamReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeLocalMediaTrack) StreamReturnsOnCall(i int, result1 string) { + fake.streamMutex.Lock() + defer fake.streamMutex.Unlock() + fake.StreamStub = nil + if fake.streamReturnsOnCall == nil { + fake.streamReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.streamReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeLocalMediaTrack) ToProto() *livekit.TrackInfo { + fake.toProtoMutex.Lock() + ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] + fake.toProtoArgsForCall = append(fake.toProtoArgsForCall, struct { + }{}) + stub := fake.ToProtoStub + fakeReturns := fake.toProtoReturns + fake.recordInvocation("ToProto", []interface{}{}) + fake.toProtoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalMediaTrack) ToProtoCallCount() int { + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() + return len(fake.toProtoArgsForCall) +} + +func (fake *FakeLocalMediaTrack) ToProtoCalls(stub func() *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = stub +} + +func (fake *FakeLocalMediaTrack) ToProtoReturns(result1 *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + fake.toProtoReturns = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalMediaTrack) ToProtoReturnsOnCall(i int, result1 *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + if fake.toProtoReturnsOnCall == nil { + fake.toProtoReturnsOnCall = make(map[int]struct { + result1 *livekit.TrackInfo + }) + } + fake.toProtoReturnsOnCall[i] = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalMediaTrack) UpdateAudioTrack(arg1 *livekit.UpdateLocalAudioTrack) { + fake.updateAudioTrackMutex.Lock() + fake.updateAudioTrackArgsForCall = append(fake.updateAudioTrackArgsForCall, struct { + arg1 *livekit.UpdateLocalAudioTrack + }{arg1}) + stub := fake.UpdateAudioTrackStub + fake.recordInvocation("UpdateAudioTrack", []interface{}{arg1}) + fake.updateAudioTrackMutex.Unlock() + if stub != nil { + fake.UpdateAudioTrackStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) UpdateAudioTrackCallCount() int { + fake.updateAudioTrackMutex.RLock() + defer fake.updateAudioTrackMutex.RUnlock() + return len(fake.updateAudioTrackArgsForCall) +} + +func (fake *FakeLocalMediaTrack) UpdateAudioTrackCalls(stub func(*livekit.UpdateLocalAudioTrack)) { + fake.updateAudioTrackMutex.Lock() + defer fake.updateAudioTrackMutex.Unlock() + fake.UpdateAudioTrackStub = stub +} + +func (fake *FakeLocalMediaTrack) UpdateAudioTrackArgsForCall(i int) *livekit.UpdateLocalAudioTrack { + fake.updateAudioTrackMutex.RLock() + defer fake.updateAudioTrackMutex.RUnlock() + argsForCall := fake.updateAudioTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfo(arg1 *livekit.TrackInfo) { + fake.updateTrackInfoMutex.Lock() + fake.updateTrackInfoArgsForCall = append(fake.updateTrackInfoArgsForCall, struct { + arg1 *livekit.TrackInfo + }{arg1}) + stub := fake.UpdateTrackInfoStub + fake.recordInvocation("UpdateTrackInfo", []interface{}{arg1}) + fake.updateTrackInfoMutex.Unlock() + if stub != nil { + fake.UpdateTrackInfoStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfoCallCount() int { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + return len(fake.updateTrackInfoArgsForCall) +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfoCalls(stub func(*livekit.TrackInfo)) { + fake.updateTrackInfoMutex.Lock() + defer fake.updateTrackInfoMutex.Unlock() + fake.UpdateTrackInfoStub = stub +} + +func (fake *FakeLocalMediaTrack) UpdateTrackInfoArgsForCall(i int) *livekit.TrackInfo { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + argsForCall := fake.updateTrackInfoArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) UpdateVideoTrack(arg1 *livekit.UpdateLocalVideoTrack) { + fake.updateVideoTrackMutex.Lock() + fake.updateVideoTrackArgsForCall = append(fake.updateVideoTrackArgsForCall, struct { + arg1 *livekit.UpdateLocalVideoTrack + }{arg1}) + stub := fake.UpdateVideoTrackStub + fake.recordInvocation("UpdateVideoTrack", []interface{}{arg1}) + fake.updateVideoTrackMutex.Unlock() + if stub != nil { + fake.UpdateVideoTrackStub(arg1) + } +} + +func (fake *FakeLocalMediaTrack) UpdateVideoTrackCallCount() int { + fake.updateVideoTrackMutex.RLock() + defer fake.updateVideoTrackMutex.RUnlock() + return len(fake.updateVideoTrackArgsForCall) +} + +func (fake *FakeLocalMediaTrack) UpdateVideoTrackCalls(stub func(*livekit.UpdateLocalVideoTrack)) { + fake.updateVideoTrackMutex.Lock() + defer fake.updateVideoTrackMutex.Unlock() + fake.UpdateVideoTrackStub = stub +} + +func (fake *FakeLocalMediaTrack) UpdateVideoTrackArgsForCall(i int) *livekit.UpdateLocalVideoTrack { + fake.updateVideoTrackMutex.RLock() + defer fake.updateVideoTrackMutex.RUnlock() + argsForCall := fake.updateVideoTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalMediaTrack) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeLocalMediaTrack) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.LocalMediaTrack = new(FakeLocalMediaTrack) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_local_participant.go b/livekit/pkg/rtc/types/typesfakes/fake_local_participant.go new file mode 100644 index 0000000..494d685 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_local_participant.go @@ -0,0 +1,9468 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/utils" + "github.com/pion/rtcp" + webrtc "github.com/pion/webrtc/v4" + "google.golang.org/protobuf/proto" +) + +type FakeLocalParticipant struct { + ActiveAtStub func() time.Time + activeAtMutex sync.RWMutex + activeAtArgsForCall []struct { + } + activeAtReturns struct { + result1 time.Time + } + activeAtReturnsOnCall map[int]struct { + result1 time.Time + } + AddOnCloseStub func(string, func(types.LocalParticipant)) + addOnCloseMutex sync.RWMutex + addOnCloseArgsForCall []struct { + arg1 string + arg2 func(types.LocalParticipant) + } + AddTrackStub func(*livekit.AddTrackRequest) + addTrackMutex sync.RWMutex + addTrackArgsForCall []struct { + arg1 *livekit.AddTrackRequest + } + AddTrackLocalStub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + addTrackLocalMutex sync.RWMutex + addTrackLocalArgsForCall []struct { + arg1 webrtc.TrackLocal + arg2 types.AddTrackParams + } + addTrackLocalReturns struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + } + addTrackLocalReturnsOnCall map[int]struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + } + AddTransceiverFromTrackLocalStub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) + addTransceiverFromTrackLocalMutex sync.RWMutex + addTransceiverFromTrackLocalArgsForCall []struct { + arg1 webrtc.TrackLocal + arg2 types.AddTrackParams + } + addTransceiverFromTrackLocalReturns struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + } + addTransceiverFromTrackLocalReturnsOnCall map[int]struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + } + CacheDownTrackStub func(livekit.TrackID, *webrtc.RTPTransceiver, sfu.DownTrackState) + cacheDownTrackMutex sync.RWMutex + cacheDownTrackArgsForCall []struct { + arg1 livekit.TrackID + arg2 *webrtc.RTPTransceiver + arg3 sfu.DownTrackState + } + CanPublishStub func() bool + canPublishMutex sync.RWMutex + canPublishArgsForCall []struct { + } + canPublishReturns struct { + result1 bool + } + canPublishReturnsOnCall map[int]struct { + result1 bool + } + CanPublishDataStub func() bool + canPublishDataMutex sync.RWMutex + canPublishDataArgsForCall []struct { + } + canPublishDataReturns struct { + result1 bool + } + canPublishDataReturnsOnCall map[int]struct { + result1 bool + } + CanPublishSourceStub func(livekit.TrackSource) bool + canPublishSourceMutex sync.RWMutex + canPublishSourceArgsForCall []struct { + arg1 livekit.TrackSource + } + canPublishSourceReturns struct { + result1 bool + } + canPublishSourceReturnsOnCall map[int]struct { + result1 bool + } + CanSkipBroadcastStub func() bool + canSkipBroadcastMutex sync.RWMutex + canSkipBroadcastArgsForCall []struct { + } + canSkipBroadcastReturns struct { + result1 bool + } + canSkipBroadcastReturnsOnCall map[int]struct { + result1 bool + } + CanSubscribeStub func() bool + canSubscribeMutex sync.RWMutex + canSubscribeArgsForCall []struct { + } + canSubscribeReturns struct { + result1 bool + } + canSubscribeReturnsOnCall map[int]struct { + result1 bool + } + ClaimGrantsStub func() *auth.ClaimGrants + claimGrantsMutex sync.RWMutex + claimGrantsArgsForCall []struct { + } + claimGrantsReturns struct { + result1 *auth.ClaimGrants + } + claimGrantsReturnsOnCall map[int]struct { + result1 *auth.ClaimGrants + } + ClearParticipantListenerStub func() + clearParticipantListenerMutex sync.RWMutex + clearParticipantListenerArgsForCall []struct { + } + CloseStub func(bool, types.ParticipantCloseReason, bool) error + closeMutex sync.RWMutex + closeArgsForCall []struct { + arg1 bool + arg2 types.ParticipantCloseReason + arg3 bool + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + CloseReasonStub func() types.ParticipantCloseReason + closeReasonMutex sync.RWMutex + closeReasonArgsForCall []struct { + } + closeReasonReturns struct { + result1 types.ParticipantCloseReason + } + closeReasonReturnsOnCall map[int]struct { + result1 types.ParticipantCloseReason + } + CloseSignalConnectionStub func(types.SignallingCloseReason) + closeSignalConnectionMutex sync.RWMutex + closeSignalConnectionArgsForCall []struct { + arg1 types.SignallingCloseReason + } + ConnectedAtStub func() time.Time + connectedAtMutex sync.RWMutex + connectedAtArgsForCall []struct { + } + connectedAtReturns struct { + result1 time.Time + } + connectedAtReturnsOnCall map[int]struct { + result1 time.Time + } + DebugInfoStub func() map[string]any + debugInfoMutex sync.RWMutex + debugInfoArgsForCall []struct { + } + debugInfoReturns struct { + result1 map[string]any + } + debugInfoReturnsOnCall map[int]struct { + result1 map[string]any + } + DisconnectedStub func() <-chan struct{} + disconnectedMutex sync.RWMutex + disconnectedArgsForCall []struct { + } + disconnectedReturns struct { + result1 <-chan struct{} + } + disconnectedReturnsOnCall map[int]struct { + result1 <-chan struct{} + } + GetAdaptiveStreamStub func() bool + getAdaptiveStreamMutex sync.RWMutex + getAdaptiveStreamArgsForCall []struct { + } + getAdaptiveStreamReturns struct { + result1 bool + } + getAdaptiveStreamReturnsOnCall map[int]struct { + result1 bool + } + GetAnswerStub func() (webrtc.SessionDescription, uint32, error) + getAnswerMutex sync.RWMutex + getAnswerArgsForCall []struct { + } + getAnswerReturns struct { + result1 webrtc.SessionDescription + result2 uint32 + result3 error + } + getAnswerReturnsOnCall map[int]struct { + result1 webrtc.SessionDescription + result2 uint32 + result3 error + } + GetAudioLevelStub func() (float64, bool) + getAudioLevelMutex sync.RWMutex + getAudioLevelArgsForCall []struct { + } + getAudioLevelReturns struct { + result1 float64 + result2 bool + } + getAudioLevelReturnsOnCall map[int]struct { + result1 float64 + result2 bool + } + GetBufferFactoryStub func() *buffer.Factory + getBufferFactoryMutex sync.RWMutex + getBufferFactoryArgsForCall []struct { + } + getBufferFactoryReturns struct { + result1 *buffer.Factory + } + getBufferFactoryReturnsOnCall map[int]struct { + result1 *buffer.Factory + } + GetCachedDownTrackStub func(livekit.TrackID) (*webrtc.RTPTransceiver, sfu.DownTrackState) + getCachedDownTrackMutex sync.RWMutex + getCachedDownTrackArgsForCall []struct { + arg1 livekit.TrackID + } + getCachedDownTrackReturns struct { + result1 *webrtc.RTPTransceiver + result2 sfu.DownTrackState + } + getCachedDownTrackReturnsOnCall map[int]struct { + result1 *webrtc.RTPTransceiver + result2 sfu.DownTrackState + } + GetClientConfigurationStub func() *livekit.ClientConfiguration + getClientConfigurationMutex sync.RWMutex + getClientConfigurationArgsForCall []struct { + } + getClientConfigurationReturns struct { + result1 *livekit.ClientConfiguration + } + getClientConfigurationReturnsOnCall map[int]struct { + result1 *livekit.ClientConfiguration + } + GetClientInfoStub func() *livekit.ClientInfo + getClientInfoMutex sync.RWMutex + getClientInfoArgsForCall []struct { + } + getClientInfoReturns struct { + result1 *livekit.ClientInfo + } + getClientInfoReturnsOnCall map[int]struct { + result1 *livekit.ClientInfo + } + GetConnectionQualityStub func() *livekit.ConnectionQualityInfo + getConnectionQualityMutex sync.RWMutex + getConnectionQualityArgsForCall []struct { + } + getConnectionQualityReturns struct { + result1 *livekit.ConnectionQualityInfo + } + getConnectionQualityReturnsOnCall map[int]struct { + result1 *livekit.ConnectionQualityInfo + } + GetCountryStub func() string + getCountryMutex sync.RWMutex + getCountryArgsForCall []struct { + } + getCountryReturns struct { + result1 string + } + getCountryReturnsOnCall map[int]struct { + result1 string + } + GetDataTrackTransportStub func() types.DataTrackTransport + getDataTrackTransportMutex sync.RWMutex + getDataTrackTransportArgsForCall []struct { + } + getDataTrackTransportReturns struct { + result1 types.DataTrackTransport + } + getDataTrackTransportReturnsOnCall map[int]struct { + result1 types.DataTrackTransport + } + GetDisableSenderReportPassThroughStub func() bool + getDisableSenderReportPassThroughMutex sync.RWMutex + getDisableSenderReportPassThroughArgsForCall []struct { + } + getDisableSenderReportPassThroughReturns struct { + result1 bool + } + getDisableSenderReportPassThroughReturnsOnCall map[int]struct { + result1 bool + } + GetEnabledPublishCodecsStub func() []*livekit.Codec + getEnabledPublishCodecsMutex sync.RWMutex + getEnabledPublishCodecsArgsForCall []struct { + } + getEnabledPublishCodecsReturns struct { + result1 []*livekit.Codec + } + getEnabledPublishCodecsReturnsOnCall map[int]struct { + result1 []*livekit.Codec + } + GetICEConfigStub func() *livekit.ICEConfig + getICEConfigMutex sync.RWMutex + getICEConfigArgsForCall []struct { + } + getICEConfigReturns struct { + result1 *livekit.ICEConfig + } + getICEConfigReturnsOnCall map[int]struct { + result1 *livekit.ICEConfig + } + GetICEConnectionInfoStub func() []*types.ICEConnectionInfo + getICEConnectionInfoMutex sync.RWMutex + getICEConnectionInfoArgsForCall []struct { + } + getICEConnectionInfoReturns struct { + result1 []*types.ICEConnectionInfo + } + getICEConnectionInfoReturnsOnCall map[int]struct { + result1 []*types.ICEConnectionInfo + } + GetLastReliableSequenceStub func(bool) uint32 + getLastReliableSequenceMutex sync.RWMutex + getLastReliableSequenceArgsForCall []struct { + arg1 bool + } + getLastReliableSequenceReturns struct { + result1 uint32 + } + getLastReliableSequenceReturnsOnCall map[int]struct { + result1 uint32 + } + GetLoggerStub func() logger.Logger + getLoggerMutex sync.RWMutex + getLoggerArgsForCall []struct { + } + getLoggerReturns struct { + result1 logger.Logger + } + getLoggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + GetLoggerResolverStub func() logger.DeferredFieldResolver + getLoggerResolverMutex sync.RWMutex + getLoggerResolverArgsForCall []struct { + } + getLoggerResolverReturns struct { + result1 logger.DeferredFieldResolver + } + getLoggerResolverReturnsOnCall map[int]struct { + result1 logger.DeferredFieldResolver + } + GetNextSubscribedDataTrackHandleStub func() uint16 + getNextSubscribedDataTrackHandleMutex sync.RWMutex + getNextSubscribedDataTrackHandleArgsForCall []struct { + } + getNextSubscribedDataTrackHandleReturns struct { + result1 uint16 + } + getNextSubscribedDataTrackHandleReturnsOnCall map[int]struct { + result1 uint16 + } + GetPacerStub func() pacer.Pacer + getPacerMutex sync.RWMutex + getPacerArgsForCall []struct { + } + getPacerReturns struct { + result1 pacer.Pacer + } + getPacerReturnsOnCall map[int]struct { + result1 pacer.Pacer + } + GetParticipantListenerStub func() types.ParticipantListener + getParticipantListenerMutex sync.RWMutex + getParticipantListenerArgsForCall []struct { + } + getParticipantListenerReturns struct { + result1 types.ParticipantListener + } + getParticipantListenerReturnsOnCall map[int]struct { + result1 types.ParticipantListener + } + GetPendingTrackStub func(livekit.TrackID) *livekit.TrackInfo + getPendingTrackMutex sync.RWMutex + getPendingTrackArgsForCall []struct { + arg1 livekit.TrackID + } + getPendingTrackReturns struct { + result1 *livekit.TrackInfo + } + getPendingTrackReturnsOnCall map[int]struct { + result1 *livekit.TrackInfo + } + GetPlayoutDelayConfigStub func() *livekit.PlayoutDelay + getPlayoutDelayConfigMutex sync.RWMutex + getPlayoutDelayConfigArgsForCall []struct { + } + getPlayoutDelayConfigReturns struct { + result1 *livekit.PlayoutDelay + } + getPlayoutDelayConfigReturnsOnCall map[int]struct { + result1 *livekit.PlayoutDelay + } + GetPublishedDataTrackStub func(uint16) types.DataTrack + getPublishedDataTrackMutex sync.RWMutex + getPublishedDataTrackArgsForCall []struct { + arg1 uint16 + } + getPublishedDataTrackReturns struct { + result1 types.DataTrack + } + getPublishedDataTrackReturnsOnCall map[int]struct { + result1 types.DataTrack + } + GetPublishedDataTracksStub func() []types.DataTrack + getPublishedDataTracksMutex sync.RWMutex + getPublishedDataTracksArgsForCall []struct { + } + getPublishedDataTracksReturns struct { + result1 []types.DataTrack + } + getPublishedDataTracksReturnsOnCall map[int]struct { + result1 []types.DataTrack + } + GetPublishedTrackStub func(livekit.TrackID) types.MediaTrack + getPublishedTrackMutex sync.RWMutex + getPublishedTrackArgsForCall []struct { + arg1 livekit.TrackID + } + getPublishedTrackReturns struct { + result1 types.MediaTrack + } + getPublishedTrackReturnsOnCall map[int]struct { + result1 types.MediaTrack + } + GetPublishedTracksStub func() []types.MediaTrack + getPublishedTracksMutex sync.RWMutex + getPublishedTracksArgsForCall []struct { + } + getPublishedTracksReturns struct { + result1 []types.MediaTrack + } + getPublishedTracksReturnsOnCall map[int]struct { + result1 []types.MediaTrack + } + GetPublisherICESessionUfragStub func() (string, error) + getPublisherICESessionUfragMutex sync.RWMutex + getPublisherICESessionUfragArgsForCall []struct { + } + getPublisherICESessionUfragReturns struct { + result1 string + result2 error + } + getPublisherICESessionUfragReturnsOnCall map[int]struct { + result1 string + result2 error + } + GetReporterStub func() roomobs.ParticipantSessionReporter + getReporterMutex sync.RWMutex + getReporterArgsForCall []struct { + } + getReporterReturns struct { + result1 roomobs.ParticipantSessionReporter + } + getReporterReturnsOnCall map[int]struct { + result1 roomobs.ParticipantSessionReporter + } + GetReporterResolverStub func() roomobs.ParticipantReporterResolver + getReporterResolverMutex sync.RWMutex + getReporterResolverArgsForCall []struct { + } + getReporterResolverReturns struct { + result1 roomobs.ParticipantReporterResolver + } + getReporterResolverReturnsOnCall map[int]struct { + result1 roomobs.ParticipantReporterResolver + } + GetResponseSinkStub func() routing.MessageSink + getResponseSinkMutex sync.RWMutex + getResponseSinkArgsForCall []struct { + } + getResponseSinkReturns struct { + result1 routing.MessageSink + } + getResponseSinkReturnsOnCall map[int]struct { + result1 routing.MessageSink + } + GetSubscribedParticipantsStub func() []livekit.ParticipantID + getSubscribedParticipantsMutex sync.RWMutex + getSubscribedParticipantsArgsForCall []struct { + } + getSubscribedParticipantsReturns struct { + result1 []livekit.ParticipantID + } + getSubscribedParticipantsReturnsOnCall map[int]struct { + result1 []livekit.ParticipantID + } + GetSubscribedTracksStub func() []types.SubscribedTrack + getSubscribedTracksMutex sync.RWMutex + getSubscribedTracksArgsForCall []struct { + } + getSubscribedTracksReturns struct { + result1 []types.SubscribedTrack + } + getSubscribedTracksReturnsOnCall map[int]struct { + result1 []types.SubscribedTrack + } + GetTrailerStub func() []byte + getTrailerMutex sync.RWMutex + getTrailerArgsForCall []struct { + } + getTrailerReturns struct { + result1 []byte + } + getTrailerReturnsOnCall map[int]struct { + result1 []byte + } + HandleAnswerStub func(*livekit.SessionDescription) + handleAnswerMutex sync.RWMutex + handleAnswerArgsForCall []struct { + arg1 *livekit.SessionDescription + } + HandleICERestartSDPFragmentStub func(string) (string, error) + handleICERestartSDPFragmentMutex sync.RWMutex + handleICERestartSDPFragmentArgsForCall []struct { + arg1 string + } + handleICERestartSDPFragmentReturns struct { + result1 string + result2 error + } + handleICERestartSDPFragmentReturnsOnCall map[int]struct { + result1 string + result2 error + } + HandleICETrickleStub func(*livekit.TrickleRequest) + handleICETrickleMutex sync.RWMutex + handleICETrickleArgsForCall []struct { + arg1 *livekit.TrickleRequest + } + HandleICETrickleSDPFragmentStub func(string) error + handleICETrickleSDPFragmentMutex sync.RWMutex + handleICETrickleSDPFragmentArgsForCall []struct { + arg1 string + } + handleICETrickleSDPFragmentReturns struct { + result1 error + } + handleICETrickleSDPFragmentReturnsOnCall map[int]struct { + result1 error + } + HandleLeaveRequestStub func(types.ParticipantCloseReason) + handleLeaveRequestMutex sync.RWMutex + handleLeaveRequestArgsForCall []struct { + arg1 types.ParticipantCloseReason + } + HandleMetricsStub func(livekit.ParticipantID, *livekit.MetricsBatch) error + handleMetricsMutex sync.RWMutex + handleMetricsArgsForCall []struct { + arg1 livekit.ParticipantID + arg2 *livekit.MetricsBatch + } + handleMetricsReturns struct { + result1 error + } + handleMetricsReturnsOnCall map[int]struct { + result1 error + } + HandleOfferStub func(*livekit.SessionDescription) error + handleOfferMutex sync.RWMutex + handleOfferArgsForCall []struct { + arg1 *livekit.SessionDescription + } + handleOfferReturns struct { + result1 error + } + handleOfferReturnsOnCall map[int]struct { + result1 error + } + HandlePublishDataTrackRequestStub func(*livekit.PublishDataTrackRequest) + handlePublishDataTrackRequestMutex sync.RWMutex + handlePublishDataTrackRequestArgsForCall []struct { + arg1 *livekit.PublishDataTrackRequest + } + HandleReceivedDataTrackMessageStub func([]byte, *datatrack.Packet, int64) + handleReceivedDataTrackMessageMutex sync.RWMutex + handleReceivedDataTrackMessageArgsForCall []struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + } + HandleReceiverReportStub func(*sfu.DownTrack, *rtcp.ReceiverReport) + handleReceiverReportMutex sync.RWMutex + handleReceiverReportArgsForCall []struct { + arg1 *sfu.DownTrack + arg2 *rtcp.ReceiverReport + } + HandleReconnectAndSendResponseStub func(livekit.ReconnectReason, *livekit.ReconnectResponse) error + handleReconnectAndSendResponseMutex sync.RWMutex + handleReconnectAndSendResponseArgsForCall []struct { + arg1 livekit.ReconnectReason + arg2 *livekit.ReconnectResponse + } + handleReconnectAndSendResponseReturns struct { + result1 error + } + handleReconnectAndSendResponseReturnsOnCall map[int]struct { + result1 error + } + HandleSignalMessageStub func(proto.Message) error + handleSignalMessageMutex sync.RWMutex + handleSignalMessageArgsForCall []struct { + arg1 proto.Message + } + handleSignalMessageReturns struct { + result1 error + } + handleSignalMessageReturnsOnCall map[int]struct { + result1 error + } + HandleSignalSourceCloseStub func() + handleSignalSourceCloseMutex sync.RWMutex + handleSignalSourceCloseArgsForCall []struct { + } + HandleSimulateScenarioStub func(*livekit.SimulateScenario) error + handleSimulateScenarioMutex sync.RWMutex + handleSimulateScenarioArgsForCall []struct { + arg1 *livekit.SimulateScenario + } + handleSimulateScenarioReturns struct { + result1 error + } + handleSimulateScenarioReturnsOnCall map[int]struct { + result1 error + } + HandleSyncStateStub func(*livekit.SyncState) error + handleSyncStateMutex sync.RWMutex + handleSyncStateArgsForCall []struct { + arg1 *livekit.SyncState + } + handleSyncStateReturns struct { + result1 error + } + handleSyncStateReturnsOnCall map[int]struct { + result1 error + } + HandleUnpublishDataTrackRequestStub func(*livekit.UnpublishDataTrackRequest) + handleUnpublishDataTrackRequestMutex sync.RWMutex + handleUnpublishDataTrackRequestArgsForCall []struct { + arg1 *livekit.UnpublishDataTrackRequest + } + HandleUpdateDataSubscriptionStub func(*livekit.UpdateDataSubscription) + handleUpdateDataSubscriptionMutex sync.RWMutex + handleUpdateDataSubscriptionArgsForCall []struct { + arg1 *livekit.UpdateDataSubscription + } + HandleUpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission) error + handleUpdateSubscriptionPermissionMutex sync.RWMutex + handleUpdateSubscriptionPermissionArgsForCall []struct { + arg1 *livekit.SubscriptionPermission + } + handleUpdateSubscriptionPermissionReturns struct { + result1 error + } + handleUpdateSubscriptionPermissionReturnsOnCall map[int]struct { + result1 error + } + HandleUpdateSubscriptionsStub func([]livekit.TrackID, []*livekit.ParticipantTracks, bool) + handleUpdateSubscriptionsMutex sync.RWMutex + handleUpdateSubscriptionsArgsForCall []struct { + arg1 []livekit.TrackID + arg2 []*livekit.ParticipantTracks + arg3 bool + } + HasConnectedStub func() bool + hasConnectedMutex sync.RWMutex + hasConnectedArgsForCall []struct { + } + hasConnectedReturns struct { + result1 bool + } + hasConnectedReturnsOnCall map[int]struct { + result1 bool + } + HasPermissionStub func(livekit.TrackID, livekit.ParticipantIdentity) bool + hasPermissionMutex sync.RWMutex + hasPermissionArgsForCall []struct { + arg1 livekit.TrackID + arg2 livekit.ParticipantIdentity + } + hasPermissionReturns struct { + result1 bool + } + hasPermissionReturnsOnCall map[int]struct { + result1 bool + } + HiddenStub func() bool + hiddenMutex sync.RWMutex + hiddenArgsForCall []struct { + } + hiddenReturns struct { + result1 bool + } + hiddenReturnsOnCall map[int]struct { + result1 bool + } + ICERestartStub func(*livekit.ICEConfig) + iCERestartMutex sync.RWMutex + iCERestartArgsForCall []struct { + arg1 *livekit.ICEConfig + } + IDStub func() livekit.ParticipantID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.ParticipantID + } + iDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + IdentityStub func() livekit.ParticipantIdentity + identityMutex sync.RWMutex + identityArgsForCall []struct { + } + identityReturns struct { + result1 livekit.ParticipantIdentity + } + identityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + IsAgentStub func() bool + isAgentMutex sync.RWMutex + isAgentArgsForCall []struct { + } + isAgentReturns struct { + result1 bool + } + isAgentReturnsOnCall map[int]struct { + result1 bool + } + IsClosedStub func() bool + isClosedMutex sync.RWMutex + isClosedArgsForCall []struct { + } + isClosedReturns struct { + result1 bool + } + isClosedReturnsOnCall map[int]struct { + result1 bool + } + IsDependentStub func() bool + isDependentMutex sync.RWMutex + isDependentArgsForCall []struct { + } + isDependentReturns struct { + result1 bool + } + isDependentReturnsOnCall map[int]struct { + result1 bool + } + IsDisconnectedStub func() bool + isDisconnectedMutex sync.RWMutex + isDisconnectedArgsForCall []struct { + } + isDisconnectedReturns struct { + result1 bool + } + isDisconnectedReturnsOnCall map[int]struct { + result1 bool + } + IsIdleStub func() bool + isIdleMutex sync.RWMutex + isIdleArgsForCall []struct { + } + isIdleReturns struct { + result1 bool + } + isIdleReturnsOnCall map[int]struct { + result1 bool + } + IsPublisherStub func() bool + isPublisherMutex sync.RWMutex + isPublisherArgsForCall []struct { + } + isPublisherReturns struct { + result1 bool + } + isPublisherReturnsOnCall map[int]struct { + result1 bool + } + IsReadyStub func() bool + isReadyMutex sync.RWMutex + isReadyArgsForCall []struct { + } + isReadyReturns struct { + result1 bool + } + isReadyReturnsOnCall map[int]struct { + result1 bool + } + IsReconnectStub func() bool + isReconnectMutex sync.RWMutex + isReconnectArgsForCall []struct { + } + isReconnectReturns struct { + result1 bool + } + isReconnectReturnsOnCall map[int]struct { + result1 bool + } + IsRecorderStub func() bool + isRecorderMutex sync.RWMutex + isRecorderArgsForCall []struct { + } + isRecorderReturns struct { + result1 bool + } + isRecorderReturnsOnCall map[int]struct { + result1 bool + } + IsSubscribedToStub func(livekit.ParticipantID) bool + isSubscribedToMutex sync.RWMutex + isSubscribedToArgsForCall []struct { + arg1 livekit.ParticipantID + } + isSubscribedToReturns struct { + result1 bool + } + isSubscribedToReturnsOnCall map[int]struct { + result1 bool + } + IsTrackNameSubscribedStub func(livekit.ParticipantIdentity, string) bool + isTrackNameSubscribedMutex sync.RWMutex + isTrackNameSubscribedArgsForCall []struct { + arg1 livekit.ParticipantIdentity + arg2 string + } + isTrackNameSubscribedReturns struct { + result1 bool + } + isTrackNameSubscribedReturnsOnCall map[int]struct { + result1 bool + } + IsUsingSinglePeerConnectionStub func() bool + isUsingSinglePeerConnectionMutex sync.RWMutex + isUsingSinglePeerConnectionArgsForCall []struct { + } + isUsingSinglePeerConnectionReturns struct { + result1 bool + } + isUsingSinglePeerConnectionReturnsOnCall map[int]struct { + result1 bool + } + IssueFullReconnectStub func(types.ParticipantCloseReason) + issueFullReconnectMutex sync.RWMutex + issueFullReconnectArgsForCall []struct { + arg1 types.ParticipantCloseReason + } + KindStub func() livekit.ParticipantInfo_Kind + kindMutex sync.RWMutex + kindArgsForCall []struct { + } + kindReturns struct { + result1 livekit.ParticipantInfo_Kind + } + kindReturnsOnCall map[int]struct { + result1 livekit.ParticipantInfo_Kind + } + MaybeStartMigrationStub func(bool, func()) bool + maybeStartMigrationMutex sync.RWMutex + maybeStartMigrationArgsForCall []struct { + arg1 bool + arg2 func() + } + maybeStartMigrationReturns struct { + result1 bool + } + maybeStartMigrationReturnsOnCall map[int]struct { + result1 bool + } + MigrateStateStub func() types.MigrateState + migrateStateMutex sync.RWMutex + migrateStateArgsForCall []struct { + } + migrateStateReturns struct { + result1 types.MigrateState + } + migrateStateReturnsOnCall map[int]struct { + result1 types.MigrateState + } + MoveToRoomStub func(types.MoveToRoomParams) + moveToRoomMutex sync.RWMutex + moveToRoomArgsForCall []struct { + arg1 types.MoveToRoomParams + } + NegotiateStub func(bool) + negotiateMutex sync.RWMutex + negotiateArgsForCall []struct { + arg1 bool + } + NotifyMigrationStub func() + notifyMigrationMutex sync.RWMutex + notifyMigrationArgsForCall []struct { + } + OnClaimsChangedStub func(func(types.LocalParticipant)) + onClaimsChangedMutex sync.RWMutex + onClaimsChangedArgsForCall []struct { + arg1 func(types.LocalParticipant) + } + OnICEConfigChangedStub func(func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig)) + onICEConfigChangedMutex sync.RWMutex + onICEConfigChangedArgsForCall []struct { + arg1 func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig) + } + PerformRpcStub func(*livekit.PerformRpcRequest, chan string, chan error) + performRpcMutex sync.RWMutex + performRpcArgsForCall []struct { + arg1 *livekit.PerformRpcRequest + arg2 chan string + arg3 chan error + } + ProtocolVersionStub func() types.ProtocolVersion + protocolVersionMutex sync.RWMutex + protocolVersionArgsForCall []struct { + } + protocolVersionReturns struct { + result1 types.ProtocolVersion + } + protocolVersionReturnsOnCall map[int]struct { + result1 types.ProtocolVersion + } + RemovePublishedDataTrackStub func(types.DataTrack) + removePublishedDataTrackMutex sync.RWMutex + removePublishedDataTrackArgsForCall []struct { + arg1 types.DataTrack + } + RemovePublishedTrackStub func(types.MediaTrack, bool) + removePublishedTrackMutex sync.RWMutex + removePublishedTrackArgsForCall []struct { + arg1 types.MediaTrack + arg2 bool + } + RemoveTrackLocalStub func(*webrtc.RTPSender) error + removeTrackLocalMutex sync.RWMutex + removeTrackLocalArgsForCall []struct { + arg1 *webrtc.RTPSender + } + removeTrackLocalReturns struct { + result1 error + } + removeTrackLocalReturnsOnCall map[int]struct { + result1 error + } + SendConnectionQualityUpdateStub func(*livekit.ConnectionQualityUpdate) error + sendConnectionQualityUpdateMutex sync.RWMutex + sendConnectionQualityUpdateArgsForCall []struct { + arg1 *livekit.ConnectionQualityUpdate + } + sendConnectionQualityUpdateReturns struct { + result1 error + } + sendConnectionQualityUpdateReturnsOnCall map[int]struct { + result1 error + } + SendDataMessageStub func(livekit.DataPacket_Kind, []byte, livekit.ParticipantID, uint32) error + sendDataMessageMutex sync.RWMutex + sendDataMessageArgsForCall []struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + arg3 livekit.ParticipantID + arg4 uint32 + } + sendDataMessageReturns struct { + result1 error + } + sendDataMessageReturnsOnCall map[int]struct { + result1 error + } + SendDataMessageUnlabeledStub func([]byte, bool, livekit.ParticipantIdentity) error + sendDataMessageUnlabeledMutex sync.RWMutex + sendDataMessageUnlabeledArgsForCall []struct { + arg1 []byte + arg2 bool + arg3 livekit.ParticipantIdentity + } + sendDataMessageUnlabeledReturns struct { + result1 error + } + sendDataMessageUnlabeledReturnsOnCall map[int]struct { + result1 error + } + SendDataTrackSubscriberHandlesStub func(map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack) error + sendDataTrackSubscriberHandlesMutex sync.RWMutex + sendDataTrackSubscriberHandlesArgsForCall []struct { + arg1 map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack + } + sendDataTrackSubscriberHandlesReturns struct { + result1 error + } + sendDataTrackSubscriberHandlesReturnsOnCall map[int]struct { + result1 error + } + SendJoinResponseStub func(*livekit.JoinResponse) error + sendJoinResponseMutex sync.RWMutex + sendJoinResponseArgsForCall []struct { + arg1 *livekit.JoinResponse + } + sendJoinResponseReturns struct { + result1 error + } + sendJoinResponseReturnsOnCall map[int]struct { + result1 error + } + SendParticipantUpdateStub func([]*livekit.ParticipantInfo) error + sendParticipantUpdateMutex sync.RWMutex + sendParticipantUpdateArgsForCall []struct { + arg1 []*livekit.ParticipantInfo + } + sendParticipantUpdateReturns struct { + result1 error + } + sendParticipantUpdateReturnsOnCall map[int]struct { + result1 error + } + SendRefreshTokenStub func(string) error + sendRefreshTokenMutex sync.RWMutex + sendRefreshTokenArgsForCall []struct { + arg1 string + } + sendRefreshTokenReturns struct { + result1 error + } + sendRefreshTokenReturnsOnCall map[int]struct { + result1 error + } + SendRoomMovedResponseStub func(*livekit.RoomMovedResponse) error + sendRoomMovedResponseMutex sync.RWMutex + sendRoomMovedResponseArgsForCall []struct { + arg1 *livekit.RoomMovedResponse + } + sendRoomMovedResponseReturns struct { + result1 error + } + sendRoomMovedResponseReturnsOnCall map[int]struct { + result1 error + } + SendRoomUpdateStub func(*livekit.Room) error + sendRoomUpdateMutex sync.RWMutex + sendRoomUpdateArgsForCall []struct { + arg1 *livekit.Room + } + sendRoomUpdateReturns struct { + result1 error + } + sendRoomUpdateReturnsOnCall map[int]struct { + result1 error + } + SendSpeakerUpdateStub func([]*livekit.SpeakerInfo, bool) error + sendSpeakerUpdateMutex sync.RWMutex + sendSpeakerUpdateArgsForCall []struct { + arg1 []*livekit.SpeakerInfo + arg2 bool + } + sendSpeakerUpdateReturns struct { + result1 error + } + sendSpeakerUpdateReturnsOnCall map[int]struct { + result1 error + } + SendSubscriptionPermissionUpdateStub func(livekit.ParticipantID, livekit.TrackID, bool) error + sendSubscriptionPermissionUpdateMutex sync.RWMutex + sendSubscriptionPermissionUpdateArgsForCall []struct { + arg1 livekit.ParticipantID + arg2 livekit.TrackID + arg3 bool + } + sendSubscriptionPermissionUpdateReturns struct { + result1 error + } + sendSubscriptionPermissionUpdateReturnsOnCall map[int]struct { + result1 error + } + SetAttributesStub func(map[string]string) + setAttributesMutex sync.RWMutex + setAttributesArgsForCall []struct { + arg1 map[string]string + } + SetICEConfigStub func(*livekit.ICEConfig) + setICEConfigMutex sync.RWMutex + setICEConfigArgsForCall []struct { + arg1 *livekit.ICEConfig + } + SetMetadataStub func(string) + setMetadataMutex sync.RWMutex + setMetadataArgsForCall []struct { + arg1 string + } + SetMigrateInfoStub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, []*livekit.DataChannelReceiveState, []*livekit.PublishDataTrackResponse) + setMigrateInfoMutex sync.RWMutex + setMigrateInfoArgsForCall []struct { + arg1 *webrtc.SessionDescription + arg2 *webrtc.SessionDescription + arg3 []*livekit.TrackPublishedResponse + arg4 []*livekit.DataChannelInfo + arg5 []*livekit.DataChannelReceiveState + arg6 []*livekit.PublishDataTrackResponse + } + SetMigrateStateStub func(types.MigrateState) + setMigrateStateMutex sync.RWMutex + setMigrateStateArgsForCall []struct { + arg1 types.MigrateState + } + SetNameStub func(string) + setNameMutex sync.RWMutex + setNameArgsForCall []struct { + arg1 string + } + SetPermissionStub func(*livekit.ParticipantPermission) bool + setPermissionMutex sync.RWMutex + setPermissionArgsForCall []struct { + arg1 *livekit.ParticipantPermission + } + setPermissionReturns struct { + result1 bool + } + setPermissionReturnsOnCall map[int]struct { + result1 bool + } + SetSignalSourceValidStub func(bool) + setSignalSourceValidMutex sync.RWMutex + setSignalSourceValidArgsForCall []struct { + arg1 bool + } + SetSubscriberAllowPauseStub func(bool) + setSubscriberAllowPauseMutex sync.RWMutex + setSubscriberAllowPauseArgsForCall []struct { + arg1 bool + } + SetSubscriberChannelCapacityStub func(int64) + setSubscriberChannelCapacityMutex sync.RWMutex + setSubscriberChannelCapacityArgsForCall []struct { + arg1 int64 + } + SetTrackMutedStub func(*livekit.MuteTrackRequest, bool) *livekit.TrackInfo + setTrackMutedMutex sync.RWMutex + setTrackMutedArgsForCall []struct { + arg1 *livekit.MuteTrackRequest + arg2 bool + } + setTrackMutedReturns struct { + result1 *livekit.TrackInfo + } + setTrackMutedReturnsOnCall map[int]struct { + result1 *livekit.TrackInfo + } + StateStub func() livekit.ParticipantInfo_State + stateMutex sync.RWMutex + stateArgsForCall []struct { + } + stateReturns struct { + result1 livekit.ParticipantInfo_State + } + stateReturnsOnCall map[int]struct { + result1 livekit.ParticipantInfo_State + } + StopAndGetSubscribedTracksForwarderStateStub func() map[livekit.TrackID]*livekit.RTPForwarderState + stopAndGetSubscribedTracksForwarderStateMutex sync.RWMutex + stopAndGetSubscribedTracksForwarderStateArgsForCall []struct { + } + stopAndGetSubscribedTracksForwarderStateReturns struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + } + stopAndGetSubscribedTracksForwarderStateReturnsOnCall map[int]struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + } + SubscribeToDataTrackStub func(livekit.TrackID) + subscribeToDataTrackMutex sync.RWMutex + subscribeToDataTrackArgsForCall []struct { + arg1 livekit.TrackID + } + SubscribeToTrackStub func(livekit.TrackID, bool) + subscribeToTrackMutex sync.RWMutex + subscribeToTrackArgsForCall []struct { + arg1 livekit.TrackID + arg2 bool + } + SubscriberAsPrimaryStub func() bool + subscriberAsPrimaryMutex sync.RWMutex + subscriberAsPrimaryArgsForCall []struct { + } + subscriberAsPrimaryReturns struct { + result1 bool + } + subscriberAsPrimaryReturnsOnCall map[int]struct { + result1 bool + } + SubscriptionPermissionStub func() (*livekit.SubscriptionPermission, utils.TimedVersion) + subscriptionPermissionMutex sync.RWMutex + subscriptionPermissionArgsForCall []struct { + } + subscriptionPermissionReturns struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + } + subscriptionPermissionReturnsOnCall map[int]struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + } + SupportsCodecChangeStub func() bool + supportsCodecChangeMutex sync.RWMutex + supportsCodecChangeArgsForCall []struct { + } + supportsCodecChangeReturns struct { + result1 bool + } + supportsCodecChangeReturnsOnCall map[int]struct { + result1 bool + } + SupportsMovingStub func() error + supportsMovingMutex sync.RWMutex + supportsMovingArgsForCall []struct { + } + supportsMovingReturns struct { + result1 error + } + supportsMovingReturnsOnCall map[int]struct { + result1 error + } + SupportsSyncStreamIDStub func() bool + supportsSyncStreamIDMutex sync.RWMutex + supportsSyncStreamIDArgsForCall []struct { + } + supportsSyncStreamIDReturns struct { + result1 bool + } + supportsSyncStreamIDReturnsOnCall map[int]struct { + result1 bool + } + SupportsTransceiverReuseStub func() bool + supportsTransceiverReuseMutex sync.RWMutex + supportsTransceiverReuseArgsForCall []struct { + } + supportsTransceiverReuseReturns struct { + result1 bool + } + supportsTransceiverReuseReturnsOnCall map[int]struct { + result1 bool + } + SwapResponseSinkStub func(routing.MessageSink, types.SignallingCloseReason) + swapResponseSinkMutex sync.RWMutex + swapResponseSinkArgsForCall []struct { + arg1 routing.MessageSink + arg2 types.SignallingCloseReason + } + TelemetryGuardStub func() *telemetry.ReferenceGuard + telemetryGuardMutex sync.RWMutex + telemetryGuardArgsForCall []struct { + } + telemetryGuardReturns struct { + result1 *telemetry.ReferenceGuard + } + telemetryGuardReturnsOnCall map[int]struct { + result1 *telemetry.ReferenceGuard + } + ToProtoStub func() *livekit.ParticipantInfo + toProtoMutex sync.RWMutex + toProtoArgsForCall []struct { + } + toProtoReturns struct { + result1 *livekit.ParticipantInfo + } + toProtoReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + } + ToProtoWithVersionStub func() (*livekit.ParticipantInfo, utils.TimedVersion) + toProtoWithVersionMutex sync.RWMutex + toProtoWithVersionArgsForCall []struct { + } + toProtoWithVersionReturns struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + } + toProtoWithVersionReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + } + UncacheDownTrackStub func(*webrtc.RTPTransceiver) + uncacheDownTrackMutex sync.RWMutex + uncacheDownTrackArgsForCall []struct { + arg1 *webrtc.RTPTransceiver + } + UnsubscribeFromDataTrackStub func(livekit.TrackID) + unsubscribeFromDataTrackMutex sync.RWMutex + unsubscribeFromDataTrackArgsForCall []struct { + arg1 livekit.TrackID + } + UnsubscribeFromTrackStub func(livekit.TrackID) + unsubscribeFromTrackMutex sync.RWMutex + unsubscribeFromTrackArgsForCall []struct { + arg1 livekit.TrackID + } + UpdateAudioTrackStub func(*livekit.UpdateLocalAudioTrack) error + updateAudioTrackMutex sync.RWMutex + updateAudioTrackArgsForCall []struct { + arg1 *livekit.UpdateLocalAudioTrack + } + updateAudioTrackReturns struct { + result1 error + } + updateAudioTrackReturnsOnCall map[int]struct { + result1 error + } + UpdateDataTrackSubscriptionOptionsStub func(livekit.TrackID, *livekit.DataTrackSubscriptionOptions) + updateDataTrackSubscriptionOptionsMutex sync.RWMutex + updateDataTrackSubscriptionOptionsArgsForCall []struct { + arg1 livekit.TrackID + arg2 *livekit.DataTrackSubscriptionOptions + } + UpdateLastSeenSignalStub func() + updateLastSeenSignalMutex sync.RWMutex + updateLastSeenSignalArgsForCall []struct { + } + UpdateMediaLossStub func(livekit.NodeID, livekit.TrackID, uint32) error + updateMediaLossMutex sync.RWMutex + updateMediaLossArgsForCall []struct { + arg1 livekit.NodeID + arg2 livekit.TrackID + arg3 uint32 + } + updateMediaLossReturns struct { + result1 error + } + updateMediaLossReturnsOnCall map[int]struct { + result1 error + } + UpdateMediaRTTStub func(uint32) + updateMediaRTTMutex sync.RWMutex + updateMediaRTTArgsForCall []struct { + arg1 uint32 + } + UpdateMetadataStub func(*livekit.UpdateParticipantMetadata, bool) error + updateMetadataMutex sync.RWMutex + updateMetadataArgsForCall []struct { + arg1 *livekit.UpdateParticipantMetadata + arg2 bool + } + updateMetadataReturns struct { + result1 error + } + updateMetadataReturnsOnCall map[int]struct { + result1 error + } + UpdateSignalingRTTStub func(uint32) + updateSignalingRTTMutex sync.RWMutex + updateSignalingRTTArgsForCall []struct { + arg1 uint32 + } + UpdateSubscribedAudioCodecsStub func(livekit.NodeID, livekit.TrackID, []*livekit.SubscribedAudioCodec) error + updateSubscribedAudioCodecsMutex sync.RWMutex + updateSubscribedAudioCodecsArgsForCall []struct { + arg1 livekit.NodeID + arg2 livekit.TrackID + arg3 []*livekit.SubscribedAudioCodec + } + updateSubscribedAudioCodecsReturns struct { + result1 error + } + updateSubscribedAudioCodecsReturnsOnCall map[int]struct { + result1 error + } + UpdateSubscribedQualityStub func(livekit.NodeID, livekit.TrackID, []types.SubscribedCodecQuality) error + updateSubscribedQualityMutex sync.RWMutex + updateSubscribedQualityArgsForCall []struct { + arg1 livekit.NodeID + arg2 livekit.TrackID + arg3 []types.SubscribedCodecQuality + } + updateSubscribedQualityReturns struct { + result1 error + } + updateSubscribedQualityReturnsOnCall map[int]struct { + result1 error + } + UpdateSubscribedTrackSettingsStub func(livekit.TrackID, *livekit.UpdateTrackSettings) + updateSubscribedTrackSettingsMutex sync.RWMutex + updateSubscribedTrackSettingsArgsForCall []struct { + arg1 livekit.TrackID + arg2 *livekit.UpdateTrackSettings + } + UpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission, utils.TimedVersion, func(participantID livekit.ParticipantID) types.LocalParticipant) error + updateSubscriptionPermissionMutex sync.RWMutex + updateSubscriptionPermissionArgsForCall []struct { + arg1 *livekit.SubscriptionPermission + arg2 utils.TimedVersion + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant + } + updateSubscriptionPermissionReturns struct { + result1 error + } + updateSubscriptionPermissionReturnsOnCall map[int]struct { + result1 error + } + UpdateVideoTrackStub func(*livekit.UpdateLocalVideoTrack) error + updateVideoTrackMutex sync.RWMutex + updateVideoTrackArgsForCall []struct { + arg1 *livekit.UpdateLocalVideoTrack + } + updateVideoTrackReturns struct { + result1 error + } + updateVideoTrackReturnsOnCall map[int]struct { + result1 error + } + VerifyStub func() bool + verifyMutex sync.RWMutex + verifyArgsForCall []struct { + } + verifyReturns struct { + result1 bool + } + verifyReturnsOnCall map[int]struct { + result1 bool + } + VerifySubscribeParticipantInfoStub func(livekit.ParticipantID, uint32) + verifySubscribeParticipantInfoMutex sync.RWMutex + verifySubscribeParticipantInfoArgsForCall []struct { + arg1 livekit.ParticipantID + arg2 uint32 + } + VersionStub func() utils.TimedVersion + versionMutex sync.RWMutex + versionArgsForCall []struct { + } + versionReturns struct { + result1 utils.TimedVersion + } + versionReturnsOnCall map[int]struct { + result1 utils.TimedVersion + } + WaitUntilSubscribedStub func(time.Duration) error + waitUntilSubscribedMutex sync.RWMutex + waitUntilSubscribedArgsForCall []struct { + arg1 time.Duration + } + waitUntilSubscribedReturns struct { + result1 error + } + waitUntilSubscribedReturnsOnCall map[int]struct { + result1 error + } + WriteSubscriberRTCPStub func([]rtcp.Packet) error + writeSubscriberRTCPMutex sync.RWMutex + writeSubscriberRTCPArgsForCall []struct { + arg1 []rtcp.Packet + } + writeSubscriberRTCPReturns struct { + result1 error + } + writeSubscriberRTCPReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeLocalParticipant) ActiveAt() time.Time { + fake.activeAtMutex.Lock() + ret, specificReturn := fake.activeAtReturnsOnCall[len(fake.activeAtArgsForCall)] + fake.activeAtArgsForCall = append(fake.activeAtArgsForCall, struct { + }{}) + stub := fake.ActiveAtStub + fakeReturns := fake.activeAtReturns + fake.recordInvocation("ActiveAt", []interface{}{}) + fake.activeAtMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) ActiveAtCallCount() int { + fake.activeAtMutex.RLock() + defer fake.activeAtMutex.RUnlock() + return len(fake.activeAtArgsForCall) +} + +func (fake *FakeLocalParticipant) ActiveAtCalls(stub func() time.Time) { + fake.activeAtMutex.Lock() + defer fake.activeAtMutex.Unlock() + fake.ActiveAtStub = stub +} + +func (fake *FakeLocalParticipant) ActiveAtReturns(result1 time.Time) { + fake.activeAtMutex.Lock() + defer fake.activeAtMutex.Unlock() + fake.ActiveAtStub = nil + fake.activeAtReturns = struct { + result1 time.Time + }{result1} +} + +func (fake *FakeLocalParticipant) ActiveAtReturnsOnCall(i int, result1 time.Time) { + fake.activeAtMutex.Lock() + defer fake.activeAtMutex.Unlock() + fake.ActiveAtStub = nil + if fake.activeAtReturnsOnCall == nil { + fake.activeAtReturnsOnCall = make(map[int]struct { + result1 time.Time + }) + } + fake.activeAtReturnsOnCall[i] = struct { + result1 time.Time + }{result1} +} + +func (fake *FakeLocalParticipant) AddOnClose(arg1 string, arg2 func(types.LocalParticipant)) { + fake.addOnCloseMutex.Lock() + fake.addOnCloseArgsForCall = append(fake.addOnCloseArgsForCall, struct { + arg1 string + arg2 func(types.LocalParticipant) + }{arg1, arg2}) + stub := fake.AddOnCloseStub + fake.recordInvocation("AddOnClose", []interface{}{arg1, arg2}) + fake.addOnCloseMutex.Unlock() + if stub != nil { + fake.AddOnCloseStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) AddOnCloseCallCount() int { + fake.addOnCloseMutex.RLock() + defer fake.addOnCloseMutex.RUnlock() + return len(fake.addOnCloseArgsForCall) +} + +func (fake *FakeLocalParticipant) AddOnCloseCalls(stub func(string, func(types.LocalParticipant))) { + fake.addOnCloseMutex.Lock() + defer fake.addOnCloseMutex.Unlock() + fake.AddOnCloseStub = stub +} + +func (fake *FakeLocalParticipant) AddOnCloseArgsForCall(i int) (string, func(types.LocalParticipant)) { + fake.addOnCloseMutex.RLock() + defer fake.addOnCloseMutex.RUnlock() + argsForCall := fake.addOnCloseArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) AddTrack(arg1 *livekit.AddTrackRequest) { + fake.addTrackMutex.Lock() + fake.addTrackArgsForCall = append(fake.addTrackArgsForCall, struct { + arg1 *livekit.AddTrackRequest + }{arg1}) + stub := fake.AddTrackStub + fake.recordInvocation("AddTrack", []interface{}{arg1}) + fake.addTrackMutex.Unlock() + if stub != nil { + fake.AddTrackStub(arg1) + } +} + +func (fake *FakeLocalParticipant) AddTrackCallCount() int { + fake.addTrackMutex.RLock() + defer fake.addTrackMutex.RUnlock() + return len(fake.addTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) AddTrackCalls(stub func(*livekit.AddTrackRequest)) { + fake.addTrackMutex.Lock() + defer fake.addTrackMutex.Unlock() + fake.AddTrackStub = stub +} + +func (fake *FakeLocalParticipant) AddTrackArgsForCall(i int) *livekit.AddTrackRequest { + fake.addTrackMutex.RLock() + defer fake.addTrackMutex.RUnlock() + argsForCall := fake.addTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) AddTrackLocal(arg1 webrtc.TrackLocal, arg2 types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + fake.addTrackLocalMutex.Lock() + ret, specificReturn := fake.addTrackLocalReturnsOnCall[len(fake.addTrackLocalArgsForCall)] + fake.addTrackLocalArgsForCall = append(fake.addTrackLocalArgsForCall, struct { + arg1 webrtc.TrackLocal + arg2 types.AddTrackParams + }{arg1, arg2}) + stub := fake.AddTrackLocalStub + fakeReturns := fake.addTrackLocalReturns + fake.recordInvocation("AddTrackLocal", []interface{}{arg1, arg2}) + fake.addTrackLocalMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeLocalParticipant) AddTrackLocalCallCount() int { + fake.addTrackLocalMutex.RLock() + defer fake.addTrackLocalMutex.RUnlock() + return len(fake.addTrackLocalArgsForCall) +} + +func (fake *FakeLocalParticipant) AddTrackLocalCalls(stub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error)) { + fake.addTrackLocalMutex.Lock() + defer fake.addTrackLocalMutex.Unlock() + fake.AddTrackLocalStub = stub +} + +func (fake *FakeLocalParticipant) AddTrackLocalArgsForCall(i int) (webrtc.TrackLocal, types.AddTrackParams) { + fake.addTrackLocalMutex.RLock() + defer fake.addTrackLocalMutex.RUnlock() + argsForCall := fake.addTrackLocalArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) AddTrackLocalReturns(result1 *webrtc.RTPSender, result2 *webrtc.RTPTransceiver, result3 error) { + fake.addTrackLocalMutex.Lock() + defer fake.addTrackLocalMutex.Unlock() + fake.AddTrackLocalStub = nil + fake.addTrackLocalReturns = struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + }{result1, result2, result3} +} + +func (fake *FakeLocalParticipant) AddTrackLocalReturnsOnCall(i int, result1 *webrtc.RTPSender, result2 *webrtc.RTPTransceiver, result3 error) { + fake.addTrackLocalMutex.Lock() + defer fake.addTrackLocalMutex.Unlock() + fake.AddTrackLocalStub = nil + if fake.addTrackLocalReturnsOnCall == nil { + fake.addTrackLocalReturnsOnCall = make(map[int]struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + }) + } + fake.addTrackLocalReturnsOnCall[i] = struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + }{result1, result2, result3} +} + +func (fake *FakeLocalParticipant) AddTransceiverFromTrackLocal(arg1 webrtc.TrackLocal, arg2 types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error) { + fake.addTransceiverFromTrackLocalMutex.Lock() + ret, specificReturn := fake.addTransceiverFromTrackLocalReturnsOnCall[len(fake.addTransceiverFromTrackLocalArgsForCall)] + fake.addTransceiverFromTrackLocalArgsForCall = append(fake.addTransceiverFromTrackLocalArgsForCall, struct { + arg1 webrtc.TrackLocal + arg2 types.AddTrackParams + }{arg1, arg2}) + stub := fake.AddTransceiverFromTrackLocalStub + fakeReturns := fake.addTransceiverFromTrackLocalReturns + fake.recordInvocation("AddTransceiverFromTrackLocal", []interface{}{arg1, arg2}) + fake.addTransceiverFromTrackLocalMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeLocalParticipant) AddTransceiverFromTrackLocalCallCount() int { + fake.addTransceiverFromTrackLocalMutex.RLock() + defer fake.addTransceiverFromTrackLocalMutex.RUnlock() + return len(fake.addTransceiverFromTrackLocalArgsForCall) +} + +func (fake *FakeLocalParticipant) AddTransceiverFromTrackLocalCalls(stub func(webrtc.TrackLocal, types.AddTrackParams) (*webrtc.RTPSender, *webrtc.RTPTransceiver, error)) { + fake.addTransceiverFromTrackLocalMutex.Lock() + defer fake.addTransceiverFromTrackLocalMutex.Unlock() + fake.AddTransceiverFromTrackLocalStub = stub +} + +func (fake *FakeLocalParticipant) AddTransceiverFromTrackLocalArgsForCall(i int) (webrtc.TrackLocal, types.AddTrackParams) { + fake.addTransceiverFromTrackLocalMutex.RLock() + defer fake.addTransceiverFromTrackLocalMutex.RUnlock() + argsForCall := fake.addTransceiverFromTrackLocalArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) AddTransceiverFromTrackLocalReturns(result1 *webrtc.RTPSender, result2 *webrtc.RTPTransceiver, result3 error) { + fake.addTransceiverFromTrackLocalMutex.Lock() + defer fake.addTransceiverFromTrackLocalMutex.Unlock() + fake.AddTransceiverFromTrackLocalStub = nil + fake.addTransceiverFromTrackLocalReturns = struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + }{result1, result2, result3} +} + +func (fake *FakeLocalParticipant) AddTransceiverFromTrackLocalReturnsOnCall(i int, result1 *webrtc.RTPSender, result2 *webrtc.RTPTransceiver, result3 error) { + fake.addTransceiverFromTrackLocalMutex.Lock() + defer fake.addTransceiverFromTrackLocalMutex.Unlock() + fake.AddTransceiverFromTrackLocalStub = nil + if fake.addTransceiverFromTrackLocalReturnsOnCall == nil { + fake.addTransceiverFromTrackLocalReturnsOnCall = make(map[int]struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + }) + } + fake.addTransceiverFromTrackLocalReturnsOnCall[i] = struct { + result1 *webrtc.RTPSender + result2 *webrtc.RTPTransceiver + result3 error + }{result1, result2, result3} +} + +func (fake *FakeLocalParticipant) CacheDownTrack(arg1 livekit.TrackID, arg2 *webrtc.RTPTransceiver, arg3 sfu.DownTrackState) { + fake.cacheDownTrackMutex.Lock() + fake.cacheDownTrackArgsForCall = append(fake.cacheDownTrackArgsForCall, struct { + arg1 livekit.TrackID + arg2 *webrtc.RTPTransceiver + arg3 sfu.DownTrackState + }{arg1, arg2, arg3}) + stub := fake.CacheDownTrackStub + fake.recordInvocation("CacheDownTrack", []interface{}{arg1, arg2, arg3}) + fake.cacheDownTrackMutex.Unlock() + if stub != nil { + fake.CacheDownTrackStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipant) CacheDownTrackCallCount() int { + fake.cacheDownTrackMutex.RLock() + defer fake.cacheDownTrackMutex.RUnlock() + return len(fake.cacheDownTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) CacheDownTrackCalls(stub func(livekit.TrackID, *webrtc.RTPTransceiver, sfu.DownTrackState)) { + fake.cacheDownTrackMutex.Lock() + defer fake.cacheDownTrackMutex.Unlock() + fake.CacheDownTrackStub = stub +} + +func (fake *FakeLocalParticipant) CacheDownTrackArgsForCall(i int) (livekit.TrackID, *webrtc.RTPTransceiver, sfu.DownTrackState) { + fake.cacheDownTrackMutex.RLock() + defer fake.cacheDownTrackMutex.RUnlock() + argsForCall := fake.cacheDownTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) CanPublish() bool { + fake.canPublishMutex.Lock() + ret, specificReturn := fake.canPublishReturnsOnCall[len(fake.canPublishArgsForCall)] + fake.canPublishArgsForCall = append(fake.canPublishArgsForCall, struct { + }{}) + stub := fake.CanPublishStub + fakeReturns := fake.canPublishReturns + fake.recordInvocation("CanPublish", []interface{}{}) + fake.canPublishMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CanPublishCallCount() int { + fake.canPublishMutex.RLock() + defer fake.canPublishMutex.RUnlock() + return len(fake.canPublishArgsForCall) +} + +func (fake *FakeLocalParticipant) CanPublishCalls(stub func() bool) { + fake.canPublishMutex.Lock() + defer fake.canPublishMutex.Unlock() + fake.CanPublishStub = stub +} + +func (fake *FakeLocalParticipant) CanPublishReturns(result1 bool) { + fake.canPublishMutex.Lock() + defer fake.canPublishMutex.Unlock() + fake.CanPublishStub = nil + fake.canPublishReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanPublishReturnsOnCall(i int, result1 bool) { + fake.canPublishMutex.Lock() + defer fake.canPublishMutex.Unlock() + fake.CanPublishStub = nil + if fake.canPublishReturnsOnCall == nil { + fake.canPublishReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.canPublishReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanPublishData() bool { + fake.canPublishDataMutex.Lock() + ret, specificReturn := fake.canPublishDataReturnsOnCall[len(fake.canPublishDataArgsForCall)] + fake.canPublishDataArgsForCall = append(fake.canPublishDataArgsForCall, struct { + }{}) + stub := fake.CanPublishDataStub + fakeReturns := fake.canPublishDataReturns + fake.recordInvocation("CanPublishData", []interface{}{}) + fake.canPublishDataMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CanPublishDataCallCount() int { + fake.canPublishDataMutex.RLock() + defer fake.canPublishDataMutex.RUnlock() + return len(fake.canPublishDataArgsForCall) +} + +func (fake *FakeLocalParticipant) CanPublishDataCalls(stub func() bool) { + fake.canPublishDataMutex.Lock() + defer fake.canPublishDataMutex.Unlock() + fake.CanPublishDataStub = stub +} + +func (fake *FakeLocalParticipant) CanPublishDataReturns(result1 bool) { + fake.canPublishDataMutex.Lock() + defer fake.canPublishDataMutex.Unlock() + fake.CanPublishDataStub = nil + fake.canPublishDataReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanPublishDataReturnsOnCall(i int, result1 bool) { + fake.canPublishDataMutex.Lock() + defer fake.canPublishDataMutex.Unlock() + fake.CanPublishDataStub = nil + if fake.canPublishDataReturnsOnCall == nil { + fake.canPublishDataReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.canPublishDataReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanPublishSource(arg1 livekit.TrackSource) bool { + fake.canPublishSourceMutex.Lock() + ret, specificReturn := fake.canPublishSourceReturnsOnCall[len(fake.canPublishSourceArgsForCall)] + fake.canPublishSourceArgsForCall = append(fake.canPublishSourceArgsForCall, struct { + arg1 livekit.TrackSource + }{arg1}) + stub := fake.CanPublishSourceStub + fakeReturns := fake.canPublishSourceReturns + fake.recordInvocation("CanPublishSource", []interface{}{arg1}) + fake.canPublishSourceMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CanPublishSourceCallCount() int { + fake.canPublishSourceMutex.RLock() + defer fake.canPublishSourceMutex.RUnlock() + return len(fake.canPublishSourceArgsForCall) +} + +func (fake *FakeLocalParticipant) CanPublishSourceCalls(stub func(livekit.TrackSource) bool) { + fake.canPublishSourceMutex.Lock() + defer fake.canPublishSourceMutex.Unlock() + fake.CanPublishSourceStub = stub +} + +func (fake *FakeLocalParticipant) CanPublishSourceArgsForCall(i int) livekit.TrackSource { + fake.canPublishSourceMutex.RLock() + defer fake.canPublishSourceMutex.RUnlock() + argsForCall := fake.canPublishSourceArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) CanPublishSourceReturns(result1 bool) { + fake.canPublishSourceMutex.Lock() + defer fake.canPublishSourceMutex.Unlock() + fake.CanPublishSourceStub = nil + fake.canPublishSourceReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanPublishSourceReturnsOnCall(i int, result1 bool) { + fake.canPublishSourceMutex.Lock() + defer fake.canPublishSourceMutex.Unlock() + fake.CanPublishSourceStub = nil + if fake.canPublishSourceReturnsOnCall == nil { + fake.canPublishSourceReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.canPublishSourceReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanSkipBroadcast() bool { + fake.canSkipBroadcastMutex.Lock() + ret, specificReturn := fake.canSkipBroadcastReturnsOnCall[len(fake.canSkipBroadcastArgsForCall)] + fake.canSkipBroadcastArgsForCall = append(fake.canSkipBroadcastArgsForCall, struct { + }{}) + stub := fake.CanSkipBroadcastStub + fakeReturns := fake.canSkipBroadcastReturns + fake.recordInvocation("CanSkipBroadcast", []interface{}{}) + fake.canSkipBroadcastMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CanSkipBroadcastCallCount() int { + fake.canSkipBroadcastMutex.RLock() + defer fake.canSkipBroadcastMutex.RUnlock() + return len(fake.canSkipBroadcastArgsForCall) +} + +func (fake *FakeLocalParticipant) CanSkipBroadcastCalls(stub func() bool) { + fake.canSkipBroadcastMutex.Lock() + defer fake.canSkipBroadcastMutex.Unlock() + fake.CanSkipBroadcastStub = stub +} + +func (fake *FakeLocalParticipant) CanSkipBroadcastReturns(result1 bool) { + fake.canSkipBroadcastMutex.Lock() + defer fake.canSkipBroadcastMutex.Unlock() + fake.CanSkipBroadcastStub = nil + fake.canSkipBroadcastReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanSkipBroadcastReturnsOnCall(i int, result1 bool) { + fake.canSkipBroadcastMutex.Lock() + defer fake.canSkipBroadcastMutex.Unlock() + fake.CanSkipBroadcastStub = nil + if fake.canSkipBroadcastReturnsOnCall == nil { + fake.canSkipBroadcastReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.canSkipBroadcastReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanSubscribe() bool { + fake.canSubscribeMutex.Lock() + ret, specificReturn := fake.canSubscribeReturnsOnCall[len(fake.canSubscribeArgsForCall)] + fake.canSubscribeArgsForCall = append(fake.canSubscribeArgsForCall, struct { + }{}) + stub := fake.CanSubscribeStub + fakeReturns := fake.canSubscribeReturns + fake.recordInvocation("CanSubscribe", []interface{}{}) + fake.canSubscribeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CanSubscribeCallCount() int { + fake.canSubscribeMutex.RLock() + defer fake.canSubscribeMutex.RUnlock() + return len(fake.canSubscribeArgsForCall) +} + +func (fake *FakeLocalParticipant) CanSubscribeCalls(stub func() bool) { + fake.canSubscribeMutex.Lock() + defer fake.canSubscribeMutex.Unlock() + fake.CanSubscribeStub = stub +} + +func (fake *FakeLocalParticipant) CanSubscribeReturns(result1 bool) { + fake.canSubscribeMutex.Lock() + defer fake.canSubscribeMutex.Unlock() + fake.CanSubscribeStub = nil + fake.canSubscribeReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) CanSubscribeReturnsOnCall(i int, result1 bool) { + fake.canSubscribeMutex.Lock() + defer fake.canSubscribeMutex.Unlock() + fake.CanSubscribeStub = nil + if fake.canSubscribeReturnsOnCall == nil { + fake.canSubscribeReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.canSubscribeReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) ClaimGrants() *auth.ClaimGrants { + fake.claimGrantsMutex.Lock() + ret, specificReturn := fake.claimGrantsReturnsOnCall[len(fake.claimGrantsArgsForCall)] + fake.claimGrantsArgsForCall = append(fake.claimGrantsArgsForCall, struct { + }{}) + stub := fake.ClaimGrantsStub + fakeReturns := fake.claimGrantsReturns + fake.recordInvocation("ClaimGrants", []interface{}{}) + fake.claimGrantsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) ClaimGrantsCallCount() int { + fake.claimGrantsMutex.RLock() + defer fake.claimGrantsMutex.RUnlock() + return len(fake.claimGrantsArgsForCall) +} + +func (fake *FakeLocalParticipant) ClaimGrantsCalls(stub func() *auth.ClaimGrants) { + fake.claimGrantsMutex.Lock() + defer fake.claimGrantsMutex.Unlock() + fake.ClaimGrantsStub = stub +} + +func (fake *FakeLocalParticipant) ClaimGrantsReturns(result1 *auth.ClaimGrants) { + fake.claimGrantsMutex.Lock() + defer fake.claimGrantsMutex.Unlock() + fake.ClaimGrantsStub = nil + fake.claimGrantsReturns = struct { + result1 *auth.ClaimGrants + }{result1} +} + +func (fake *FakeLocalParticipant) ClaimGrantsReturnsOnCall(i int, result1 *auth.ClaimGrants) { + fake.claimGrantsMutex.Lock() + defer fake.claimGrantsMutex.Unlock() + fake.ClaimGrantsStub = nil + if fake.claimGrantsReturnsOnCall == nil { + fake.claimGrantsReturnsOnCall = make(map[int]struct { + result1 *auth.ClaimGrants + }) + } + fake.claimGrantsReturnsOnCall[i] = struct { + result1 *auth.ClaimGrants + }{result1} +} + +func (fake *FakeLocalParticipant) ClearParticipantListener() { + fake.clearParticipantListenerMutex.Lock() + fake.clearParticipantListenerArgsForCall = append(fake.clearParticipantListenerArgsForCall, struct { + }{}) + stub := fake.ClearParticipantListenerStub + fake.recordInvocation("ClearParticipantListener", []interface{}{}) + fake.clearParticipantListenerMutex.Unlock() + if stub != nil { + fake.ClearParticipantListenerStub() + } +} + +func (fake *FakeLocalParticipant) ClearParticipantListenerCallCount() int { + fake.clearParticipantListenerMutex.RLock() + defer fake.clearParticipantListenerMutex.RUnlock() + return len(fake.clearParticipantListenerArgsForCall) +} + +func (fake *FakeLocalParticipant) ClearParticipantListenerCalls(stub func()) { + fake.clearParticipantListenerMutex.Lock() + defer fake.clearParticipantListenerMutex.Unlock() + fake.ClearParticipantListenerStub = stub +} + +func (fake *FakeLocalParticipant) Close(arg1 bool, arg2 types.ParticipantCloseReason, arg3 bool) error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + arg1 bool + arg2 types.ParticipantCloseReason + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{arg1, arg2, arg3}) + fake.closeMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeLocalParticipant) CloseCalls(stub func(bool, types.ParticipantCloseReason, bool) error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeLocalParticipant) CloseArgsForCall(i int) (bool, types.ParticipantCloseReason, bool) { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + argsForCall := fake.closeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) CloseReason() types.ParticipantCloseReason { + fake.closeReasonMutex.Lock() + ret, specificReturn := fake.closeReasonReturnsOnCall[len(fake.closeReasonArgsForCall)] + fake.closeReasonArgsForCall = append(fake.closeReasonArgsForCall, struct { + }{}) + stub := fake.CloseReasonStub + fakeReturns := fake.closeReasonReturns + fake.recordInvocation("CloseReason", []interface{}{}) + fake.closeReasonMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) CloseReasonCallCount() int { + fake.closeReasonMutex.RLock() + defer fake.closeReasonMutex.RUnlock() + return len(fake.closeReasonArgsForCall) +} + +func (fake *FakeLocalParticipant) CloseReasonCalls(stub func() types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = stub +} + +func (fake *FakeLocalParticipant) CloseReasonReturns(result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + fake.closeReasonReturns = struct { + result1 types.ParticipantCloseReason + }{result1} +} + +func (fake *FakeLocalParticipant) CloseReasonReturnsOnCall(i int, result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + if fake.closeReasonReturnsOnCall == nil { + fake.closeReasonReturnsOnCall = make(map[int]struct { + result1 types.ParticipantCloseReason + }) + } + fake.closeReasonReturnsOnCall[i] = struct { + result1 types.ParticipantCloseReason + }{result1} +} + +func (fake *FakeLocalParticipant) CloseSignalConnection(arg1 types.SignallingCloseReason) { + fake.closeSignalConnectionMutex.Lock() + fake.closeSignalConnectionArgsForCall = append(fake.closeSignalConnectionArgsForCall, struct { + arg1 types.SignallingCloseReason + }{arg1}) + stub := fake.CloseSignalConnectionStub + fake.recordInvocation("CloseSignalConnection", []interface{}{arg1}) + fake.closeSignalConnectionMutex.Unlock() + if stub != nil { + fake.CloseSignalConnectionStub(arg1) + } +} + +func (fake *FakeLocalParticipant) CloseSignalConnectionCallCount() int { + fake.closeSignalConnectionMutex.RLock() + defer fake.closeSignalConnectionMutex.RUnlock() + return len(fake.closeSignalConnectionArgsForCall) +} + +func (fake *FakeLocalParticipant) CloseSignalConnectionCalls(stub func(types.SignallingCloseReason)) { + fake.closeSignalConnectionMutex.Lock() + defer fake.closeSignalConnectionMutex.Unlock() + fake.CloseSignalConnectionStub = stub +} + +func (fake *FakeLocalParticipant) CloseSignalConnectionArgsForCall(i int) types.SignallingCloseReason { + fake.closeSignalConnectionMutex.RLock() + defer fake.closeSignalConnectionMutex.RUnlock() + argsForCall := fake.closeSignalConnectionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) ConnectedAt() time.Time { + fake.connectedAtMutex.Lock() + ret, specificReturn := fake.connectedAtReturnsOnCall[len(fake.connectedAtArgsForCall)] + fake.connectedAtArgsForCall = append(fake.connectedAtArgsForCall, struct { + }{}) + stub := fake.ConnectedAtStub + fakeReturns := fake.connectedAtReturns + fake.recordInvocation("ConnectedAt", []interface{}{}) + fake.connectedAtMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) ConnectedAtCallCount() int { + fake.connectedAtMutex.RLock() + defer fake.connectedAtMutex.RUnlock() + return len(fake.connectedAtArgsForCall) +} + +func (fake *FakeLocalParticipant) ConnectedAtCalls(stub func() time.Time) { + fake.connectedAtMutex.Lock() + defer fake.connectedAtMutex.Unlock() + fake.ConnectedAtStub = stub +} + +func (fake *FakeLocalParticipant) ConnectedAtReturns(result1 time.Time) { + fake.connectedAtMutex.Lock() + defer fake.connectedAtMutex.Unlock() + fake.ConnectedAtStub = nil + fake.connectedAtReturns = struct { + result1 time.Time + }{result1} +} + +func (fake *FakeLocalParticipant) ConnectedAtReturnsOnCall(i int, result1 time.Time) { + fake.connectedAtMutex.Lock() + defer fake.connectedAtMutex.Unlock() + fake.ConnectedAtStub = nil + if fake.connectedAtReturnsOnCall == nil { + fake.connectedAtReturnsOnCall = make(map[int]struct { + result1 time.Time + }) + } + fake.connectedAtReturnsOnCall[i] = struct { + result1 time.Time + }{result1} +} + +func (fake *FakeLocalParticipant) DebugInfo() map[string]any { + fake.debugInfoMutex.Lock() + ret, specificReturn := fake.debugInfoReturnsOnCall[len(fake.debugInfoArgsForCall)] + fake.debugInfoArgsForCall = append(fake.debugInfoArgsForCall, struct { + }{}) + stub := fake.DebugInfoStub + fakeReturns := fake.debugInfoReturns + fake.recordInvocation("DebugInfo", []interface{}{}) + fake.debugInfoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) DebugInfoCallCount() int { + fake.debugInfoMutex.RLock() + defer fake.debugInfoMutex.RUnlock() + return len(fake.debugInfoArgsForCall) +} + +func (fake *FakeLocalParticipant) DebugInfoCalls(stub func() map[string]any) { + fake.debugInfoMutex.Lock() + defer fake.debugInfoMutex.Unlock() + fake.DebugInfoStub = stub +} + +func (fake *FakeLocalParticipant) DebugInfoReturns(result1 map[string]any) { + fake.debugInfoMutex.Lock() + defer fake.debugInfoMutex.Unlock() + fake.DebugInfoStub = nil + fake.debugInfoReturns = struct { + result1 map[string]any + }{result1} +} + +func (fake *FakeLocalParticipant) DebugInfoReturnsOnCall(i int, result1 map[string]any) { + fake.debugInfoMutex.Lock() + defer fake.debugInfoMutex.Unlock() + fake.DebugInfoStub = nil + if fake.debugInfoReturnsOnCall == nil { + fake.debugInfoReturnsOnCall = make(map[int]struct { + result1 map[string]any + }) + } + fake.debugInfoReturnsOnCall[i] = struct { + result1 map[string]any + }{result1} +} + +func (fake *FakeLocalParticipant) Disconnected() <-chan struct{} { + fake.disconnectedMutex.Lock() + ret, specificReturn := fake.disconnectedReturnsOnCall[len(fake.disconnectedArgsForCall)] + fake.disconnectedArgsForCall = append(fake.disconnectedArgsForCall, struct { + }{}) + stub := fake.DisconnectedStub + fakeReturns := fake.disconnectedReturns + fake.recordInvocation("Disconnected", []interface{}{}) + fake.disconnectedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) DisconnectedCallCount() int { + fake.disconnectedMutex.RLock() + defer fake.disconnectedMutex.RUnlock() + return len(fake.disconnectedArgsForCall) +} + +func (fake *FakeLocalParticipant) DisconnectedCalls(stub func() <-chan struct{}) { + fake.disconnectedMutex.Lock() + defer fake.disconnectedMutex.Unlock() + fake.DisconnectedStub = stub +} + +func (fake *FakeLocalParticipant) DisconnectedReturns(result1 <-chan struct{}) { + fake.disconnectedMutex.Lock() + defer fake.disconnectedMutex.Unlock() + fake.DisconnectedStub = nil + fake.disconnectedReturns = struct { + result1 <-chan struct{} + }{result1} +} + +func (fake *FakeLocalParticipant) DisconnectedReturnsOnCall(i int, result1 <-chan struct{}) { + fake.disconnectedMutex.Lock() + defer fake.disconnectedMutex.Unlock() + fake.DisconnectedStub = nil + if fake.disconnectedReturnsOnCall == nil { + fake.disconnectedReturnsOnCall = make(map[int]struct { + result1 <-chan struct{} + }) + } + fake.disconnectedReturnsOnCall[i] = struct { + result1 <-chan struct{} + }{result1} +} + +func (fake *FakeLocalParticipant) GetAdaptiveStream() bool { + fake.getAdaptiveStreamMutex.Lock() + ret, specificReturn := fake.getAdaptiveStreamReturnsOnCall[len(fake.getAdaptiveStreamArgsForCall)] + fake.getAdaptiveStreamArgsForCall = append(fake.getAdaptiveStreamArgsForCall, struct { + }{}) + stub := fake.GetAdaptiveStreamStub + fakeReturns := fake.getAdaptiveStreamReturns + fake.recordInvocation("GetAdaptiveStream", []interface{}{}) + fake.getAdaptiveStreamMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetAdaptiveStreamCallCount() int { + fake.getAdaptiveStreamMutex.RLock() + defer fake.getAdaptiveStreamMutex.RUnlock() + return len(fake.getAdaptiveStreamArgsForCall) +} + +func (fake *FakeLocalParticipant) GetAdaptiveStreamCalls(stub func() bool) { + fake.getAdaptiveStreamMutex.Lock() + defer fake.getAdaptiveStreamMutex.Unlock() + fake.GetAdaptiveStreamStub = stub +} + +func (fake *FakeLocalParticipant) GetAdaptiveStreamReturns(result1 bool) { + fake.getAdaptiveStreamMutex.Lock() + defer fake.getAdaptiveStreamMutex.Unlock() + fake.GetAdaptiveStreamStub = nil + fake.getAdaptiveStreamReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) GetAdaptiveStreamReturnsOnCall(i int, result1 bool) { + fake.getAdaptiveStreamMutex.Lock() + defer fake.getAdaptiveStreamMutex.Unlock() + fake.GetAdaptiveStreamStub = nil + if fake.getAdaptiveStreamReturnsOnCall == nil { + fake.getAdaptiveStreamReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.getAdaptiveStreamReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) GetAnswer() (webrtc.SessionDescription, uint32, error) { + fake.getAnswerMutex.Lock() + ret, specificReturn := fake.getAnswerReturnsOnCall[len(fake.getAnswerArgsForCall)] + fake.getAnswerArgsForCall = append(fake.getAnswerArgsForCall, struct { + }{}) + stub := fake.GetAnswerStub + fakeReturns := fake.getAnswerReturns + fake.recordInvocation("GetAnswer", []interface{}{}) + fake.getAnswerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeLocalParticipant) GetAnswerCallCount() int { + fake.getAnswerMutex.RLock() + defer fake.getAnswerMutex.RUnlock() + return len(fake.getAnswerArgsForCall) +} + +func (fake *FakeLocalParticipant) GetAnswerCalls(stub func() (webrtc.SessionDescription, uint32, error)) { + fake.getAnswerMutex.Lock() + defer fake.getAnswerMutex.Unlock() + fake.GetAnswerStub = stub +} + +func (fake *FakeLocalParticipant) GetAnswerReturns(result1 webrtc.SessionDescription, result2 uint32, result3 error) { + fake.getAnswerMutex.Lock() + defer fake.getAnswerMutex.Unlock() + fake.GetAnswerStub = nil + fake.getAnswerReturns = struct { + result1 webrtc.SessionDescription + result2 uint32 + result3 error + }{result1, result2, result3} +} + +func (fake *FakeLocalParticipant) GetAnswerReturnsOnCall(i int, result1 webrtc.SessionDescription, result2 uint32, result3 error) { + fake.getAnswerMutex.Lock() + defer fake.getAnswerMutex.Unlock() + fake.GetAnswerStub = nil + if fake.getAnswerReturnsOnCall == nil { + fake.getAnswerReturnsOnCall = make(map[int]struct { + result1 webrtc.SessionDescription + result2 uint32 + result3 error + }) + } + fake.getAnswerReturnsOnCall[i] = struct { + result1 webrtc.SessionDescription + result2 uint32 + result3 error + }{result1, result2, result3} +} + +func (fake *FakeLocalParticipant) GetAudioLevel() (float64, bool) { + fake.getAudioLevelMutex.Lock() + ret, specificReturn := fake.getAudioLevelReturnsOnCall[len(fake.getAudioLevelArgsForCall)] + fake.getAudioLevelArgsForCall = append(fake.getAudioLevelArgsForCall, struct { + }{}) + stub := fake.GetAudioLevelStub + fakeReturns := fake.getAudioLevelReturns + fake.recordInvocation("GetAudioLevel", []interface{}{}) + fake.getAudioLevelMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipant) GetAudioLevelCallCount() int { + fake.getAudioLevelMutex.RLock() + defer fake.getAudioLevelMutex.RUnlock() + return len(fake.getAudioLevelArgsForCall) +} + +func (fake *FakeLocalParticipant) GetAudioLevelCalls(stub func() (float64, bool)) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = stub +} + +func (fake *FakeLocalParticipant) GetAudioLevelReturns(result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + fake.getAudioLevelReturns = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeLocalParticipant) GetAudioLevelReturnsOnCall(i int, result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + if fake.getAudioLevelReturnsOnCall == nil { + fake.getAudioLevelReturnsOnCall = make(map[int]struct { + result1 float64 + result2 bool + }) + } + fake.getAudioLevelReturnsOnCall[i] = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeLocalParticipant) GetBufferFactory() *buffer.Factory { + fake.getBufferFactoryMutex.Lock() + ret, specificReturn := fake.getBufferFactoryReturnsOnCall[len(fake.getBufferFactoryArgsForCall)] + fake.getBufferFactoryArgsForCall = append(fake.getBufferFactoryArgsForCall, struct { + }{}) + stub := fake.GetBufferFactoryStub + fakeReturns := fake.getBufferFactoryReturns + fake.recordInvocation("GetBufferFactory", []interface{}{}) + fake.getBufferFactoryMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetBufferFactoryCallCount() int { + fake.getBufferFactoryMutex.RLock() + defer fake.getBufferFactoryMutex.RUnlock() + return len(fake.getBufferFactoryArgsForCall) +} + +func (fake *FakeLocalParticipant) GetBufferFactoryCalls(stub func() *buffer.Factory) { + fake.getBufferFactoryMutex.Lock() + defer fake.getBufferFactoryMutex.Unlock() + fake.GetBufferFactoryStub = stub +} + +func (fake *FakeLocalParticipant) GetBufferFactoryReturns(result1 *buffer.Factory) { + fake.getBufferFactoryMutex.Lock() + defer fake.getBufferFactoryMutex.Unlock() + fake.GetBufferFactoryStub = nil + fake.getBufferFactoryReturns = struct { + result1 *buffer.Factory + }{result1} +} + +func (fake *FakeLocalParticipant) GetBufferFactoryReturnsOnCall(i int, result1 *buffer.Factory) { + fake.getBufferFactoryMutex.Lock() + defer fake.getBufferFactoryMutex.Unlock() + fake.GetBufferFactoryStub = nil + if fake.getBufferFactoryReturnsOnCall == nil { + fake.getBufferFactoryReturnsOnCall = make(map[int]struct { + result1 *buffer.Factory + }) + } + fake.getBufferFactoryReturnsOnCall[i] = struct { + result1 *buffer.Factory + }{result1} +} + +func (fake *FakeLocalParticipant) GetCachedDownTrack(arg1 livekit.TrackID) (*webrtc.RTPTransceiver, sfu.DownTrackState) { + fake.getCachedDownTrackMutex.Lock() + ret, specificReturn := fake.getCachedDownTrackReturnsOnCall[len(fake.getCachedDownTrackArgsForCall)] + fake.getCachedDownTrackArgsForCall = append(fake.getCachedDownTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.GetCachedDownTrackStub + fakeReturns := fake.getCachedDownTrackReturns + fake.recordInvocation("GetCachedDownTrack", []interface{}{arg1}) + fake.getCachedDownTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipant) GetCachedDownTrackCallCount() int { + fake.getCachedDownTrackMutex.RLock() + defer fake.getCachedDownTrackMutex.RUnlock() + return len(fake.getCachedDownTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) GetCachedDownTrackCalls(stub func(livekit.TrackID) (*webrtc.RTPTransceiver, sfu.DownTrackState)) { + fake.getCachedDownTrackMutex.Lock() + defer fake.getCachedDownTrackMutex.Unlock() + fake.GetCachedDownTrackStub = stub +} + +func (fake *FakeLocalParticipant) GetCachedDownTrackArgsForCall(i int) livekit.TrackID { + fake.getCachedDownTrackMutex.RLock() + defer fake.getCachedDownTrackMutex.RUnlock() + argsForCall := fake.getCachedDownTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) GetCachedDownTrackReturns(result1 *webrtc.RTPTransceiver, result2 sfu.DownTrackState) { + fake.getCachedDownTrackMutex.Lock() + defer fake.getCachedDownTrackMutex.Unlock() + fake.GetCachedDownTrackStub = nil + fake.getCachedDownTrackReturns = struct { + result1 *webrtc.RTPTransceiver + result2 sfu.DownTrackState + }{result1, result2} +} + +func (fake *FakeLocalParticipant) GetCachedDownTrackReturnsOnCall(i int, result1 *webrtc.RTPTransceiver, result2 sfu.DownTrackState) { + fake.getCachedDownTrackMutex.Lock() + defer fake.getCachedDownTrackMutex.Unlock() + fake.GetCachedDownTrackStub = nil + if fake.getCachedDownTrackReturnsOnCall == nil { + fake.getCachedDownTrackReturnsOnCall = make(map[int]struct { + result1 *webrtc.RTPTransceiver + result2 sfu.DownTrackState + }) + } + fake.getCachedDownTrackReturnsOnCall[i] = struct { + result1 *webrtc.RTPTransceiver + result2 sfu.DownTrackState + }{result1, result2} +} + +func (fake *FakeLocalParticipant) GetClientConfiguration() *livekit.ClientConfiguration { + fake.getClientConfigurationMutex.Lock() + ret, specificReturn := fake.getClientConfigurationReturnsOnCall[len(fake.getClientConfigurationArgsForCall)] + fake.getClientConfigurationArgsForCall = append(fake.getClientConfigurationArgsForCall, struct { + }{}) + stub := fake.GetClientConfigurationStub + fakeReturns := fake.getClientConfigurationReturns + fake.recordInvocation("GetClientConfiguration", []interface{}{}) + fake.getClientConfigurationMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetClientConfigurationCallCount() int { + fake.getClientConfigurationMutex.RLock() + defer fake.getClientConfigurationMutex.RUnlock() + return len(fake.getClientConfigurationArgsForCall) +} + +func (fake *FakeLocalParticipant) GetClientConfigurationCalls(stub func() *livekit.ClientConfiguration) { + fake.getClientConfigurationMutex.Lock() + defer fake.getClientConfigurationMutex.Unlock() + fake.GetClientConfigurationStub = stub +} + +func (fake *FakeLocalParticipant) GetClientConfigurationReturns(result1 *livekit.ClientConfiguration) { + fake.getClientConfigurationMutex.Lock() + defer fake.getClientConfigurationMutex.Unlock() + fake.GetClientConfigurationStub = nil + fake.getClientConfigurationReturns = struct { + result1 *livekit.ClientConfiguration + }{result1} +} + +func (fake *FakeLocalParticipant) GetClientConfigurationReturnsOnCall(i int, result1 *livekit.ClientConfiguration) { + fake.getClientConfigurationMutex.Lock() + defer fake.getClientConfigurationMutex.Unlock() + fake.GetClientConfigurationStub = nil + if fake.getClientConfigurationReturnsOnCall == nil { + fake.getClientConfigurationReturnsOnCall = make(map[int]struct { + result1 *livekit.ClientConfiguration + }) + } + fake.getClientConfigurationReturnsOnCall[i] = struct { + result1 *livekit.ClientConfiguration + }{result1} +} + +func (fake *FakeLocalParticipant) GetClientInfo() *livekit.ClientInfo { + fake.getClientInfoMutex.Lock() + ret, specificReturn := fake.getClientInfoReturnsOnCall[len(fake.getClientInfoArgsForCall)] + fake.getClientInfoArgsForCall = append(fake.getClientInfoArgsForCall, struct { + }{}) + stub := fake.GetClientInfoStub + fakeReturns := fake.getClientInfoReturns + fake.recordInvocation("GetClientInfo", []interface{}{}) + fake.getClientInfoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetClientInfoCallCount() int { + fake.getClientInfoMutex.RLock() + defer fake.getClientInfoMutex.RUnlock() + return len(fake.getClientInfoArgsForCall) +} + +func (fake *FakeLocalParticipant) GetClientInfoCalls(stub func() *livekit.ClientInfo) { + fake.getClientInfoMutex.Lock() + defer fake.getClientInfoMutex.Unlock() + fake.GetClientInfoStub = stub +} + +func (fake *FakeLocalParticipant) GetClientInfoReturns(result1 *livekit.ClientInfo) { + fake.getClientInfoMutex.Lock() + defer fake.getClientInfoMutex.Unlock() + fake.GetClientInfoStub = nil + fake.getClientInfoReturns = struct { + result1 *livekit.ClientInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetClientInfoReturnsOnCall(i int, result1 *livekit.ClientInfo) { + fake.getClientInfoMutex.Lock() + defer fake.getClientInfoMutex.Unlock() + fake.GetClientInfoStub = nil + if fake.getClientInfoReturnsOnCall == nil { + fake.getClientInfoReturnsOnCall = make(map[int]struct { + result1 *livekit.ClientInfo + }) + } + fake.getClientInfoReturnsOnCall[i] = struct { + result1 *livekit.ClientInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetConnectionQuality() *livekit.ConnectionQualityInfo { + fake.getConnectionQualityMutex.Lock() + ret, specificReturn := fake.getConnectionQualityReturnsOnCall[len(fake.getConnectionQualityArgsForCall)] + fake.getConnectionQualityArgsForCall = append(fake.getConnectionQualityArgsForCall, struct { + }{}) + stub := fake.GetConnectionQualityStub + fakeReturns := fake.getConnectionQualityReturns + fake.recordInvocation("GetConnectionQuality", []interface{}{}) + fake.getConnectionQualityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetConnectionQualityCallCount() int { + fake.getConnectionQualityMutex.RLock() + defer fake.getConnectionQualityMutex.RUnlock() + return len(fake.getConnectionQualityArgsForCall) +} + +func (fake *FakeLocalParticipant) GetConnectionQualityCalls(stub func() *livekit.ConnectionQualityInfo) { + fake.getConnectionQualityMutex.Lock() + defer fake.getConnectionQualityMutex.Unlock() + fake.GetConnectionQualityStub = stub +} + +func (fake *FakeLocalParticipant) GetConnectionQualityReturns(result1 *livekit.ConnectionQualityInfo) { + fake.getConnectionQualityMutex.Lock() + defer fake.getConnectionQualityMutex.Unlock() + fake.GetConnectionQualityStub = nil + fake.getConnectionQualityReturns = struct { + result1 *livekit.ConnectionQualityInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetConnectionQualityReturnsOnCall(i int, result1 *livekit.ConnectionQualityInfo) { + fake.getConnectionQualityMutex.Lock() + defer fake.getConnectionQualityMutex.Unlock() + fake.GetConnectionQualityStub = nil + if fake.getConnectionQualityReturnsOnCall == nil { + fake.getConnectionQualityReturnsOnCall = make(map[int]struct { + result1 *livekit.ConnectionQualityInfo + }) + } + fake.getConnectionQualityReturnsOnCall[i] = struct { + result1 *livekit.ConnectionQualityInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetCountry() string { + fake.getCountryMutex.Lock() + ret, specificReturn := fake.getCountryReturnsOnCall[len(fake.getCountryArgsForCall)] + fake.getCountryArgsForCall = append(fake.getCountryArgsForCall, struct { + }{}) + stub := fake.GetCountryStub + fakeReturns := fake.getCountryReturns + fake.recordInvocation("GetCountry", []interface{}{}) + fake.getCountryMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetCountryCallCount() int { + fake.getCountryMutex.RLock() + defer fake.getCountryMutex.RUnlock() + return len(fake.getCountryArgsForCall) +} + +func (fake *FakeLocalParticipant) GetCountryCalls(stub func() string) { + fake.getCountryMutex.Lock() + defer fake.getCountryMutex.Unlock() + fake.GetCountryStub = stub +} + +func (fake *FakeLocalParticipant) GetCountryReturns(result1 string) { + fake.getCountryMutex.Lock() + defer fake.getCountryMutex.Unlock() + fake.GetCountryStub = nil + fake.getCountryReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeLocalParticipant) GetCountryReturnsOnCall(i int, result1 string) { + fake.getCountryMutex.Lock() + defer fake.getCountryMutex.Unlock() + fake.GetCountryStub = nil + if fake.getCountryReturnsOnCall == nil { + fake.getCountryReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.getCountryReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeLocalParticipant) GetDataTrackTransport() types.DataTrackTransport { + fake.getDataTrackTransportMutex.Lock() + ret, specificReturn := fake.getDataTrackTransportReturnsOnCall[len(fake.getDataTrackTransportArgsForCall)] + fake.getDataTrackTransportArgsForCall = append(fake.getDataTrackTransportArgsForCall, struct { + }{}) + stub := fake.GetDataTrackTransportStub + fakeReturns := fake.getDataTrackTransportReturns + fake.recordInvocation("GetDataTrackTransport", []interface{}{}) + fake.getDataTrackTransportMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetDataTrackTransportCallCount() int { + fake.getDataTrackTransportMutex.RLock() + defer fake.getDataTrackTransportMutex.RUnlock() + return len(fake.getDataTrackTransportArgsForCall) +} + +func (fake *FakeLocalParticipant) GetDataTrackTransportCalls(stub func() types.DataTrackTransport) { + fake.getDataTrackTransportMutex.Lock() + defer fake.getDataTrackTransportMutex.Unlock() + fake.GetDataTrackTransportStub = stub +} + +func (fake *FakeLocalParticipant) GetDataTrackTransportReturns(result1 types.DataTrackTransport) { + fake.getDataTrackTransportMutex.Lock() + defer fake.getDataTrackTransportMutex.Unlock() + fake.GetDataTrackTransportStub = nil + fake.getDataTrackTransportReturns = struct { + result1 types.DataTrackTransport + }{result1} +} + +func (fake *FakeLocalParticipant) GetDataTrackTransportReturnsOnCall(i int, result1 types.DataTrackTransport) { + fake.getDataTrackTransportMutex.Lock() + defer fake.getDataTrackTransportMutex.Unlock() + fake.GetDataTrackTransportStub = nil + if fake.getDataTrackTransportReturnsOnCall == nil { + fake.getDataTrackTransportReturnsOnCall = make(map[int]struct { + result1 types.DataTrackTransport + }) + } + fake.getDataTrackTransportReturnsOnCall[i] = struct { + result1 types.DataTrackTransport + }{result1} +} + +func (fake *FakeLocalParticipant) GetDisableSenderReportPassThrough() bool { + fake.getDisableSenderReportPassThroughMutex.Lock() + ret, specificReturn := fake.getDisableSenderReportPassThroughReturnsOnCall[len(fake.getDisableSenderReportPassThroughArgsForCall)] + fake.getDisableSenderReportPassThroughArgsForCall = append(fake.getDisableSenderReportPassThroughArgsForCall, struct { + }{}) + stub := fake.GetDisableSenderReportPassThroughStub + fakeReturns := fake.getDisableSenderReportPassThroughReturns + fake.recordInvocation("GetDisableSenderReportPassThrough", []interface{}{}) + fake.getDisableSenderReportPassThroughMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetDisableSenderReportPassThroughCallCount() int { + fake.getDisableSenderReportPassThroughMutex.RLock() + defer fake.getDisableSenderReportPassThroughMutex.RUnlock() + return len(fake.getDisableSenderReportPassThroughArgsForCall) +} + +func (fake *FakeLocalParticipant) GetDisableSenderReportPassThroughCalls(stub func() bool) { + fake.getDisableSenderReportPassThroughMutex.Lock() + defer fake.getDisableSenderReportPassThroughMutex.Unlock() + fake.GetDisableSenderReportPassThroughStub = stub +} + +func (fake *FakeLocalParticipant) GetDisableSenderReportPassThroughReturns(result1 bool) { + fake.getDisableSenderReportPassThroughMutex.Lock() + defer fake.getDisableSenderReportPassThroughMutex.Unlock() + fake.GetDisableSenderReportPassThroughStub = nil + fake.getDisableSenderReportPassThroughReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) GetDisableSenderReportPassThroughReturnsOnCall(i int, result1 bool) { + fake.getDisableSenderReportPassThroughMutex.Lock() + defer fake.getDisableSenderReportPassThroughMutex.Unlock() + fake.GetDisableSenderReportPassThroughStub = nil + if fake.getDisableSenderReportPassThroughReturnsOnCall == nil { + fake.getDisableSenderReportPassThroughReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.getDisableSenderReportPassThroughReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) GetEnabledPublishCodecs() []*livekit.Codec { + fake.getEnabledPublishCodecsMutex.Lock() + ret, specificReturn := fake.getEnabledPublishCodecsReturnsOnCall[len(fake.getEnabledPublishCodecsArgsForCall)] + fake.getEnabledPublishCodecsArgsForCall = append(fake.getEnabledPublishCodecsArgsForCall, struct { + }{}) + stub := fake.GetEnabledPublishCodecsStub + fakeReturns := fake.getEnabledPublishCodecsReturns + fake.recordInvocation("GetEnabledPublishCodecs", []interface{}{}) + fake.getEnabledPublishCodecsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetEnabledPublishCodecsCallCount() int { + fake.getEnabledPublishCodecsMutex.RLock() + defer fake.getEnabledPublishCodecsMutex.RUnlock() + return len(fake.getEnabledPublishCodecsArgsForCall) +} + +func (fake *FakeLocalParticipant) GetEnabledPublishCodecsCalls(stub func() []*livekit.Codec) { + fake.getEnabledPublishCodecsMutex.Lock() + defer fake.getEnabledPublishCodecsMutex.Unlock() + fake.GetEnabledPublishCodecsStub = stub +} + +func (fake *FakeLocalParticipant) GetEnabledPublishCodecsReturns(result1 []*livekit.Codec) { + fake.getEnabledPublishCodecsMutex.Lock() + defer fake.getEnabledPublishCodecsMutex.Unlock() + fake.GetEnabledPublishCodecsStub = nil + fake.getEnabledPublishCodecsReturns = struct { + result1 []*livekit.Codec + }{result1} +} + +func (fake *FakeLocalParticipant) GetEnabledPublishCodecsReturnsOnCall(i int, result1 []*livekit.Codec) { + fake.getEnabledPublishCodecsMutex.Lock() + defer fake.getEnabledPublishCodecsMutex.Unlock() + fake.GetEnabledPublishCodecsStub = nil + if fake.getEnabledPublishCodecsReturnsOnCall == nil { + fake.getEnabledPublishCodecsReturnsOnCall = make(map[int]struct { + result1 []*livekit.Codec + }) + } + fake.getEnabledPublishCodecsReturnsOnCall[i] = struct { + result1 []*livekit.Codec + }{result1} +} + +func (fake *FakeLocalParticipant) GetICEConfig() *livekit.ICEConfig { + fake.getICEConfigMutex.Lock() + ret, specificReturn := fake.getICEConfigReturnsOnCall[len(fake.getICEConfigArgsForCall)] + fake.getICEConfigArgsForCall = append(fake.getICEConfigArgsForCall, struct { + }{}) + stub := fake.GetICEConfigStub + fakeReturns := fake.getICEConfigReturns + fake.recordInvocation("GetICEConfig", []interface{}{}) + fake.getICEConfigMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetICEConfigCallCount() int { + fake.getICEConfigMutex.RLock() + defer fake.getICEConfigMutex.RUnlock() + return len(fake.getICEConfigArgsForCall) +} + +func (fake *FakeLocalParticipant) GetICEConfigCalls(stub func() *livekit.ICEConfig) { + fake.getICEConfigMutex.Lock() + defer fake.getICEConfigMutex.Unlock() + fake.GetICEConfigStub = stub +} + +func (fake *FakeLocalParticipant) GetICEConfigReturns(result1 *livekit.ICEConfig) { + fake.getICEConfigMutex.Lock() + defer fake.getICEConfigMutex.Unlock() + fake.GetICEConfigStub = nil + fake.getICEConfigReturns = struct { + result1 *livekit.ICEConfig + }{result1} +} + +func (fake *FakeLocalParticipant) GetICEConfigReturnsOnCall(i int, result1 *livekit.ICEConfig) { + fake.getICEConfigMutex.Lock() + defer fake.getICEConfigMutex.Unlock() + fake.GetICEConfigStub = nil + if fake.getICEConfigReturnsOnCall == nil { + fake.getICEConfigReturnsOnCall = make(map[int]struct { + result1 *livekit.ICEConfig + }) + } + fake.getICEConfigReturnsOnCall[i] = struct { + result1 *livekit.ICEConfig + }{result1} +} + +func (fake *FakeLocalParticipant) GetICEConnectionInfo() []*types.ICEConnectionInfo { + fake.getICEConnectionInfoMutex.Lock() + ret, specificReturn := fake.getICEConnectionInfoReturnsOnCall[len(fake.getICEConnectionInfoArgsForCall)] + fake.getICEConnectionInfoArgsForCall = append(fake.getICEConnectionInfoArgsForCall, struct { + }{}) + stub := fake.GetICEConnectionInfoStub + fakeReturns := fake.getICEConnectionInfoReturns + fake.recordInvocation("GetICEConnectionInfo", []interface{}{}) + fake.getICEConnectionInfoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetICEConnectionInfoCallCount() int { + fake.getICEConnectionInfoMutex.RLock() + defer fake.getICEConnectionInfoMutex.RUnlock() + return len(fake.getICEConnectionInfoArgsForCall) +} + +func (fake *FakeLocalParticipant) GetICEConnectionInfoCalls(stub func() []*types.ICEConnectionInfo) { + fake.getICEConnectionInfoMutex.Lock() + defer fake.getICEConnectionInfoMutex.Unlock() + fake.GetICEConnectionInfoStub = stub +} + +func (fake *FakeLocalParticipant) GetICEConnectionInfoReturns(result1 []*types.ICEConnectionInfo) { + fake.getICEConnectionInfoMutex.Lock() + defer fake.getICEConnectionInfoMutex.Unlock() + fake.GetICEConnectionInfoStub = nil + fake.getICEConnectionInfoReturns = struct { + result1 []*types.ICEConnectionInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetICEConnectionInfoReturnsOnCall(i int, result1 []*types.ICEConnectionInfo) { + fake.getICEConnectionInfoMutex.Lock() + defer fake.getICEConnectionInfoMutex.Unlock() + fake.GetICEConnectionInfoStub = nil + if fake.getICEConnectionInfoReturnsOnCall == nil { + fake.getICEConnectionInfoReturnsOnCall = make(map[int]struct { + result1 []*types.ICEConnectionInfo + }) + } + fake.getICEConnectionInfoReturnsOnCall[i] = struct { + result1 []*types.ICEConnectionInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetLastReliableSequence(arg1 bool) uint32 { + fake.getLastReliableSequenceMutex.Lock() + ret, specificReturn := fake.getLastReliableSequenceReturnsOnCall[len(fake.getLastReliableSequenceArgsForCall)] + fake.getLastReliableSequenceArgsForCall = append(fake.getLastReliableSequenceArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.GetLastReliableSequenceStub + fakeReturns := fake.getLastReliableSequenceReturns + fake.recordInvocation("GetLastReliableSequence", []interface{}{arg1}) + fake.getLastReliableSequenceMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetLastReliableSequenceCallCount() int { + fake.getLastReliableSequenceMutex.RLock() + defer fake.getLastReliableSequenceMutex.RUnlock() + return len(fake.getLastReliableSequenceArgsForCall) +} + +func (fake *FakeLocalParticipant) GetLastReliableSequenceCalls(stub func(bool) uint32) { + fake.getLastReliableSequenceMutex.Lock() + defer fake.getLastReliableSequenceMutex.Unlock() + fake.GetLastReliableSequenceStub = stub +} + +func (fake *FakeLocalParticipant) GetLastReliableSequenceArgsForCall(i int) bool { + fake.getLastReliableSequenceMutex.RLock() + defer fake.getLastReliableSequenceMutex.RUnlock() + argsForCall := fake.getLastReliableSequenceArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) GetLastReliableSequenceReturns(result1 uint32) { + fake.getLastReliableSequenceMutex.Lock() + defer fake.getLastReliableSequenceMutex.Unlock() + fake.GetLastReliableSequenceStub = nil + fake.getLastReliableSequenceReturns = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeLocalParticipant) GetLastReliableSequenceReturnsOnCall(i int, result1 uint32) { + fake.getLastReliableSequenceMutex.Lock() + defer fake.getLastReliableSequenceMutex.Unlock() + fake.GetLastReliableSequenceStub = nil + if fake.getLastReliableSequenceReturnsOnCall == nil { + fake.getLastReliableSequenceReturnsOnCall = make(map[int]struct { + result1 uint32 + }) + } + fake.getLastReliableSequenceReturnsOnCall[i] = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeLocalParticipant) GetLogger() logger.Logger { + fake.getLoggerMutex.Lock() + ret, specificReturn := fake.getLoggerReturnsOnCall[len(fake.getLoggerArgsForCall)] + fake.getLoggerArgsForCall = append(fake.getLoggerArgsForCall, struct { + }{}) + stub := fake.GetLoggerStub + fakeReturns := fake.getLoggerReturns + fake.recordInvocation("GetLogger", []interface{}{}) + fake.getLoggerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetLoggerCallCount() int { + fake.getLoggerMutex.RLock() + defer fake.getLoggerMutex.RUnlock() + return len(fake.getLoggerArgsForCall) +} + +func (fake *FakeLocalParticipant) GetLoggerCalls(stub func() logger.Logger) { + fake.getLoggerMutex.Lock() + defer fake.getLoggerMutex.Unlock() + fake.GetLoggerStub = stub +} + +func (fake *FakeLocalParticipant) GetLoggerReturns(result1 logger.Logger) { + fake.getLoggerMutex.Lock() + defer fake.getLoggerMutex.Unlock() + fake.GetLoggerStub = nil + fake.getLoggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeLocalParticipant) GetLoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.getLoggerMutex.Lock() + defer fake.getLoggerMutex.Unlock() + fake.GetLoggerStub = nil + if fake.getLoggerReturnsOnCall == nil { + fake.getLoggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.getLoggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeLocalParticipant) GetLoggerResolver() logger.DeferredFieldResolver { + fake.getLoggerResolverMutex.Lock() + ret, specificReturn := fake.getLoggerResolverReturnsOnCall[len(fake.getLoggerResolverArgsForCall)] + fake.getLoggerResolverArgsForCall = append(fake.getLoggerResolverArgsForCall, struct { + }{}) + stub := fake.GetLoggerResolverStub + fakeReturns := fake.getLoggerResolverReturns + fake.recordInvocation("GetLoggerResolver", []interface{}{}) + fake.getLoggerResolverMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetLoggerResolverCallCount() int { + fake.getLoggerResolverMutex.RLock() + defer fake.getLoggerResolverMutex.RUnlock() + return len(fake.getLoggerResolverArgsForCall) +} + +func (fake *FakeLocalParticipant) GetLoggerResolverCalls(stub func() logger.DeferredFieldResolver) { + fake.getLoggerResolverMutex.Lock() + defer fake.getLoggerResolverMutex.Unlock() + fake.GetLoggerResolverStub = stub +} + +func (fake *FakeLocalParticipant) GetLoggerResolverReturns(result1 logger.DeferredFieldResolver) { + fake.getLoggerResolverMutex.Lock() + defer fake.getLoggerResolverMutex.Unlock() + fake.GetLoggerResolverStub = nil + fake.getLoggerResolverReturns = struct { + result1 logger.DeferredFieldResolver + }{result1} +} + +func (fake *FakeLocalParticipant) GetLoggerResolverReturnsOnCall(i int, result1 logger.DeferredFieldResolver) { + fake.getLoggerResolverMutex.Lock() + defer fake.getLoggerResolverMutex.Unlock() + fake.GetLoggerResolverStub = nil + if fake.getLoggerResolverReturnsOnCall == nil { + fake.getLoggerResolverReturnsOnCall = make(map[int]struct { + result1 logger.DeferredFieldResolver + }) + } + fake.getLoggerResolverReturnsOnCall[i] = struct { + result1 logger.DeferredFieldResolver + }{result1} +} + +func (fake *FakeLocalParticipant) GetNextSubscribedDataTrackHandle() uint16 { + fake.getNextSubscribedDataTrackHandleMutex.Lock() + ret, specificReturn := fake.getNextSubscribedDataTrackHandleReturnsOnCall[len(fake.getNextSubscribedDataTrackHandleArgsForCall)] + fake.getNextSubscribedDataTrackHandleArgsForCall = append(fake.getNextSubscribedDataTrackHandleArgsForCall, struct { + }{}) + stub := fake.GetNextSubscribedDataTrackHandleStub + fakeReturns := fake.getNextSubscribedDataTrackHandleReturns + fake.recordInvocation("GetNextSubscribedDataTrackHandle", []interface{}{}) + fake.getNextSubscribedDataTrackHandleMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetNextSubscribedDataTrackHandleCallCount() int { + fake.getNextSubscribedDataTrackHandleMutex.RLock() + defer fake.getNextSubscribedDataTrackHandleMutex.RUnlock() + return len(fake.getNextSubscribedDataTrackHandleArgsForCall) +} + +func (fake *FakeLocalParticipant) GetNextSubscribedDataTrackHandleCalls(stub func() uint16) { + fake.getNextSubscribedDataTrackHandleMutex.Lock() + defer fake.getNextSubscribedDataTrackHandleMutex.Unlock() + fake.GetNextSubscribedDataTrackHandleStub = stub +} + +func (fake *FakeLocalParticipant) GetNextSubscribedDataTrackHandleReturns(result1 uint16) { + fake.getNextSubscribedDataTrackHandleMutex.Lock() + defer fake.getNextSubscribedDataTrackHandleMutex.Unlock() + fake.GetNextSubscribedDataTrackHandleStub = nil + fake.getNextSubscribedDataTrackHandleReturns = struct { + result1 uint16 + }{result1} +} + +func (fake *FakeLocalParticipant) GetNextSubscribedDataTrackHandleReturnsOnCall(i int, result1 uint16) { + fake.getNextSubscribedDataTrackHandleMutex.Lock() + defer fake.getNextSubscribedDataTrackHandleMutex.Unlock() + fake.GetNextSubscribedDataTrackHandleStub = nil + if fake.getNextSubscribedDataTrackHandleReturnsOnCall == nil { + fake.getNextSubscribedDataTrackHandleReturnsOnCall = make(map[int]struct { + result1 uint16 + }) + } + fake.getNextSubscribedDataTrackHandleReturnsOnCall[i] = struct { + result1 uint16 + }{result1} +} + +func (fake *FakeLocalParticipant) GetPacer() pacer.Pacer { + fake.getPacerMutex.Lock() + ret, specificReturn := fake.getPacerReturnsOnCall[len(fake.getPacerArgsForCall)] + fake.getPacerArgsForCall = append(fake.getPacerArgsForCall, struct { + }{}) + stub := fake.GetPacerStub + fakeReturns := fake.getPacerReturns + fake.recordInvocation("GetPacer", []interface{}{}) + fake.getPacerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPacerCallCount() int { + fake.getPacerMutex.RLock() + defer fake.getPacerMutex.RUnlock() + return len(fake.getPacerArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPacerCalls(stub func() pacer.Pacer) { + fake.getPacerMutex.Lock() + defer fake.getPacerMutex.Unlock() + fake.GetPacerStub = stub +} + +func (fake *FakeLocalParticipant) GetPacerReturns(result1 pacer.Pacer) { + fake.getPacerMutex.Lock() + defer fake.getPacerMutex.Unlock() + fake.GetPacerStub = nil + fake.getPacerReturns = struct { + result1 pacer.Pacer + }{result1} +} + +func (fake *FakeLocalParticipant) GetPacerReturnsOnCall(i int, result1 pacer.Pacer) { + fake.getPacerMutex.Lock() + defer fake.getPacerMutex.Unlock() + fake.GetPacerStub = nil + if fake.getPacerReturnsOnCall == nil { + fake.getPacerReturnsOnCall = make(map[int]struct { + result1 pacer.Pacer + }) + } + fake.getPacerReturnsOnCall[i] = struct { + result1 pacer.Pacer + }{result1} +} + +func (fake *FakeLocalParticipant) GetParticipantListener() types.ParticipantListener { + fake.getParticipantListenerMutex.Lock() + ret, specificReturn := fake.getParticipantListenerReturnsOnCall[len(fake.getParticipantListenerArgsForCall)] + fake.getParticipantListenerArgsForCall = append(fake.getParticipantListenerArgsForCall, struct { + }{}) + stub := fake.GetParticipantListenerStub + fakeReturns := fake.getParticipantListenerReturns + fake.recordInvocation("GetParticipantListener", []interface{}{}) + fake.getParticipantListenerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetParticipantListenerCallCount() int { + fake.getParticipantListenerMutex.RLock() + defer fake.getParticipantListenerMutex.RUnlock() + return len(fake.getParticipantListenerArgsForCall) +} + +func (fake *FakeLocalParticipant) GetParticipantListenerCalls(stub func() types.ParticipantListener) { + fake.getParticipantListenerMutex.Lock() + defer fake.getParticipantListenerMutex.Unlock() + fake.GetParticipantListenerStub = stub +} + +func (fake *FakeLocalParticipant) GetParticipantListenerReturns(result1 types.ParticipantListener) { + fake.getParticipantListenerMutex.Lock() + defer fake.getParticipantListenerMutex.Unlock() + fake.GetParticipantListenerStub = nil + fake.getParticipantListenerReturns = struct { + result1 types.ParticipantListener + }{result1} +} + +func (fake *FakeLocalParticipant) GetParticipantListenerReturnsOnCall(i int, result1 types.ParticipantListener) { + fake.getParticipantListenerMutex.Lock() + defer fake.getParticipantListenerMutex.Unlock() + fake.GetParticipantListenerStub = nil + if fake.getParticipantListenerReturnsOnCall == nil { + fake.getParticipantListenerReturnsOnCall = make(map[int]struct { + result1 types.ParticipantListener + }) + } + fake.getParticipantListenerReturnsOnCall[i] = struct { + result1 types.ParticipantListener + }{result1} +} + +func (fake *FakeLocalParticipant) GetPendingTrack(arg1 livekit.TrackID) *livekit.TrackInfo { + fake.getPendingTrackMutex.Lock() + ret, specificReturn := fake.getPendingTrackReturnsOnCall[len(fake.getPendingTrackArgsForCall)] + fake.getPendingTrackArgsForCall = append(fake.getPendingTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.GetPendingTrackStub + fakeReturns := fake.getPendingTrackReturns + fake.recordInvocation("GetPendingTrack", []interface{}{arg1}) + fake.getPendingTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPendingTrackCallCount() int { + fake.getPendingTrackMutex.RLock() + defer fake.getPendingTrackMutex.RUnlock() + return len(fake.getPendingTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPendingTrackCalls(stub func(livekit.TrackID) *livekit.TrackInfo) { + fake.getPendingTrackMutex.Lock() + defer fake.getPendingTrackMutex.Unlock() + fake.GetPendingTrackStub = stub +} + +func (fake *FakeLocalParticipant) GetPendingTrackArgsForCall(i int) livekit.TrackID { + fake.getPendingTrackMutex.RLock() + defer fake.getPendingTrackMutex.RUnlock() + argsForCall := fake.getPendingTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) GetPendingTrackReturns(result1 *livekit.TrackInfo) { + fake.getPendingTrackMutex.Lock() + defer fake.getPendingTrackMutex.Unlock() + fake.GetPendingTrackStub = nil + fake.getPendingTrackReturns = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetPendingTrackReturnsOnCall(i int, result1 *livekit.TrackInfo) { + fake.getPendingTrackMutex.Lock() + defer fake.getPendingTrackMutex.Unlock() + fake.GetPendingTrackStub = nil + if fake.getPendingTrackReturnsOnCall == nil { + fake.getPendingTrackReturnsOnCall = make(map[int]struct { + result1 *livekit.TrackInfo + }) + } + fake.getPendingTrackReturnsOnCall[i] = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalParticipant) GetPlayoutDelayConfig() *livekit.PlayoutDelay { + fake.getPlayoutDelayConfigMutex.Lock() + ret, specificReturn := fake.getPlayoutDelayConfigReturnsOnCall[len(fake.getPlayoutDelayConfigArgsForCall)] + fake.getPlayoutDelayConfigArgsForCall = append(fake.getPlayoutDelayConfigArgsForCall, struct { + }{}) + stub := fake.GetPlayoutDelayConfigStub + fakeReturns := fake.getPlayoutDelayConfigReturns + fake.recordInvocation("GetPlayoutDelayConfig", []interface{}{}) + fake.getPlayoutDelayConfigMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPlayoutDelayConfigCallCount() int { + fake.getPlayoutDelayConfigMutex.RLock() + defer fake.getPlayoutDelayConfigMutex.RUnlock() + return len(fake.getPlayoutDelayConfigArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPlayoutDelayConfigCalls(stub func() *livekit.PlayoutDelay) { + fake.getPlayoutDelayConfigMutex.Lock() + defer fake.getPlayoutDelayConfigMutex.Unlock() + fake.GetPlayoutDelayConfigStub = stub +} + +func (fake *FakeLocalParticipant) GetPlayoutDelayConfigReturns(result1 *livekit.PlayoutDelay) { + fake.getPlayoutDelayConfigMutex.Lock() + defer fake.getPlayoutDelayConfigMutex.Unlock() + fake.GetPlayoutDelayConfigStub = nil + fake.getPlayoutDelayConfigReturns = struct { + result1 *livekit.PlayoutDelay + }{result1} +} + +func (fake *FakeLocalParticipant) GetPlayoutDelayConfigReturnsOnCall(i int, result1 *livekit.PlayoutDelay) { + fake.getPlayoutDelayConfigMutex.Lock() + defer fake.getPlayoutDelayConfigMutex.Unlock() + fake.GetPlayoutDelayConfigStub = nil + if fake.getPlayoutDelayConfigReturnsOnCall == nil { + fake.getPlayoutDelayConfigReturnsOnCall = make(map[int]struct { + result1 *livekit.PlayoutDelay + }) + } + fake.getPlayoutDelayConfigReturnsOnCall[i] = struct { + result1 *livekit.PlayoutDelay + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedDataTrack(arg1 uint16) types.DataTrack { + fake.getPublishedDataTrackMutex.Lock() + ret, specificReturn := fake.getPublishedDataTrackReturnsOnCall[len(fake.getPublishedDataTrackArgsForCall)] + fake.getPublishedDataTrackArgsForCall = append(fake.getPublishedDataTrackArgsForCall, struct { + arg1 uint16 + }{arg1}) + stub := fake.GetPublishedDataTrackStub + fakeReturns := fake.getPublishedDataTrackReturns + fake.recordInvocation("GetPublishedDataTrack", []interface{}{arg1}) + fake.getPublishedDataTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPublishedDataTrackCallCount() int { + fake.getPublishedDataTrackMutex.RLock() + defer fake.getPublishedDataTrackMutex.RUnlock() + return len(fake.getPublishedDataTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPublishedDataTrackCalls(stub func(uint16) types.DataTrack) { + fake.getPublishedDataTrackMutex.Lock() + defer fake.getPublishedDataTrackMutex.Unlock() + fake.GetPublishedDataTrackStub = stub +} + +func (fake *FakeLocalParticipant) GetPublishedDataTrackArgsForCall(i int) uint16 { + fake.getPublishedDataTrackMutex.RLock() + defer fake.getPublishedDataTrackMutex.RUnlock() + argsForCall := fake.getPublishedDataTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) GetPublishedDataTrackReturns(result1 types.DataTrack) { + fake.getPublishedDataTrackMutex.Lock() + defer fake.getPublishedDataTrackMutex.Unlock() + fake.GetPublishedDataTrackStub = nil + fake.getPublishedDataTrackReturns = struct { + result1 types.DataTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedDataTrackReturnsOnCall(i int, result1 types.DataTrack) { + fake.getPublishedDataTrackMutex.Lock() + defer fake.getPublishedDataTrackMutex.Unlock() + fake.GetPublishedDataTrackStub = nil + if fake.getPublishedDataTrackReturnsOnCall == nil { + fake.getPublishedDataTrackReturnsOnCall = make(map[int]struct { + result1 types.DataTrack + }) + } + fake.getPublishedDataTrackReturnsOnCall[i] = struct { + result1 types.DataTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedDataTracks() []types.DataTrack { + fake.getPublishedDataTracksMutex.Lock() + ret, specificReturn := fake.getPublishedDataTracksReturnsOnCall[len(fake.getPublishedDataTracksArgsForCall)] + fake.getPublishedDataTracksArgsForCall = append(fake.getPublishedDataTracksArgsForCall, struct { + }{}) + stub := fake.GetPublishedDataTracksStub + fakeReturns := fake.getPublishedDataTracksReturns + fake.recordInvocation("GetPublishedDataTracks", []interface{}{}) + fake.getPublishedDataTracksMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPublishedDataTracksCallCount() int { + fake.getPublishedDataTracksMutex.RLock() + defer fake.getPublishedDataTracksMutex.RUnlock() + return len(fake.getPublishedDataTracksArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPublishedDataTracksCalls(stub func() []types.DataTrack) { + fake.getPublishedDataTracksMutex.Lock() + defer fake.getPublishedDataTracksMutex.Unlock() + fake.GetPublishedDataTracksStub = stub +} + +func (fake *FakeLocalParticipant) GetPublishedDataTracksReturns(result1 []types.DataTrack) { + fake.getPublishedDataTracksMutex.Lock() + defer fake.getPublishedDataTracksMutex.Unlock() + fake.GetPublishedDataTracksStub = nil + fake.getPublishedDataTracksReturns = struct { + result1 []types.DataTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedDataTracksReturnsOnCall(i int, result1 []types.DataTrack) { + fake.getPublishedDataTracksMutex.Lock() + defer fake.getPublishedDataTracksMutex.Unlock() + fake.GetPublishedDataTracksStub = nil + if fake.getPublishedDataTracksReturnsOnCall == nil { + fake.getPublishedDataTracksReturnsOnCall = make(map[int]struct { + result1 []types.DataTrack + }) + } + fake.getPublishedDataTracksReturnsOnCall[i] = struct { + result1 []types.DataTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedTrack(arg1 livekit.TrackID) types.MediaTrack { + fake.getPublishedTrackMutex.Lock() + ret, specificReturn := fake.getPublishedTrackReturnsOnCall[len(fake.getPublishedTrackArgsForCall)] + fake.getPublishedTrackArgsForCall = append(fake.getPublishedTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.GetPublishedTrackStub + fakeReturns := fake.getPublishedTrackReturns + fake.recordInvocation("GetPublishedTrack", []interface{}{arg1}) + fake.getPublishedTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPublishedTrackCallCount() int { + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() + return len(fake.getPublishedTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPublishedTrackCalls(stub func(livekit.TrackID) types.MediaTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = stub +} + +func (fake *FakeLocalParticipant) GetPublishedTrackArgsForCall(i int) livekit.TrackID { + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() + argsForCall := fake.getPublishedTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) GetPublishedTrackReturns(result1 types.MediaTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = nil + fake.getPublishedTrackReturns = struct { + result1 types.MediaTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedTrackReturnsOnCall(i int, result1 types.MediaTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = nil + if fake.getPublishedTrackReturnsOnCall == nil { + fake.getPublishedTrackReturnsOnCall = make(map[int]struct { + result1 types.MediaTrack + }) + } + fake.getPublishedTrackReturnsOnCall[i] = struct { + result1 types.MediaTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedTracks() []types.MediaTrack { + fake.getPublishedTracksMutex.Lock() + ret, specificReturn := fake.getPublishedTracksReturnsOnCall[len(fake.getPublishedTracksArgsForCall)] + fake.getPublishedTracksArgsForCall = append(fake.getPublishedTracksArgsForCall, struct { + }{}) + stub := fake.GetPublishedTracksStub + fakeReturns := fake.getPublishedTracksReturns + fake.recordInvocation("GetPublishedTracks", []interface{}{}) + fake.getPublishedTracksMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetPublishedTracksCallCount() int { + fake.getPublishedTracksMutex.RLock() + defer fake.getPublishedTracksMutex.RUnlock() + return len(fake.getPublishedTracksArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPublishedTracksCalls(stub func() []types.MediaTrack) { + fake.getPublishedTracksMutex.Lock() + defer fake.getPublishedTracksMutex.Unlock() + fake.GetPublishedTracksStub = stub +} + +func (fake *FakeLocalParticipant) GetPublishedTracksReturns(result1 []types.MediaTrack) { + fake.getPublishedTracksMutex.Lock() + defer fake.getPublishedTracksMutex.Unlock() + fake.GetPublishedTracksStub = nil + fake.getPublishedTracksReturns = struct { + result1 []types.MediaTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublishedTracksReturnsOnCall(i int, result1 []types.MediaTrack) { + fake.getPublishedTracksMutex.Lock() + defer fake.getPublishedTracksMutex.Unlock() + fake.GetPublishedTracksStub = nil + if fake.getPublishedTracksReturnsOnCall == nil { + fake.getPublishedTracksReturnsOnCall = make(map[int]struct { + result1 []types.MediaTrack + }) + } + fake.getPublishedTracksReturnsOnCall[i] = struct { + result1 []types.MediaTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetPublisherICESessionUfrag() (string, error) { + fake.getPublisherICESessionUfragMutex.Lock() + ret, specificReturn := fake.getPublisherICESessionUfragReturnsOnCall[len(fake.getPublisherICESessionUfragArgsForCall)] + fake.getPublisherICESessionUfragArgsForCall = append(fake.getPublisherICESessionUfragArgsForCall, struct { + }{}) + stub := fake.GetPublisherICESessionUfragStub + fakeReturns := fake.getPublisherICESessionUfragReturns + fake.recordInvocation("GetPublisherICESessionUfrag", []interface{}{}) + fake.getPublisherICESessionUfragMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipant) GetPublisherICESessionUfragCallCount() int { + fake.getPublisherICESessionUfragMutex.RLock() + defer fake.getPublisherICESessionUfragMutex.RUnlock() + return len(fake.getPublisherICESessionUfragArgsForCall) +} + +func (fake *FakeLocalParticipant) GetPublisherICESessionUfragCalls(stub func() (string, error)) { + fake.getPublisherICESessionUfragMutex.Lock() + defer fake.getPublisherICESessionUfragMutex.Unlock() + fake.GetPublisherICESessionUfragStub = stub +} + +func (fake *FakeLocalParticipant) GetPublisherICESessionUfragReturns(result1 string, result2 error) { + fake.getPublisherICESessionUfragMutex.Lock() + defer fake.getPublisherICESessionUfragMutex.Unlock() + fake.GetPublisherICESessionUfragStub = nil + fake.getPublisherICESessionUfragReturns = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeLocalParticipant) GetPublisherICESessionUfragReturnsOnCall(i int, result1 string, result2 error) { + fake.getPublisherICESessionUfragMutex.Lock() + defer fake.getPublisherICESessionUfragMutex.Unlock() + fake.GetPublisherICESessionUfragStub = nil + if fake.getPublisherICESessionUfragReturnsOnCall == nil { + fake.getPublisherICESessionUfragReturnsOnCall = make(map[int]struct { + result1 string + result2 error + }) + } + fake.getPublisherICESessionUfragReturnsOnCall[i] = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeLocalParticipant) GetReporter() roomobs.ParticipantSessionReporter { + fake.getReporterMutex.Lock() + ret, specificReturn := fake.getReporterReturnsOnCall[len(fake.getReporterArgsForCall)] + fake.getReporterArgsForCall = append(fake.getReporterArgsForCall, struct { + }{}) + stub := fake.GetReporterStub + fakeReturns := fake.getReporterReturns + fake.recordInvocation("GetReporter", []interface{}{}) + fake.getReporterMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetReporterCallCount() int { + fake.getReporterMutex.RLock() + defer fake.getReporterMutex.RUnlock() + return len(fake.getReporterArgsForCall) +} + +func (fake *FakeLocalParticipant) GetReporterCalls(stub func() roomobs.ParticipantSessionReporter) { + fake.getReporterMutex.Lock() + defer fake.getReporterMutex.Unlock() + fake.GetReporterStub = stub +} + +func (fake *FakeLocalParticipant) GetReporterReturns(result1 roomobs.ParticipantSessionReporter) { + fake.getReporterMutex.Lock() + defer fake.getReporterMutex.Unlock() + fake.GetReporterStub = nil + fake.getReporterReturns = struct { + result1 roomobs.ParticipantSessionReporter + }{result1} +} + +func (fake *FakeLocalParticipant) GetReporterReturnsOnCall(i int, result1 roomobs.ParticipantSessionReporter) { + fake.getReporterMutex.Lock() + defer fake.getReporterMutex.Unlock() + fake.GetReporterStub = nil + if fake.getReporterReturnsOnCall == nil { + fake.getReporterReturnsOnCall = make(map[int]struct { + result1 roomobs.ParticipantSessionReporter + }) + } + fake.getReporterReturnsOnCall[i] = struct { + result1 roomobs.ParticipantSessionReporter + }{result1} +} + +func (fake *FakeLocalParticipant) GetReporterResolver() roomobs.ParticipantReporterResolver { + fake.getReporterResolverMutex.Lock() + ret, specificReturn := fake.getReporterResolverReturnsOnCall[len(fake.getReporterResolverArgsForCall)] + fake.getReporterResolverArgsForCall = append(fake.getReporterResolverArgsForCall, struct { + }{}) + stub := fake.GetReporterResolverStub + fakeReturns := fake.getReporterResolverReturns + fake.recordInvocation("GetReporterResolver", []interface{}{}) + fake.getReporterResolverMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetReporterResolverCallCount() int { + fake.getReporterResolverMutex.RLock() + defer fake.getReporterResolverMutex.RUnlock() + return len(fake.getReporterResolverArgsForCall) +} + +func (fake *FakeLocalParticipant) GetReporterResolverCalls(stub func() roomobs.ParticipantReporterResolver) { + fake.getReporterResolverMutex.Lock() + defer fake.getReporterResolverMutex.Unlock() + fake.GetReporterResolverStub = stub +} + +func (fake *FakeLocalParticipant) GetReporterResolverReturns(result1 roomobs.ParticipantReporterResolver) { + fake.getReporterResolverMutex.Lock() + defer fake.getReporterResolverMutex.Unlock() + fake.GetReporterResolverStub = nil + fake.getReporterResolverReturns = struct { + result1 roomobs.ParticipantReporterResolver + }{result1} +} + +func (fake *FakeLocalParticipant) GetReporterResolverReturnsOnCall(i int, result1 roomobs.ParticipantReporterResolver) { + fake.getReporterResolverMutex.Lock() + defer fake.getReporterResolverMutex.Unlock() + fake.GetReporterResolverStub = nil + if fake.getReporterResolverReturnsOnCall == nil { + fake.getReporterResolverReturnsOnCall = make(map[int]struct { + result1 roomobs.ParticipantReporterResolver + }) + } + fake.getReporterResolverReturnsOnCall[i] = struct { + result1 roomobs.ParticipantReporterResolver + }{result1} +} + +func (fake *FakeLocalParticipant) GetResponseSink() routing.MessageSink { + fake.getResponseSinkMutex.Lock() + ret, specificReturn := fake.getResponseSinkReturnsOnCall[len(fake.getResponseSinkArgsForCall)] + fake.getResponseSinkArgsForCall = append(fake.getResponseSinkArgsForCall, struct { + }{}) + stub := fake.GetResponseSinkStub + fakeReturns := fake.getResponseSinkReturns + fake.recordInvocation("GetResponseSink", []interface{}{}) + fake.getResponseSinkMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetResponseSinkCallCount() int { + fake.getResponseSinkMutex.RLock() + defer fake.getResponseSinkMutex.RUnlock() + return len(fake.getResponseSinkArgsForCall) +} + +func (fake *FakeLocalParticipant) GetResponseSinkCalls(stub func() routing.MessageSink) { + fake.getResponseSinkMutex.Lock() + defer fake.getResponseSinkMutex.Unlock() + fake.GetResponseSinkStub = stub +} + +func (fake *FakeLocalParticipant) GetResponseSinkReturns(result1 routing.MessageSink) { + fake.getResponseSinkMutex.Lock() + defer fake.getResponseSinkMutex.Unlock() + fake.GetResponseSinkStub = nil + fake.getResponseSinkReturns = struct { + result1 routing.MessageSink + }{result1} +} + +func (fake *FakeLocalParticipant) GetResponseSinkReturnsOnCall(i int, result1 routing.MessageSink) { + fake.getResponseSinkMutex.Lock() + defer fake.getResponseSinkMutex.Unlock() + fake.GetResponseSinkStub = nil + if fake.getResponseSinkReturnsOnCall == nil { + fake.getResponseSinkReturnsOnCall = make(map[int]struct { + result1 routing.MessageSink + }) + } + fake.getResponseSinkReturnsOnCall[i] = struct { + result1 routing.MessageSink + }{result1} +} + +func (fake *FakeLocalParticipant) GetSubscribedParticipants() []livekit.ParticipantID { + fake.getSubscribedParticipantsMutex.Lock() + ret, specificReturn := fake.getSubscribedParticipantsReturnsOnCall[len(fake.getSubscribedParticipantsArgsForCall)] + fake.getSubscribedParticipantsArgsForCall = append(fake.getSubscribedParticipantsArgsForCall, struct { + }{}) + stub := fake.GetSubscribedParticipantsStub + fakeReturns := fake.getSubscribedParticipantsReturns + fake.recordInvocation("GetSubscribedParticipants", []interface{}{}) + fake.getSubscribedParticipantsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetSubscribedParticipantsCallCount() int { + fake.getSubscribedParticipantsMutex.RLock() + defer fake.getSubscribedParticipantsMutex.RUnlock() + return len(fake.getSubscribedParticipantsArgsForCall) +} + +func (fake *FakeLocalParticipant) GetSubscribedParticipantsCalls(stub func() []livekit.ParticipantID) { + fake.getSubscribedParticipantsMutex.Lock() + defer fake.getSubscribedParticipantsMutex.Unlock() + fake.GetSubscribedParticipantsStub = stub +} + +func (fake *FakeLocalParticipant) GetSubscribedParticipantsReturns(result1 []livekit.ParticipantID) { + fake.getSubscribedParticipantsMutex.Lock() + defer fake.getSubscribedParticipantsMutex.Unlock() + fake.GetSubscribedParticipantsStub = nil + fake.getSubscribedParticipantsReturns = struct { + result1 []livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalParticipant) GetSubscribedParticipantsReturnsOnCall(i int, result1 []livekit.ParticipantID) { + fake.getSubscribedParticipantsMutex.Lock() + defer fake.getSubscribedParticipantsMutex.Unlock() + fake.GetSubscribedParticipantsStub = nil + if fake.getSubscribedParticipantsReturnsOnCall == nil { + fake.getSubscribedParticipantsReturnsOnCall = make(map[int]struct { + result1 []livekit.ParticipantID + }) + } + fake.getSubscribedParticipantsReturnsOnCall[i] = struct { + result1 []livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalParticipant) GetSubscribedTracks() []types.SubscribedTrack { + fake.getSubscribedTracksMutex.Lock() + ret, specificReturn := fake.getSubscribedTracksReturnsOnCall[len(fake.getSubscribedTracksArgsForCall)] + fake.getSubscribedTracksArgsForCall = append(fake.getSubscribedTracksArgsForCall, struct { + }{}) + stub := fake.GetSubscribedTracksStub + fakeReturns := fake.getSubscribedTracksReturns + fake.recordInvocation("GetSubscribedTracks", []interface{}{}) + fake.getSubscribedTracksMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetSubscribedTracksCallCount() int { + fake.getSubscribedTracksMutex.RLock() + defer fake.getSubscribedTracksMutex.RUnlock() + return len(fake.getSubscribedTracksArgsForCall) +} + +func (fake *FakeLocalParticipant) GetSubscribedTracksCalls(stub func() []types.SubscribedTrack) { + fake.getSubscribedTracksMutex.Lock() + defer fake.getSubscribedTracksMutex.Unlock() + fake.GetSubscribedTracksStub = stub +} + +func (fake *FakeLocalParticipant) GetSubscribedTracksReturns(result1 []types.SubscribedTrack) { + fake.getSubscribedTracksMutex.Lock() + defer fake.getSubscribedTracksMutex.Unlock() + fake.GetSubscribedTracksStub = nil + fake.getSubscribedTracksReturns = struct { + result1 []types.SubscribedTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetSubscribedTracksReturnsOnCall(i int, result1 []types.SubscribedTrack) { + fake.getSubscribedTracksMutex.Lock() + defer fake.getSubscribedTracksMutex.Unlock() + fake.GetSubscribedTracksStub = nil + if fake.getSubscribedTracksReturnsOnCall == nil { + fake.getSubscribedTracksReturnsOnCall = make(map[int]struct { + result1 []types.SubscribedTrack + }) + } + fake.getSubscribedTracksReturnsOnCall[i] = struct { + result1 []types.SubscribedTrack + }{result1} +} + +func (fake *FakeLocalParticipant) GetTrailer() []byte { + fake.getTrailerMutex.Lock() + ret, specificReturn := fake.getTrailerReturnsOnCall[len(fake.getTrailerArgsForCall)] + fake.getTrailerArgsForCall = append(fake.getTrailerArgsForCall, struct { + }{}) + stub := fake.GetTrailerStub + fakeReturns := fake.getTrailerReturns + fake.recordInvocation("GetTrailer", []interface{}{}) + fake.getTrailerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) GetTrailerCallCount() int { + fake.getTrailerMutex.RLock() + defer fake.getTrailerMutex.RUnlock() + return len(fake.getTrailerArgsForCall) +} + +func (fake *FakeLocalParticipant) GetTrailerCalls(stub func() []byte) { + fake.getTrailerMutex.Lock() + defer fake.getTrailerMutex.Unlock() + fake.GetTrailerStub = stub +} + +func (fake *FakeLocalParticipant) GetTrailerReturns(result1 []byte) { + fake.getTrailerMutex.Lock() + defer fake.getTrailerMutex.Unlock() + fake.GetTrailerStub = nil + fake.getTrailerReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeLocalParticipant) GetTrailerReturnsOnCall(i int, result1 []byte) { + fake.getTrailerMutex.Lock() + defer fake.getTrailerMutex.Unlock() + fake.GetTrailerStub = nil + if fake.getTrailerReturnsOnCall == nil { + fake.getTrailerReturnsOnCall = make(map[int]struct { + result1 []byte + }) + } + fake.getTrailerReturnsOnCall[i] = struct { + result1 []byte + }{result1} +} + +func (fake *FakeLocalParticipant) HandleAnswer(arg1 *livekit.SessionDescription) { + fake.handleAnswerMutex.Lock() + fake.handleAnswerArgsForCall = append(fake.handleAnswerArgsForCall, struct { + arg1 *livekit.SessionDescription + }{arg1}) + stub := fake.HandleAnswerStub + fake.recordInvocation("HandleAnswer", []interface{}{arg1}) + fake.handleAnswerMutex.Unlock() + if stub != nil { + fake.HandleAnswerStub(arg1) + } +} + +func (fake *FakeLocalParticipant) HandleAnswerCallCount() int { + fake.handleAnswerMutex.RLock() + defer fake.handleAnswerMutex.RUnlock() + return len(fake.handleAnswerArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleAnswerCalls(stub func(*livekit.SessionDescription)) { + fake.handleAnswerMutex.Lock() + defer fake.handleAnswerMutex.Unlock() + fake.HandleAnswerStub = stub +} + +func (fake *FakeLocalParticipant) HandleAnswerArgsForCall(i int) *livekit.SessionDescription { + fake.handleAnswerMutex.RLock() + defer fake.handleAnswerMutex.RUnlock() + argsForCall := fake.handleAnswerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleICERestartSDPFragment(arg1 string) (string, error) { + fake.handleICERestartSDPFragmentMutex.Lock() + ret, specificReturn := fake.handleICERestartSDPFragmentReturnsOnCall[len(fake.handleICERestartSDPFragmentArgsForCall)] + fake.handleICERestartSDPFragmentArgsForCall = append(fake.handleICERestartSDPFragmentArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.HandleICERestartSDPFragmentStub + fakeReturns := fake.handleICERestartSDPFragmentReturns + fake.recordInvocation("HandleICERestartSDPFragment", []interface{}{arg1}) + fake.handleICERestartSDPFragmentMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipant) HandleICERestartSDPFragmentCallCount() int { + fake.handleICERestartSDPFragmentMutex.RLock() + defer fake.handleICERestartSDPFragmentMutex.RUnlock() + return len(fake.handleICERestartSDPFragmentArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleICERestartSDPFragmentCalls(stub func(string) (string, error)) { + fake.handleICERestartSDPFragmentMutex.Lock() + defer fake.handleICERestartSDPFragmentMutex.Unlock() + fake.HandleICERestartSDPFragmentStub = stub +} + +func (fake *FakeLocalParticipant) HandleICERestartSDPFragmentArgsForCall(i int) string { + fake.handleICERestartSDPFragmentMutex.RLock() + defer fake.handleICERestartSDPFragmentMutex.RUnlock() + argsForCall := fake.handleICERestartSDPFragmentArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleICERestartSDPFragmentReturns(result1 string, result2 error) { + fake.handleICERestartSDPFragmentMutex.Lock() + defer fake.handleICERestartSDPFragmentMutex.Unlock() + fake.HandleICERestartSDPFragmentStub = nil + fake.handleICERestartSDPFragmentReturns = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeLocalParticipant) HandleICERestartSDPFragmentReturnsOnCall(i int, result1 string, result2 error) { + fake.handleICERestartSDPFragmentMutex.Lock() + defer fake.handleICERestartSDPFragmentMutex.Unlock() + fake.HandleICERestartSDPFragmentStub = nil + if fake.handleICERestartSDPFragmentReturnsOnCall == nil { + fake.handleICERestartSDPFragmentReturnsOnCall = make(map[int]struct { + result1 string + result2 error + }) + } + fake.handleICERestartSDPFragmentReturnsOnCall[i] = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeLocalParticipant) HandleICETrickle(arg1 *livekit.TrickleRequest) { + fake.handleICETrickleMutex.Lock() + fake.handleICETrickleArgsForCall = append(fake.handleICETrickleArgsForCall, struct { + arg1 *livekit.TrickleRequest + }{arg1}) + stub := fake.HandleICETrickleStub + fake.recordInvocation("HandleICETrickle", []interface{}{arg1}) + fake.handleICETrickleMutex.Unlock() + if stub != nil { + fake.HandleICETrickleStub(arg1) + } +} + +func (fake *FakeLocalParticipant) HandleICETrickleCallCount() int { + fake.handleICETrickleMutex.RLock() + defer fake.handleICETrickleMutex.RUnlock() + return len(fake.handleICETrickleArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleICETrickleCalls(stub func(*livekit.TrickleRequest)) { + fake.handleICETrickleMutex.Lock() + defer fake.handleICETrickleMutex.Unlock() + fake.HandleICETrickleStub = stub +} + +func (fake *FakeLocalParticipant) HandleICETrickleArgsForCall(i int) *livekit.TrickleRequest { + fake.handleICETrickleMutex.RLock() + defer fake.handleICETrickleMutex.RUnlock() + argsForCall := fake.handleICETrickleArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleICETrickleSDPFragment(arg1 string) error { + fake.handleICETrickleSDPFragmentMutex.Lock() + ret, specificReturn := fake.handleICETrickleSDPFragmentReturnsOnCall[len(fake.handleICETrickleSDPFragmentArgsForCall)] + fake.handleICETrickleSDPFragmentArgsForCall = append(fake.handleICETrickleSDPFragmentArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.HandleICETrickleSDPFragmentStub + fakeReturns := fake.handleICETrickleSDPFragmentReturns + fake.recordInvocation("HandleICETrickleSDPFragment", []interface{}{arg1}) + fake.handleICETrickleSDPFragmentMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleICETrickleSDPFragmentCallCount() int { + fake.handleICETrickleSDPFragmentMutex.RLock() + defer fake.handleICETrickleSDPFragmentMutex.RUnlock() + return len(fake.handleICETrickleSDPFragmentArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleICETrickleSDPFragmentCalls(stub func(string) error) { + fake.handleICETrickleSDPFragmentMutex.Lock() + defer fake.handleICETrickleSDPFragmentMutex.Unlock() + fake.HandleICETrickleSDPFragmentStub = stub +} + +func (fake *FakeLocalParticipant) HandleICETrickleSDPFragmentArgsForCall(i int) string { + fake.handleICETrickleSDPFragmentMutex.RLock() + defer fake.handleICETrickleSDPFragmentMutex.RUnlock() + argsForCall := fake.handleICETrickleSDPFragmentArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleICETrickleSDPFragmentReturns(result1 error) { + fake.handleICETrickleSDPFragmentMutex.Lock() + defer fake.handleICETrickleSDPFragmentMutex.Unlock() + fake.HandleICETrickleSDPFragmentStub = nil + fake.handleICETrickleSDPFragmentReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleICETrickleSDPFragmentReturnsOnCall(i int, result1 error) { + fake.handleICETrickleSDPFragmentMutex.Lock() + defer fake.handleICETrickleSDPFragmentMutex.Unlock() + fake.HandleICETrickleSDPFragmentStub = nil + if fake.handleICETrickleSDPFragmentReturnsOnCall == nil { + fake.handleICETrickleSDPFragmentReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleICETrickleSDPFragmentReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleLeaveRequest(arg1 types.ParticipantCloseReason) { + fake.handleLeaveRequestMutex.Lock() + fake.handleLeaveRequestArgsForCall = append(fake.handleLeaveRequestArgsForCall, struct { + arg1 types.ParticipantCloseReason + }{arg1}) + stub := fake.HandleLeaveRequestStub + fake.recordInvocation("HandleLeaveRequest", []interface{}{arg1}) + fake.handleLeaveRequestMutex.Unlock() + if stub != nil { + fake.HandleLeaveRequestStub(arg1) + } +} + +func (fake *FakeLocalParticipant) HandleLeaveRequestCallCount() int { + fake.handleLeaveRequestMutex.RLock() + defer fake.handleLeaveRequestMutex.RUnlock() + return len(fake.handleLeaveRequestArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleLeaveRequestCalls(stub func(types.ParticipantCloseReason)) { + fake.handleLeaveRequestMutex.Lock() + defer fake.handleLeaveRequestMutex.Unlock() + fake.HandleLeaveRequestStub = stub +} + +func (fake *FakeLocalParticipant) HandleLeaveRequestArgsForCall(i int) types.ParticipantCloseReason { + fake.handleLeaveRequestMutex.RLock() + defer fake.handleLeaveRequestMutex.RUnlock() + argsForCall := fake.handleLeaveRequestArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleMetrics(arg1 livekit.ParticipantID, arg2 *livekit.MetricsBatch) error { + fake.handleMetricsMutex.Lock() + ret, specificReturn := fake.handleMetricsReturnsOnCall[len(fake.handleMetricsArgsForCall)] + fake.handleMetricsArgsForCall = append(fake.handleMetricsArgsForCall, struct { + arg1 livekit.ParticipantID + arg2 *livekit.MetricsBatch + }{arg1, arg2}) + stub := fake.HandleMetricsStub + fakeReturns := fake.handleMetricsReturns + fake.recordInvocation("HandleMetrics", []interface{}{arg1, arg2}) + fake.handleMetricsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleMetricsCallCount() int { + fake.handleMetricsMutex.RLock() + defer fake.handleMetricsMutex.RUnlock() + return len(fake.handleMetricsArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleMetricsCalls(stub func(livekit.ParticipantID, *livekit.MetricsBatch) error) { + fake.handleMetricsMutex.Lock() + defer fake.handleMetricsMutex.Unlock() + fake.HandleMetricsStub = stub +} + +func (fake *FakeLocalParticipant) HandleMetricsArgsForCall(i int) (livekit.ParticipantID, *livekit.MetricsBatch) { + fake.handleMetricsMutex.RLock() + defer fake.handleMetricsMutex.RUnlock() + argsForCall := fake.handleMetricsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) HandleMetricsReturns(result1 error) { + fake.handleMetricsMutex.Lock() + defer fake.handleMetricsMutex.Unlock() + fake.HandleMetricsStub = nil + fake.handleMetricsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleMetricsReturnsOnCall(i int, result1 error) { + fake.handleMetricsMutex.Lock() + defer fake.handleMetricsMutex.Unlock() + fake.HandleMetricsStub = nil + if fake.handleMetricsReturnsOnCall == nil { + fake.handleMetricsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleMetricsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleOffer(arg1 *livekit.SessionDescription) error { + fake.handleOfferMutex.Lock() + ret, specificReturn := fake.handleOfferReturnsOnCall[len(fake.handleOfferArgsForCall)] + fake.handleOfferArgsForCall = append(fake.handleOfferArgsForCall, struct { + arg1 *livekit.SessionDescription + }{arg1}) + stub := fake.HandleOfferStub + fakeReturns := fake.handleOfferReturns + fake.recordInvocation("HandleOffer", []interface{}{arg1}) + fake.handleOfferMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleOfferCallCount() int { + fake.handleOfferMutex.RLock() + defer fake.handleOfferMutex.RUnlock() + return len(fake.handleOfferArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleOfferCalls(stub func(*livekit.SessionDescription) error) { + fake.handleOfferMutex.Lock() + defer fake.handleOfferMutex.Unlock() + fake.HandleOfferStub = stub +} + +func (fake *FakeLocalParticipant) HandleOfferArgsForCall(i int) *livekit.SessionDescription { + fake.handleOfferMutex.RLock() + defer fake.handleOfferMutex.RUnlock() + argsForCall := fake.handleOfferArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleOfferReturns(result1 error) { + fake.handleOfferMutex.Lock() + defer fake.handleOfferMutex.Unlock() + fake.HandleOfferStub = nil + fake.handleOfferReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleOfferReturnsOnCall(i int, result1 error) { + fake.handleOfferMutex.Lock() + defer fake.handleOfferMutex.Unlock() + fake.HandleOfferStub = nil + if fake.handleOfferReturnsOnCall == nil { + fake.handleOfferReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleOfferReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandlePublishDataTrackRequest(arg1 *livekit.PublishDataTrackRequest) { + fake.handlePublishDataTrackRequestMutex.Lock() + fake.handlePublishDataTrackRequestArgsForCall = append(fake.handlePublishDataTrackRequestArgsForCall, struct { + arg1 *livekit.PublishDataTrackRequest + }{arg1}) + stub := fake.HandlePublishDataTrackRequestStub + fake.recordInvocation("HandlePublishDataTrackRequest", []interface{}{arg1}) + fake.handlePublishDataTrackRequestMutex.Unlock() + if stub != nil { + fake.HandlePublishDataTrackRequestStub(arg1) + } +} + +func (fake *FakeLocalParticipant) HandlePublishDataTrackRequestCallCount() int { + fake.handlePublishDataTrackRequestMutex.RLock() + defer fake.handlePublishDataTrackRequestMutex.RUnlock() + return len(fake.handlePublishDataTrackRequestArgsForCall) +} + +func (fake *FakeLocalParticipant) HandlePublishDataTrackRequestCalls(stub func(*livekit.PublishDataTrackRequest)) { + fake.handlePublishDataTrackRequestMutex.Lock() + defer fake.handlePublishDataTrackRequestMutex.Unlock() + fake.HandlePublishDataTrackRequestStub = stub +} + +func (fake *FakeLocalParticipant) HandlePublishDataTrackRequestArgsForCall(i int) *livekit.PublishDataTrackRequest { + fake.handlePublishDataTrackRequestMutex.RLock() + defer fake.handlePublishDataTrackRequestMutex.RUnlock() + argsForCall := fake.handlePublishDataTrackRequestArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleReceivedDataTrackMessage(arg1 []byte, arg2 *datatrack.Packet, arg3 int64) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.handleReceivedDataTrackMessageMutex.Lock() + fake.handleReceivedDataTrackMessageArgsForCall = append(fake.handleReceivedDataTrackMessageArgsForCall, struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + }{arg1Copy, arg2, arg3}) + stub := fake.HandleReceivedDataTrackMessageStub + fake.recordInvocation("HandleReceivedDataTrackMessage", []interface{}{arg1Copy, arg2, arg3}) + fake.handleReceivedDataTrackMessageMutex.Unlock() + if stub != nil { + fake.HandleReceivedDataTrackMessageStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipant) HandleReceivedDataTrackMessageCallCount() int { + fake.handleReceivedDataTrackMessageMutex.RLock() + defer fake.handleReceivedDataTrackMessageMutex.RUnlock() + return len(fake.handleReceivedDataTrackMessageArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleReceivedDataTrackMessageCalls(stub func([]byte, *datatrack.Packet, int64)) { + fake.handleReceivedDataTrackMessageMutex.Lock() + defer fake.handleReceivedDataTrackMessageMutex.Unlock() + fake.HandleReceivedDataTrackMessageStub = stub +} + +func (fake *FakeLocalParticipant) HandleReceivedDataTrackMessageArgsForCall(i int) ([]byte, *datatrack.Packet, int64) { + fake.handleReceivedDataTrackMessageMutex.RLock() + defer fake.handleReceivedDataTrackMessageMutex.RUnlock() + argsForCall := fake.handleReceivedDataTrackMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) HandleReceiverReport(arg1 *sfu.DownTrack, arg2 *rtcp.ReceiverReport) { + fake.handleReceiverReportMutex.Lock() + fake.handleReceiverReportArgsForCall = append(fake.handleReceiverReportArgsForCall, struct { + arg1 *sfu.DownTrack + arg2 *rtcp.ReceiverReport + }{arg1, arg2}) + stub := fake.HandleReceiverReportStub + fake.recordInvocation("HandleReceiverReport", []interface{}{arg1, arg2}) + fake.handleReceiverReportMutex.Unlock() + if stub != nil { + fake.HandleReceiverReportStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) HandleReceiverReportCallCount() int { + fake.handleReceiverReportMutex.RLock() + defer fake.handleReceiverReportMutex.RUnlock() + return len(fake.handleReceiverReportArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleReceiverReportCalls(stub func(*sfu.DownTrack, *rtcp.ReceiverReport)) { + fake.handleReceiverReportMutex.Lock() + defer fake.handleReceiverReportMutex.Unlock() + fake.HandleReceiverReportStub = stub +} + +func (fake *FakeLocalParticipant) HandleReceiverReportArgsForCall(i int) (*sfu.DownTrack, *rtcp.ReceiverReport) { + fake.handleReceiverReportMutex.RLock() + defer fake.handleReceiverReportMutex.RUnlock() + argsForCall := fake.handleReceiverReportArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) HandleReconnectAndSendResponse(arg1 livekit.ReconnectReason, arg2 *livekit.ReconnectResponse) error { + fake.handleReconnectAndSendResponseMutex.Lock() + ret, specificReturn := fake.handleReconnectAndSendResponseReturnsOnCall[len(fake.handleReconnectAndSendResponseArgsForCall)] + fake.handleReconnectAndSendResponseArgsForCall = append(fake.handleReconnectAndSendResponseArgsForCall, struct { + arg1 livekit.ReconnectReason + arg2 *livekit.ReconnectResponse + }{arg1, arg2}) + stub := fake.HandleReconnectAndSendResponseStub + fakeReturns := fake.handleReconnectAndSendResponseReturns + fake.recordInvocation("HandleReconnectAndSendResponse", []interface{}{arg1, arg2}) + fake.handleReconnectAndSendResponseMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleReconnectAndSendResponseCallCount() int { + fake.handleReconnectAndSendResponseMutex.RLock() + defer fake.handleReconnectAndSendResponseMutex.RUnlock() + return len(fake.handleReconnectAndSendResponseArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleReconnectAndSendResponseCalls(stub func(livekit.ReconnectReason, *livekit.ReconnectResponse) error) { + fake.handleReconnectAndSendResponseMutex.Lock() + defer fake.handleReconnectAndSendResponseMutex.Unlock() + fake.HandleReconnectAndSendResponseStub = stub +} + +func (fake *FakeLocalParticipant) HandleReconnectAndSendResponseArgsForCall(i int) (livekit.ReconnectReason, *livekit.ReconnectResponse) { + fake.handleReconnectAndSendResponseMutex.RLock() + defer fake.handleReconnectAndSendResponseMutex.RUnlock() + argsForCall := fake.handleReconnectAndSendResponseArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) HandleReconnectAndSendResponseReturns(result1 error) { + fake.handleReconnectAndSendResponseMutex.Lock() + defer fake.handleReconnectAndSendResponseMutex.Unlock() + fake.HandleReconnectAndSendResponseStub = nil + fake.handleReconnectAndSendResponseReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleReconnectAndSendResponseReturnsOnCall(i int, result1 error) { + fake.handleReconnectAndSendResponseMutex.Lock() + defer fake.handleReconnectAndSendResponseMutex.Unlock() + fake.HandleReconnectAndSendResponseStub = nil + if fake.handleReconnectAndSendResponseReturnsOnCall == nil { + fake.handleReconnectAndSendResponseReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleReconnectAndSendResponseReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleSignalMessage(arg1 proto.Message) error { + fake.handleSignalMessageMutex.Lock() + ret, specificReturn := fake.handleSignalMessageReturnsOnCall[len(fake.handleSignalMessageArgsForCall)] + fake.handleSignalMessageArgsForCall = append(fake.handleSignalMessageArgsForCall, struct { + arg1 proto.Message + }{arg1}) + stub := fake.HandleSignalMessageStub + fakeReturns := fake.handleSignalMessageReturns + fake.recordInvocation("HandleSignalMessage", []interface{}{arg1}) + fake.handleSignalMessageMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleSignalMessageCallCount() int { + fake.handleSignalMessageMutex.RLock() + defer fake.handleSignalMessageMutex.RUnlock() + return len(fake.handleSignalMessageArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleSignalMessageCalls(stub func(proto.Message) error) { + fake.handleSignalMessageMutex.Lock() + defer fake.handleSignalMessageMutex.Unlock() + fake.HandleSignalMessageStub = stub +} + +func (fake *FakeLocalParticipant) HandleSignalMessageArgsForCall(i int) proto.Message { + fake.handleSignalMessageMutex.RLock() + defer fake.handleSignalMessageMutex.RUnlock() + argsForCall := fake.handleSignalMessageArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleSignalMessageReturns(result1 error) { + fake.handleSignalMessageMutex.Lock() + defer fake.handleSignalMessageMutex.Unlock() + fake.HandleSignalMessageStub = nil + fake.handleSignalMessageReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleSignalMessageReturnsOnCall(i int, result1 error) { + fake.handleSignalMessageMutex.Lock() + defer fake.handleSignalMessageMutex.Unlock() + fake.HandleSignalMessageStub = nil + if fake.handleSignalMessageReturnsOnCall == nil { + fake.handleSignalMessageReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleSignalMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleSignalSourceClose() { + fake.handleSignalSourceCloseMutex.Lock() + fake.handleSignalSourceCloseArgsForCall = append(fake.handleSignalSourceCloseArgsForCall, struct { + }{}) + stub := fake.HandleSignalSourceCloseStub + fake.recordInvocation("HandleSignalSourceClose", []interface{}{}) + fake.handleSignalSourceCloseMutex.Unlock() + if stub != nil { + fake.HandleSignalSourceCloseStub() + } +} + +func (fake *FakeLocalParticipant) HandleSignalSourceCloseCallCount() int { + fake.handleSignalSourceCloseMutex.RLock() + defer fake.handleSignalSourceCloseMutex.RUnlock() + return len(fake.handleSignalSourceCloseArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleSignalSourceCloseCalls(stub func()) { + fake.handleSignalSourceCloseMutex.Lock() + defer fake.handleSignalSourceCloseMutex.Unlock() + fake.HandleSignalSourceCloseStub = stub +} + +func (fake *FakeLocalParticipant) HandleSimulateScenario(arg1 *livekit.SimulateScenario) error { + fake.handleSimulateScenarioMutex.Lock() + ret, specificReturn := fake.handleSimulateScenarioReturnsOnCall[len(fake.handleSimulateScenarioArgsForCall)] + fake.handleSimulateScenarioArgsForCall = append(fake.handleSimulateScenarioArgsForCall, struct { + arg1 *livekit.SimulateScenario + }{arg1}) + stub := fake.HandleSimulateScenarioStub + fakeReturns := fake.handleSimulateScenarioReturns + fake.recordInvocation("HandleSimulateScenario", []interface{}{arg1}) + fake.handleSimulateScenarioMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleSimulateScenarioCallCount() int { + fake.handleSimulateScenarioMutex.RLock() + defer fake.handleSimulateScenarioMutex.RUnlock() + return len(fake.handleSimulateScenarioArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleSimulateScenarioCalls(stub func(*livekit.SimulateScenario) error) { + fake.handleSimulateScenarioMutex.Lock() + defer fake.handleSimulateScenarioMutex.Unlock() + fake.HandleSimulateScenarioStub = stub +} + +func (fake *FakeLocalParticipant) HandleSimulateScenarioArgsForCall(i int) *livekit.SimulateScenario { + fake.handleSimulateScenarioMutex.RLock() + defer fake.handleSimulateScenarioMutex.RUnlock() + argsForCall := fake.handleSimulateScenarioArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleSimulateScenarioReturns(result1 error) { + fake.handleSimulateScenarioMutex.Lock() + defer fake.handleSimulateScenarioMutex.Unlock() + fake.HandleSimulateScenarioStub = nil + fake.handleSimulateScenarioReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleSimulateScenarioReturnsOnCall(i int, result1 error) { + fake.handleSimulateScenarioMutex.Lock() + defer fake.handleSimulateScenarioMutex.Unlock() + fake.HandleSimulateScenarioStub = nil + if fake.handleSimulateScenarioReturnsOnCall == nil { + fake.handleSimulateScenarioReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleSimulateScenarioReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleSyncState(arg1 *livekit.SyncState) error { + fake.handleSyncStateMutex.Lock() + ret, specificReturn := fake.handleSyncStateReturnsOnCall[len(fake.handleSyncStateArgsForCall)] + fake.handleSyncStateArgsForCall = append(fake.handleSyncStateArgsForCall, struct { + arg1 *livekit.SyncState + }{arg1}) + stub := fake.HandleSyncStateStub + fakeReturns := fake.handleSyncStateReturns + fake.recordInvocation("HandleSyncState", []interface{}{arg1}) + fake.handleSyncStateMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleSyncStateCallCount() int { + fake.handleSyncStateMutex.RLock() + defer fake.handleSyncStateMutex.RUnlock() + return len(fake.handleSyncStateArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleSyncStateCalls(stub func(*livekit.SyncState) error) { + fake.handleSyncStateMutex.Lock() + defer fake.handleSyncStateMutex.Unlock() + fake.HandleSyncStateStub = stub +} + +func (fake *FakeLocalParticipant) HandleSyncStateArgsForCall(i int) *livekit.SyncState { + fake.handleSyncStateMutex.RLock() + defer fake.handleSyncStateMutex.RUnlock() + argsForCall := fake.handleSyncStateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleSyncStateReturns(result1 error) { + fake.handleSyncStateMutex.Lock() + defer fake.handleSyncStateMutex.Unlock() + fake.HandleSyncStateStub = nil + fake.handleSyncStateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleSyncStateReturnsOnCall(i int, result1 error) { + fake.handleSyncStateMutex.Lock() + defer fake.handleSyncStateMutex.Unlock() + fake.HandleSyncStateStub = nil + if fake.handleSyncStateReturnsOnCall == nil { + fake.handleSyncStateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleSyncStateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleUnpublishDataTrackRequest(arg1 *livekit.UnpublishDataTrackRequest) { + fake.handleUnpublishDataTrackRequestMutex.Lock() + fake.handleUnpublishDataTrackRequestArgsForCall = append(fake.handleUnpublishDataTrackRequestArgsForCall, struct { + arg1 *livekit.UnpublishDataTrackRequest + }{arg1}) + stub := fake.HandleUnpublishDataTrackRequestStub + fake.recordInvocation("HandleUnpublishDataTrackRequest", []interface{}{arg1}) + fake.handleUnpublishDataTrackRequestMutex.Unlock() + if stub != nil { + fake.HandleUnpublishDataTrackRequestStub(arg1) + } +} + +func (fake *FakeLocalParticipant) HandleUnpublishDataTrackRequestCallCount() int { + fake.handleUnpublishDataTrackRequestMutex.RLock() + defer fake.handleUnpublishDataTrackRequestMutex.RUnlock() + return len(fake.handleUnpublishDataTrackRequestArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleUnpublishDataTrackRequestCalls(stub func(*livekit.UnpublishDataTrackRequest)) { + fake.handleUnpublishDataTrackRequestMutex.Lock() + defer fake.handleUnpublishDataTrackRequestMutex.Unlock() + fake.HandleUnpublishDataTrackRequestStub = stub +} + +func (fake *FakeLocalParticipant) HandleUnpublishDataTrackRequestArgsForCall(i int) *livekit.UnpublishDataTrackRequest { + fake.handleUnpublishDataTrackRequestMutex.RLock() + defer fake.handleUnpublishDataTrackRequestMutex.RUnlock() + argsForCall := fake.handleUnpublishDataTrackRequestArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleUpdateDataSubscription(arg1 *livekit.UpdateDataSubscription) { + fake.handleUpdateDataSubscriptionMutex.Lock() + fake.handleUpdateDataSubscriptionArgsForCall = append(fake.handleUpdateDataSubscriptionArgsForCall, struct { + arg1 *livekit.UpdateDataSubscription + }{arg1}) + stub := fake.HandleUpdateDataSubscriptionStub + fake.recordInvocation("HandleUpdateDataSubscription", []interface{}{arg1}) + fake.handleUpdateDataSubscriptionMutex.Unlock() + if stub != nil { + fake.HandleUpdateDataSubscriptionStub(arg1) + } +} + +func (fake *FakeLocalParticipant) HandleUpdateDataSubscriptionCallCount() int { + fake.handleUpdateDataSubscriptionMutex.RLock() + defer fake.handleUpdateDataSubscriptionMutex.RUnlock() + return len(fake.handleUpdateDataSubscriptionArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleUpdateDataSubscriptionCalls(stub func(*livekit.UpdateDataSubscription)) { + fake.handleUpdateDataSubscriptionMutex.Lock() + defer fake.handleUpdateDataSubscriptionMutex.Unlock() + fake.HandleUpdateDataSubscriptionStub = stub +} + +func (fake *FakeLocalParticipant) HandleUpdateDataSubscriptionArgsForCall(i int) *livekit.UpdateDataSubscription { + fake.handleUpdateDataSubscriptionMutex.RLock() + defer fake.handleUpdateDataSubscriptionMutex.RUnlock() + argsForCall := fake.handleUpdateDataSubscriptionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission) error { + fake.handleUpdateSubscriptionPermissionMutex.Lock() + ret, specificReturn := fake.handleUpdateSubscriptionPermissionReturnsOnCall[len(fake.handleUpdateSubscriptionPermissionArgsForCall)] + fake.handleUpdateSubscriptionPermissionArgsForCall = append(fake.handleUpdateSubscriptionPermissionArgsForCall, struct { + arg1 *livekit.SubscriptionPermission + }{arg1}) + stub := fake.HandleUpdateSubscriptionPermissionStub + fakeReturns := fake.handleUpdateSubscriptionPermissionReturns + fake.recordInvocation("HandleUpdateSubscriptionPermission", []interface{}{arg1}) + fake.handleUpdateSubscriptionPermissionMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionPermissionCallCount() int { + fake.handleUpdateSubscriptionPermissionMutex.RLock() + defer fake.handleUpdateSubscriptionPermissionMutex.RUnlock() + return len(fake.handleUpdateSubscriptionPermissionArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission) error) { + fake.handleUpdateSubscriptionPermissionMutex.Lock() + defer fake.handleUpdateSubscriptionPermissionMutex.Unlock() + fake.HandleUpdateSubscriptionPermissionStub = stub +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionPermissionArgsForCall(i int) *livekit.SubscriptionPermission { + fake.handleUpdateSubscriptionPermissionMutex.RLock() + defer fake.handleUpdateSubscriptionPermissionMutex.RUnlock() + argsForCall := fake.handleUpdateSubscriptionPermissionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionPermissionReturns(result1 error) { + fake.handleUpdateSubscriptionPermissionMutex.Lock() + defer fake.handleUpdateSubscriptionPermissionMutex.Unlock() + fake.HandleUpdateSubscriptionPermissionStub = nil + fake.handleUpdateSubscriptionPermissionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionPermissionReturnsOnCall(i int, result1 error) { + fake.handleUpdateSubscriptionPermissionMutex.Lock() + defer fake.handleUpdateSubscriptionPermissionMutex.Unlock() + fake.HandleUpdateSubscriptionPermissionStub = nil + if fake.handleUpdateSubscriptionPermissionReturnsOnCall == nil { + fake.handleUpdateSubscriptionPermissionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleUpdateSubscriptionPermissionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptions(arg1 []livekit.TrackID, arg2 []*livekit.ParticipantTracks, arg3 bool) { + var arg1Copy []livekit.TrackID + if arg1 != nil { + arg1Copy = make([]livekit.TrackID, len(arg1)) + copy(arg1Copy, arg1) + } + var arg2Copy []*livekit.ParticipantTracks + if arg2 != nil { + arg2Copy = make([]*livekit.ParticipantTracks, len(arg2)) + copy(arg2Copy, arg2) + } + fake.handleUpdateSubscriptionsMutex.Lock() + fake.handleUpdateSubscriptionsArgsForCall = append(fake.handleUpdateSubscriptionsArgsForCall, struct { + arg1 []livekit.TrackID + arg2 []*livekit.ParticipantTracks + arg3 bool + }{arg1Copy, arg2Copy, arg3}) + stub := fake.HandleUpdateSubscriptionsStub + fake.recordInvocation("HandleUpdateSubscriptions", []interface{}{arg1Copy, arg2Copy, arg3}) + fake.handleUpdateSubscriptionsMutex.Unlock() + if stub != nil { + fake.HandleUpdateSubscriptionsStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionsCallCount() int { + fake.handleUpdateSubscriptionsMutex.RLock() + defer fake.handleUpdateSubscriptionsMutex.RUnlock() + return len(fake.handleUpdateSubscriptionsArgsForCall) +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionsCalls(stub func([]livekit.TrackID, []*livekit.ParticipantTracks, bool)) { + fake.handleUpdateSubscriptionsMutex.Lock() + defer fake.handleUpdateSubscriptionsMutex.Unlock() + fake.HandleUpdateSubscriptionsStub = stub +} + +func (fake *FakeLocalParticipant) HandleUpdateSubscriptionsArgsForCall(i int) ([]livekit.TrackID, []*livekit.ParticipantTracks, bool) { + fake.handleUpdateSubscriptionsMutex.RLock() + defer fake.handleUpdateSubscriptionsMutex.RUnlock() + argsForCall := fake.handleUpdateSubscriptionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) HasConnected() bool { + fake.hasConnectedMutex.Lock() + ret, specificReturn := fake.hasConnectedReturnsOnCall[len(fake.hasConnectedArgsForCall)] + fake.hasConnectedArgsForCall = append(fake.hasConnectedArgsForCall, struct { + }{}) + stub := fake.HasConnectedStub + fakeReturns := fake.hasConnectedReturns + fake.recordInvocation("HasConnected", []interface{}{}) + fake.hasConnectedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HasConnectedCallCount() int { + fake.hasConnectedMutex.RLock() + defer fake.hasConnectedMutex.RUnlock() + return len(fake.hasConnectedArgsForCall) +} + +func (fake *FakeLocalParticipant) HasConnectedCalls(stub func() bool) { + fake.hasConnectedMutex.Lock() + defer fake.hasConnectedMutex.Unlock() + fake.HasConnectedStub = stub +} + +func (fake *FakeLocalParticipant) HasConnectedReturns(result1 bool) { + fake.hasConnectedMutex.Lock() + defer fake.hasConnectedMutex.Unlock() + fake.HasConnectedStub = nil + fake.hasConnectedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) HasConnectedReturnsOnCall(i int, result1 bool) { + fake.hasConnectedMutex.Lock() + defer fake.hasConnectedMutex.Unlock() + fake.HasConnectedStub = nil + if fake.hasConnectedReturnsOnCall == nil { + fake.hasConnectedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasConnectedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) HasPermission(arg1 livekit.TrackID, arg2 livekit.ParticipantIdentity) bool { + fake.hasPermissionMutex.Lock() + ret, specificReturn := fake.hasPermissionReturnsOnCall[len(fake.hasPermissionArgsForCall)] + fake.hasPermissionArgsForCall = append(fake.hasPermissionArgsForCall, struct { + arg1 livekit.TrackID + arg2 livekit.ParticipantIdentity + }{arg1, arg2}) + stub := fake.HasPermissionStub + fakeReturns := fake.hasPermissionReturns + fake.recordInvocation("HasPermission", []interface{}{arg1, arg2}) + fake.hasPermissionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HasPermissionCallCount() int { + fake.hasPermissionMutex.RLock() + defer fake.hasPermissionMutex.RUnlock() + return len(fake.hasPermissionArgsForCall) +} + +func (fake *FakeLocalParticipant) HasPermissionCalls(stub func(livekit.TrackID, livekit.ParticipantIdentity) bool) { + fake.hasPermissionMutex.Lock() + defer fake.hasPermissionMutex.Unlock() + fake.HasPermissionStub = stub +} + +func (fake *FakeLocalParticipant) HasPermissionArgsForCall(i int) (livekit.TrackID, livekit.ParticipantIdentity) { + fake.hasPermissionMutex.RLock() + defer fake.hasPermissionMutex.RUnlock() + argsForCall := fake.hasPermissionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) HasPermissionReturns(result1 bool) { + fake.hasPermissionMutex.Lock() + defer fake.hasPermissionMutex.Unlock() + fake.HasPermissionStub = nil + fake.hasPermissionReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) HasPermissionReturnsOnCall(i int, result1 bool) { + fake.hasPermissionMutex.Lock() + defer fake.hasPermissionMutex.Unlock() + fake.HasPermissionStub = nil + if fake.hasPermissionReturnsOnCall == nil { + fake.hasPermissionReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasPermissionReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) Hidden() bool { + fake.hiddenMutex.Lock() + ret, specificReturn := fake.hiddenReturnsOnCall[len(fake.hiddenArgsForCall)] + fake.hiddenArgsForCall = append(fake.hiddenArgsForCall, struct { + }{}) + stub := fake.HiddenStub + fakeReturns := fake.hiddenReturns + fake.recordInvocation("Hidden", []interface{}{}) + fake.hiddenMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) HiddenCallCount() int { + fake.hiddenMutex.RLock() + defer fake.hiddenMutex.RUnlock() + return len(fake.hiddenArgsForCall) +} + +func (fake *FakeLocalParticipant) HiddenCalls(stub func() bool) { + fake.hiddenMutex.Lock() + defer fake.hiddenMutex.Unlock() + fake.HiddenStub = stub +} + +func (fake *FakeLocalParticipant) HiddenReturns(result1 bool) { + fake.hiddenMutex.Lock() + defer fake.hiddenMutex.Unlock() + fake.HiddenStub = nil + fake.hiddenReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) HiddenReturnsOnCall(i int, result1 bool) { + fake.hiddenMutex.Lock() + defer fake.hiddenMutex.Unlock() + fake.HiddenStub = nil + if fake.hiddenReturnsOnCall == nil { + fake.hiddenReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hiddenReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) ICERestart(arg1 *livekit.ICEConfig) { + fake.iCERestartMutex.Lock() + fake.iCERestartArgsForCall = append(fake.iCERestartArgsForCall, struct { + arg1 *livekit.ICEConfig + }{arg1}) + stub := fake.ICERestartStub + fake.recordInvocation("ICERestart", []interface{}{arg1}) + fake.iCERestartMutex.Unlock() + if stub != nil { + fake.ICERestartStub(arg1) + } +} + +func (fake *FakeLocalParticipant) ICERestartCallCount() int { + fake.iCERestartMutex.RLock() + defer fake.iCERestartMutex.RUnlock() + return len(fake.iCERestartArgsForCall) +} + +func (fake *FakeLocalParticipant) ICERestartCalls(stub func(*livekit.ICEConfig)) { + fake.iCERestartMutex.Lock() + defer fake.iCERestartMutex.Unlock() + fake.ICERestartStub = stub +} + +func (fake *FakeLocalParticipant) ICERestartArgsForCall(i int) *livekit.ICEConfig { + fake.iCERestartMutex.RLock() + defer fake.iCERestartMutex.RUnlock() + argsForCall := fake.iCERestartArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) ID() livekit.ParticipantID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeLocalParticipant) IDCalls(stub func() livekit.ParticipantID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeLocalParticipant) IDReturns(result1 livekit.ParticipantID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalParticipant) IDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeLocalParticipant) Identity() livekit.ParticipantIdentity { + fake.identityMutex.Lock() + ret, specificReturn := fake.identityReturnsOnCall[len(fake.identityArgsForCall)] + fake.identityArgsForCall = append(fake.identityArgsForCall, struct { + }{}) + stub := fake.IdentityStub + fakeReturns := fake.identityReturns + fake.recordInvocation("Identity", []interface{}{}) + fake.identityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IdentityCallCount() int { + fake.identityMutex.RLock() + defer fake.identityMutex.RUnlock() + return len(fake.identityArgsForCall) +} + +func (fake *FakeLocalParticipant) IdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.identityMutex.Lock() + defer fake.identityMutex.Unlock() + fake.IdentityStub = stub +} + +func (fake *FakeLocalParticipant) IdentityReturns(result1 livekit.ParticipantIdentity) { + fake.identityMutex.Lock() + defer fake.identityMutex.Unlock() + fake.IdentityStub = nil + fake.identityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeLocalParticipant) IdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.identityMutex.Lock() + defer fake.identityMutex.Unlock() + fake.IdentityStub = nil + if fake.identityReturnsOnCall == nil { + fake.identityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.identityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeLocalParticipant) IsAgent() bool { + fake.isAgentMutex.Lock() + ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)] + fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct { + }{}) + stub := fake.IsAgentStub + fakeReturns := fake.isAgentReturns + fake.recordInvocation("IsAgent", []interface{}{}) + fake.isAgentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsAgentCallCount() int { + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() + return len(fake.isAgentArgsForCall) +} + +func (fake *FakeLocalParticipant) IsAgentCalls(stub func() bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = stub +} + +func (fake *FakeLocalParticipant) IsAgentReturns(result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + fake.isAgentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsAgentReturnsOnCall(i int, result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + if fake.isAgentReturnsOnCall == nil { + fake.isAgentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAgentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsClosed() bool { + fake.isClosedMutex.Lock() + ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] + fake.isClosedArgsForCall = append(fake.isClosedArgsForCall, struct { + }{}) + stub := fake.IsClosedStub + fakeReturns := fake.isClosedReturns + fake.recordInvocation("IsClosed", []interface{}{}) + fake.isClosedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsClosedCallCount() int { + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() + return len(fake.isClosedArgsForCall) +} + +func (fake *FakeLocalParticipant) IsClosedCalls(stub func() bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = stub +} + +func (fake *FakeLocalParticipant) IsClosedReturns(result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + fake.isClosedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsClosedReturnsOnCall(i int, result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + if fake.isClosedReturnsOnCall == nil { + fake.isClosedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isClosedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsDependent() bool { + fake.isDependentMutex.Lock() + ret, specificReturn := fake.isDependentReturnsOnCall[len(fake.isDependentArgsForCall)] + fake.isDependentArgsForCall = append(fake.isDependentArgsForCall, struct { + }{}) + stub := fake.IsDependentStub + fakeReturns := fake.isDependentReturns + fake.recordInvocation("IsDependent", []interface{}{}) + fake.isDependentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsDependentCallCount() int { + fake.isDependentMutex.RLock() + defer fake.isDependentMutex.RUnlock() + return len(fake.isDependentArgsForCall) +} + +func (fake *FakeLocalParticipant) IsDependentCalls(stub func() bool) { + fake.isDependentMutex.Lock() + defer fake.isDependentMutex.Unlock() + fake.IsDependentStub = stub +} + +func (fake *FakeLocalParticipant) IsDependentReturns(result1 bool) { + fake.isDependentMutex.Lock() + defer fake.isDependentMutex.Unlock() + fake.IsDependentStub = nil + fake.isDependentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsDependentReturnsOnCall(i int, result1 bool) { + fake.isDependentMutex.Lock() + defer fake.isDependentMutex.Unlock() + fake.IsDependentStub = nil + if fake.isDependentReturnsOnCall == nil { + fake.isDependentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isDependentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsDisconnected() bool { + fake.isDisconnectedMutex.Lock() + ret, specificReturn := fake.isDisconnectedReturnsOnCall[len(fake.isDisconnectedArgsForCall)] + fake.isDisconnectedArgsForCall = append(fake.isDisconnectedArgsForCall, struct { + }{}) + stub := fake.IsDisconnectedStub + fakeReturns := fake.isDisconnectedReturns + fake.recordInvocation("IsDisconnected", []interface{}{}) + fake.isDisconnectedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsDisconnectedCallCount() int { + fake.isDisconnectedMutex.RLock() + defer fake.isDisconnectedMutex.RUnlock() + return len(fake.isDisconnectedArgsForCall) +} + +func (fake *FakeLocalParticipant) IsDisconnectedCalls(stub func() bool) { + fake.isDisconnectedMutex.Lock() + defer fake.isDisconnectedMutex.Unlock() + fake.IsDisconnectedStub = stub +} + +func (fake *FakeLocalParticipant) IsDisconnectedReturns(result1 bool) { + fake.isDisconnectedMutex.Lock() + defer fake.isDisconnectedMutex.Unlock() + fake.IsDisconnectedStub = nil + fake.isDisconnectedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsDisconnectedReturnsOnCall(i int, result1 bool) { + fake.isDisconnectedMutex.Lock() + defer fake.isDisconnectedMutex.Unlock() + fake.IsDisconnectedStub = nil + if fake.isDisconnectedReturnsOnCall == nil { + fake.isDisconnectedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isDisconnectedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsIdle() bool { + fake.isIdleMutex.Lock() + ret, specificReturn := fake.isIdleReturnsOnCall[len(fake.isIdleArgsForCall)] + fake.isIdleArgsForCall = append(fake.isIdleArgsForCall, struct { + }{}) + stub := fake.IsIdleStub + fakeReturns := fake.isIdleReturns + fake.recordInvocation("IsIdle", []interface{}{}) + fake.isIdleMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsIdleCallCount() int { + fake.isIdleMutex.RLock() + defer fake.isIdleMutex.RUnlock() + return len(fake.isIdleArgsForCall) +} + +func (fake *FakeLocalParticipant) IsIdleCalls(stub func() bool) { + fake.isIdleMutex.Lock() + defer fake.isIdleMutex.Unlock() + fake.IsIdleStub = stub +} + +func (fake *FakeLocalParticipant) IsIdleReturns(result1 bool) { + fake.isIdleMutex.Lock() + defer fake.isIdleMutex.Unlock() + fake.IsIdleStub = nil + fake.isIdleReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsIdleReturnsOnCall(i int, result1 bool) { + fake.isIdleMutex.Lock() + defer fake.isIdleMutex.Unlock() + fake.IsIdleStub = nil + if fake.isIdleReturnsOnCall == nil { + fake.isIdleReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isIdleReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsPublisher() bool { + fake.isPublisherMutex.Lock() + ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)] + fake.isPublisherArgsForCall = append(fake.isPublisherArgsForCall, struct { + }{}) + stub := fake.IsPublisherStub + fakeReturns := fake.isPublisherReturns + fake.recordInvocation("IsPublisher", []interface{}{}) + fake.isPublisherMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsPublisherCallCount() int { + fake.isPublisherMutex.RLock() + defer fake.isPublisherMutex.RUnlock() + return len(fake.isPublisherArgsForCall) +} + +func (fake *FakeLocalParticipant) IsPublisherCalls(stub func() bool) { + fake.isPublisherMutex.Lock() + defer fake.isPublisherMutex.Unlock() + fake.IsPublisherStub = stub +} + +func (fake *FakeLocalParticipant) IsPublisherReturns(result1 bool) { + fake.isPublisherMutex.Lock() + defer fake.isPublisherMutex.Unlock() + fake.IsPublisherStub = nil + fake.isPublisherReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsPublisherReturnsOnCall(i int, result1 bool) { + fake.isPublisherMutex.Lock() + defer fake.isPublisherMutex.Unlock() + fake.IsPublisherStub = nil + if fake.isPublisherReturnsOnCall == nil { + fake.isPublisherReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isPublisherReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsReady() bool { + fake.isReadyMutex.Lock() + ret, specificReturn := fake.isReadyReturnsOnCall[len(fake.isReadyArgsForCall)] + fake.isReadyArgsForCall = append(fake.isReadyArgsForCall, struct { + }{}) + stub := fake.IsReadyStub + fakeReturns := fake.isReadyReturns + fake.recordInvocation("IsReady", []interface{}{}) + fake.isReadyMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsReadyCallCount() int { + fake.isReadyMutex.RLock() + defer fake.isReadyMutex.RUnlock() + return len(fake.isReadyArgsForCall) +} + +func (fake *FakeLocalParticipant) IsReadyCalls(stub func() bool) { + fake.isReadyMutex.Lock() + defer fake.isReadyMutex.Unlock() + fake.IsReadyStub = stub +} + +func (fake *FakeLocalParticipant) IsReadyReturns(result1 bool) { + fake.isReadyMutex.Lock() + defer fake.isReadyMutex.Unlock() + fake.IsReadyStub = nil + fake.isReadyReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsReadyReturnsOnCall(i int, result1 bool) { + fake.isReadyMutex.Lock() + defer fake.isReadyMutex.Unlock() + fake.IsReadyStub = nil + if fake.isReadyReturnsOnCall == nil { + fake.isReadyReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isReadyReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsReconnect() bool { + fake.isReconnectMutex.Lock() + ret, specificReturn := fake.isReconnectReturnsOnCall[len(fake.isReconnectArgsForCall)] + fake.isReconnectArgsForCall = append(fake.isReconnectArgsForCall, struct { + }{}) + stub := fake.IsReconnectStub + fakeReturns := fake.isReconnectReturns + fake.recordInvocation("IsReconnect", []interface{}{}) + fake.isReconnectMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsReconnectCallCount() int { + fake.isReconnectMutex.RLock() + defer fake.isReconnectMutex.RUnlock() + return len(fake.isReconnectArgsForCall) +} + +func (fake *FakeLocalParticipant) IsReconnectCalls(stub func() bool) { + fake.isReconnectMutex.Lock() + defer fake.isReconnectMutex.Unlock() + fake.IsReconnectStub = stub +} + +func (fake *FakeLocalParticipant) IsReconnectReturns(result1 bool) { + fake.isReconnectMutex.Lock() + defer fake.isReconnectMutex.Unlock() + fake.IsReconnectStub = nil + fake.isReconnectReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsReconnectReturnsOnCall(i int, result1 bool) { + fake.isReconnectMutex.Lock() + defer fake.isReconnectMutex.Unlock() + fake.IsReconnectStub = nil + if fake.isReconnectReturnsOnCall == nil { + fake.isReconnectReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isReconnectReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsRecorder() bool { + fake.isRecorderMutex.Lock() + ret, specificReturn := fake.isRecorderReturnsOnCall[len(fake.isRecorderArgsForCall)] + fake.isRecorderArgsForCall = append(fake.isRecorderArgsForCall, struct { + }{}) + stub := fake.IsRecorderStub + fakeReturns := fake.isRecorderReturns + fake.recordInvocation("IsRecorder", []interface{}{}) + fake.isRecorderMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsRecorderCallCount() int { + fake.isRecorderMutex.RLock() + defer fake.isRecorderMutex.RUnlock() + return len(fake.isRecorderArgsForCall) +} + +func (fake *FakeLocalParticipant) IsRecorderCalls(stub func() bool) { + fake.isRecorderMutex.Lock() + defer fake.isRecorderMutex.Unlock() + fake.IsRecorderStub = stub +} + +func (fake *FakeLocalParticipant) IsRecorderReturns(result1 bool) { + fake.isRecorderMutex.Lock() + defer fake.isRecorderMutex.Unlock() + fake.IsRecorderStub = nil + fake.isRecorderReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsRecorderReturnsOnCall(i int, result1 bool) { + fake.isRecorderMutex.Lock() + defer fake.isRecorderMutex.Unlock() + fake.IsRecorderStub = nil + if fake.isRecorderReturnsOnCall == nil { + fake.isRecorderReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isRecorderReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsSubscribedTo(arg1 livekit.ParticipantID) bool { + fake.isSubscribedToMutex.Lock() + ret, specificReturn := fake.isSubscribedToReturnsOnCall[len(fake.isSubscribedToArgsForCall)] + fake.isSubscribedToArgsForCall = append(fake.isSubscribedToArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.IsSubscribedToStub + fakeReturns := fake.isSubscribedToReturns + fake.recordInvocation("IsSubscribedTo", []interface{}{arg1}) + fake.isSubscribedToMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsSubscribedToCallCount() int { + fake.isSubscribedToMutex.RLock() + defer fake.isSubscribedToMutex.RUnlock() + return len(fake.isSubscribedToArgsForCall) +} + +func (fake *FakeLocalParticipant) IsSubscribedToCalls(stub func(livekit.ParticipantID) bool) { + fake.isSubscribedToMutex.Lock() + defer fake.isSubscribedToMutex.Unlock() + fake.IsSubscribedToStub = stub +} + +func (fake *FakeLocalParticipant) IsSubscribedToArgsForCall(i int) livekit.ParticipantID { + fake.isSubscribedToMutex.RLock() + defer fake.isSubscribedToMutex.RUnlock() + argsForCall := fake.isSubscribedToArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) IsSubscribedToReturns(result1 bool) { + fake.isSubscribedToMutex.Lock() + defer fake.isSubscribedToMutex.Unlock() + fake.IsSubscribedToStub = nil + fake.isSubscribedToReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsSubscribedToReturnsOnCall(i int, result1 bool) { + fake.isSubscribedToMutex.Lock() + defer fake.isSubscribedToMutex.Unlock() + fake.IsSubscribedToStub = nil + if fake.isSubscribedToReturnsOnCall == nil { + fake.isSubscribedToReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isSubscribedToReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsTrackNameSubscribed(arg1 livekit.ParticipantIdentity, arg2 string) bool { + fake.isTrackNameSubscribedMutex.Lock() + ret, specificReturn := fake.isTrackNameSubscribedReturnsOnCall[len(fake.isTrackNameSubscribedArgsForCall)] + fake.isTrackNameSubscribedArgsForCall = append(fake.isTrackNameSubscribedArgsForCall, struct { + arg1 livekit.ParticipantIdentity + arg2 string + }{arg1, arg2}) + stub := fake.IsTrackNameSubscribedStub + fakeReturns := fake.isTrackNameSubscribedReturns + fake.recordInvocation("IsTrackNameSubscribed", []interface{}{arg1, arg2}) + fake.isTrackNameSubscribedMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsTrackNameSubscribedCallCount() int { + fake.isTrackNameSubscribedMutex.RLock() + defer fake.isTrackNameSubscribedMutex.RUnlock() + return len(fake.isTrackNameSubscribedArgsForCall) +} + +func (fake *FakeLocalParticipant) IsTrackNameSubscribedCalls(stub func(livekit.ParticipantIdentity, string) bool) { + fake.isTrackNameSubscribedMutex.Lock() + defer fake.isTrackNameSubscribedMutex.Unlock() + fake.IsTrackNameSubscribedStub = stub +} + +func (fake *FakeLocalParticipant) IsTrackNameSubscribedArgsForCall(i int) (livekit.ParticipantIdentity, string) { + fake.isTrackNameSubscribedMutex.RLock() + defer fake.isTrackNameSubscribedMutex.RUnlock() + argsForCall := fake.isTrackNameSubscribedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) IsTrackNameSubscribedReturns(result1 bool) { + fake.isTrackNameSubscribedMutex.Lock() + defer fake.isTrackNameSubscribedMutex.Unlock() + fake.IsTrackNameSubscribedStub = nil + fake.isTrackNameSubscribedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsTrackNameSubscribedReturnsOnCall(i int, result1 bool) { + fake.isTrackNameSubscribedMutex.Lock() + defer fake.isTrackNameSubscribedMutex.Unlock() + fake.IsTrackNameSubscribedStub = nil + if fake.isTrackNameSubscribedReturnsOnCall == nil { + fake.isTrackNameSubscribedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isTrackNameSubscribedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnection() bool { + fake.isUsingSinglePeerConnectionMutex.Lock() + ret, specificReturn := fake.isUsingSinglePeerConnectionReturnsOnCall[len(fake.isUsingSinglePeerConnectionArgsForCall)] + fake.isUsingSinglePeerConnectionArgsForCall = append(fake.isUsingSinglePeerConnectionArgsForCall, struct { + }{}) + stub := fake.IsUsingSinglePeerConnectionStub + fakeReturns := fake.isUsingSinglePeerConnectionReturns + fake.recordInvocation("IsUsingSinglePeerConnection", []interface{}{}) + fake.isUsingSinglePeerConnectionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionCallCount() int { + fake.isUsingSinglePeerConnectionMutex.RLock() + defer fake.isUsingSinglePeerConnectionMutex.RUnlock() + return len(fake.isUsingSinglePeerConnectionArgsForCall) +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionCalls(stub func() bool) { + fake.isUsingSinglePeerConnectionMutex.Lock() + defer fake.isUsingSinglePeerConnectionMutex.Unlock() + fake.IsUsingSinglePeerConnectionStub = stub +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionReturns(result1 bool) { + fake.isUsingSinglePeerConnectionMutex.Lock() + defer fake.isUsingSinglePeerConnectionMutex.Unlock() + fake.IsUsingSinglePeerConnectionStub = nil + fake.isUsingSinglePeerConnectionReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IsUsingSinglePeerConnectionReturnsOnCall(i int, result1 bool) { + fake.isUsingSinglePeerConnectionMutex.Lock() + defer fake.isUsingSinglePeerConnectionMutex.Unlock() + fake.IsUsingSinglePeerConnectionStub = nil + if fake.isUsingSinglePeerConnectionReturnsOnCall == nil { + fake.isUsingSinglePeerConnectionReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isUsingSinglePeerConnectionReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) IssueFullReconnect(arg1 types.ParticipantCloseReason) { + fake.issueFullReconnectMutex.Lock() + fake.issueFullReconnectArgsForCall = append(fake.issueFullReconnectArgsForCall, struct { + arg1 types.ParticipantCloseReason + }{arg1}) + stub := fake.IssueFullReconnectStub + fake.recordInvocation("IssueFullReconnect", []interface{}{arg1}) + fake.issueFullReconnectMutex.Unlock() + if stub != nil { + fake.IssueFullReconnectStub(arg1) + } +} + +func (fake *FakeLocalParticipant) IssueFullReconnectCallCount() int { + fake.issueFullReconnectMutex.RLock() + defer fake.issueFullReconnectMutex.RUnlock() + return len(fake.issueFullReconnectArgsForCall) +} + +func (fake *FakeLocalParticipant) IssueFullReconnectCalls(stub func(types.ParticipantCloseReason)) { + fake.issueFullReconnectMutex.Lock() + defer fake.issueFullReconnectMutex.Unlock() + fake.IssueFullReconnectStub = stub +} + +func (fake *FakeLocalParticipant) IssueFullReconnectArgsForCall(i int) types.ParticipantCloseReason { + fake.issueFullReconnectMutex.RLock() + defer fake.issueFullReconnectMutex.RUnlock() + argsForCall := fake.issueFullReconnectArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) Kind() livekit.ParticipantInfo_Kind { + fake.kindMutex.Lock() + ret, specificReturn := fake.kindReturnsOnCall[len(fake.kindArgsForCall)] + fake.kindArgsForCall = append(fake.kindArgsForCall, struct { + }{}) + stub := fake.KindStub + fakeReturns := fake.kindReturns + fake.recordInvocation("Kind", []interface{}{}) + fake.kindMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) KindCallCount() int { + fake.kindMutex.RLock() + defer fake.kindMutex.RUnlock() + return len(fake.kindArgsForCall) +} + +func (fake *FakeLocalParticipant) KindCalls(stub func() livekit.ParticipantInfo_Kind) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = stub +} + +func (fake *FakeLocalParticipant) KindReturns(result1 livekit.ParticipantInfo_Kind) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + fake.kindReturns = struct { + result1 livekit.ParticipantInfo_Kind + }{result1} +} + +func (fake *FakeLocalParticipant) KindReturnsOnCall(i int, result1 livekit.ParticipantInfo_Kind) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + if fake.kindReturnsOnCall == nil { + fake.kindReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantInfo_Kind + }) + } + fake.kindReturnsOnCall[i] = struct { + result1 livekit.ParticipantInfo_Kind + }{result1} +} + +func (fake *FakeLocalParticipant) MaybeStartMigration(arg1 bool, arg2 func()) bool { + fake.maybeStartMigrationMutex.Lock() + ret, specificReturn := fake.maybeStartMigrationReturnsOnCall[len(fake.maybeStartMigrationArgsForCall)] + fake.maybeStartMigrationArgsForCall = append(fake.maybeStartMigrationArgsForCall, struct { + arg1 bool + arg2 func() + }{arg1, arg2}) + stub := fake.MaybeStartMigrationStub + fakeReturns := fake.maybeStartMigrationReturns + fake.recordInvocation("MaybeStartMigration", []interface{}{arg1, arg2}) + fake.maybeStartMigrationMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) MaybeStartMigrationCallCount() int { + fake.maybeStartMigrationMutex.RLock() + defer fake.maybeStartMigrationMutex.RUnlock() + return len(fake.maybeStartMigrationArgsForCall) +} + +func (fake *FakeLocalParticipant) MaybeStartMigrationCalls(stub func(bool, func()) bool) { + fake.maybeStartMigrationMutex.Lock() + defer fake.maybeStartMigrationMutex.Unlock() + fake.MaybeStartMigrationStub = stub +} + +func (fake *FakeLocalParticipant) MaybeStartMigrationArgsForCall(i int) (bool, func()) { + fake.maybeStartMigrationMutex.RLock() + defer fake.maybeStartMigrationMutex.RUnlock() + argsForCall := fake.maybeStartMigrationArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) MaybeStartMigrationReturns(result1 bool) { + fake.maybeStartMigrationMutex.Lock() + defer fake.maybeStartMigrationMutex.Unlock() + fake.MaybeStartMigrationStub = nil + fake.maybeStartMigrationReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) MaybeStartMigrationReturnsOnCall(i int, result1 bool) { + fake.maybeStartMigrationMutex.Lock() + defer fake.maybeStartMigrationMutex.Unlock() + fake.MaybeStartMigrationStub = nil + if fake.maybeStartMigrationReturnsOnCall == nil { + fake.maybeStartMigrationReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.maybeStartMigrationReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) MigrateState() types.MigrateState { + fake.migrateStateMutex.Lock() + ret, specificReturn := fake.migrateStateReturnsOnCall[len(fake.migrateStateArgsForCall)] + fake.migrateStateArgsForCall = append(fake.migrateStateArgsForCall, struct { + }{}) + stub := fake.MigrateStateStub + fakeReturns := fake.migrateStateReturns + fake.recordInvocation("MigrateState", []interface{}{}) + fake.migrateStateMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) MigrateStateCallCount() int { + fake.migrateStateMutex.RLock() + defer fake.migrateStateMutex.RUnlock() + return len(fake.migrateStateArgsForCall) +} + +func (fake *FakeLocalParticipant) MigrateStateCalls(stub func() types.MigrateState) { + fake.migrateStateMutex.Lock() + defer fake.migrateStateMutex.Unlock() + fake.MigrateStateStub = stub +} + +func (fake *FakeLocalParticipant) MigrateStateReturns(result1 types.MigrateState) { + fake.migrateStateMutex.Lock() + defer fake.migrateStateMutex.Unlock() + fake.MigrateStateStub = nil + fake.migrateStateReturns = struct { + result1 types.MigrateState + }{result1} +} + +func (fake *FakeLocalParticipant) MigrateStateReturnsOnCall(i int, result1 types.MigrateState) { + fake.migrateStateMutex.Lock() + defer fake.migrateStateMutex.Unlock() + fake.MigrateStateStub = nil + if fake.migrateStateReturnsOnCall == nil { + fake.migrateStateReturnsOnCall = make(map[int]struct { + result1 types.MigrateState + }) + } + fake.migrateStateReturnsOnCall[i] = struct { + result1 types.MigrateState + }{result1} +} + +func (fake *FakeLocalParticipant) MoveToRoom(arg1 types.MoveToRoomParams) { + fake.moveToRoomMutex.Lock() + fake.moveToRoomArgsForCall = append(fake.moveToRoomArgsForCall, struct { + arg1 types.MoveToRoomParams + }{arg1}) + stub := fake.MoveToRoomStub + fake.recordInvocation("MoveToRoom", []interface{}{arg1}) + fake.moveToRoomMutex.Unlock() + if stub != nil { + fake.MoveToRoomStub(arg1) + } +} + +func (fake *FakeLocalParticipant) MoveToRoomCallCount() int { + fake.moveToRoomMutex.RLock() + defer fake.moveToRoomMutex.RUnlock() + return len(fake.moveToRoomArgsForCall) +} + +func (fake *FakeLocalParticipant) MoveToRoomCalls(stub func(types.MoveToRoomParams)) { + fake.moveToRoomMutex.Lock() + defer fake.moveToRoomMutex.Unlock() + fake.MoveToRoomStub = stub +} + +func (fake *FakeLocalParticipant) MoveToRoomArgsForCall(i int) types.MoveToRoomParams { + fake.moveToRoomMutex.RLock() + defer fake.moveToRoomMutex.RUnlock() + argsForCall := fake.moveToRoomArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) Negotiate(arg1 bool) { + fake.negotiateMutex.Lock() + fake.negotiateArgsForCall = append(fake.negotiateArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.NegotiateStub + fake.recordInvocation("Negotiate", []interface{}{arg1}) + fake.negotiateMutex.Unlock() + if stub != nil { + fake.NegotiateStub(arg1) + } +} + +func (fake *FakeLocalParticipant) NegotiateCallCount() int { + fake.negotiateMutex.RLock() + defer fake.negotiateMutex.RUnlock() + return len(fake.negotiateArgsForCall) +} + +func (fake *FakeLocalParticipant) NegotiateCalls(stub func(bool)) { + fake.negotiateMutex.Lock() + defer fake.negotiateMutex.Unlock() + fake.NegotiateStub = stub +} + +func (fake *FakeLocalParticipant) NegotiateArgsForCall(i int) bool { + fake.negotiateMutex.RLock() + defer fake.negotiateMutex.RUnlock() + argsForCall := fake.negotiateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) NotifyMigration() { + fake.notifyMigrationMutex.Lock() + fake.notifyMigrationArgsForCall = append(fake.notifyMigrationArgsForCall, struct { + }{}) + stub := fake.NotifyMigrationStub + fake.recordInvocation("NotifyMigration", []interface{}{}) + fake.notifyMigrationMutex.Unlock() + if stub != nil { + fake.NotifyMigrationStub() + } +} + +func (fake *FakeLocalParticipant) NotifyMigrationCallCount() int { + fake.notifyMigrationMutex.RLock() + defer fake.notifyMigrationMutex.RUnlock() + return len(fake.notifyMigrationArgsForCall) +} + +func (fake *FakeLocalParticipant) NotifyMigrationCalls(stub func()) { + fake.notifyMigrationMutex.Lock() + defer fake.notifyMigrationMutex.Unlock() + fake.NotifyMigrationStub = stub +} + +func (fake *FakeLocalParticipant) OnClaimsChanged(arg1 func(types.LocalParticipant)) { + fake.onClaimsChangedMutex.Lock() + fake.onClaimsChangedArgsForCall = append(fake.onClaimsChangedArgsForCall, struct { + arg1 func(types.LocalParticipant) + }{arg1}) + stub := fake.OnClaimsChangedStub + fake.recordInvocation("OnClaimsChanged", []interface{}{arg1}) + fake.onClaimsChangedMutex.Unlock() + if stub != nil { + fake.OnClaimsChangedStub(arg1) + } +} + +func (fake *FakeLocalParticipant) OnClaimsChangedCallCount() int { + fake.onClaimsChangedMutex.RLock() + defer fake.onClaimsChangedMutex.RUnlock() + return len(fake.onClaimsChangedArgsForCall) +} + +func (fake *FakeLocalParticipant) OnClaimsChangedCalls(stub func(func(types.LocalParticipant))) { + fake.onClaimsChangedMutex.Lock() + defer fake.onClaimsChangedMutex.Unlock() + fake.OnClaimsChangedStub = stub +} + +func (fake *FakeLocalParticipant) OnClaimsChangedArgsForCall(i int) func(types.LocalParticipant) { + fake.onClaimsChangedMutex.RLock() + defer fake.onClaimsChangedMutex.RUnlock() + argsForCall := fake.onClaimsChangedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) OnICEConfigChanged(arg1 func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig)) { + fake.onICEConfigChangedMutex.Lock() + fake.onICEConfigChangedArgsForCall = append(fake.onICEConfigChangedArgsForCall, struct { + arg1 func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig) + }{arg1}) + stub := fake.OnICEConfigChangedStub + fake.recordInvocation("OnICEConfigChanged", []interface{}{arg1}) + fake.onICEConfigChangedMutex.Unlock() + if stub != nil { + fake.OnICEConfigChangedStub(arg1) + } +} + +func (fake *FakeLocalParticipant) OnICEConfigChangedCallCount() int { + fake.onICEConfigChangedMutex.RLock() + defer fake.onICEConfigChangedMutex.RUnlock() + return len(fake.onICEConfigChangedArgsForCall) +} + +func (fake *FakeLocalParticipant) OnICEConfigChangedCalls(stub func(func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig))) { + fake.onICEConfigChangedMutex.Lock() + defer fake.onICEConfigChangedMutex.Unlock() + fake.OnICEConfigChangedStub = stub +} + +func (fake *FakeLocalParticipant) OnICEConfigChangedArgsForCall(i int) func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig) { + fake.onICEConfigChangedMutex.RLock() + defer fake.onICEConfigChangedMutex.RUnlock() + argsForCall := fake.onICEConfigChangedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) PerformRpc(arg1 *livekit.PerformRpcRequest, arg2 chan string, arg3 chan error) { + fake.performRpcMutex.Lock() + fake.performRpcArgsForCall = append(fake.performRpcArgsForCall, struct { + arg1 *livekit.PerformRpcRequest + arg2 chan string + arg3 chan error + }{arg1, arg2, arg3}) + stub := fake.PerformRpcStub + fake.recordInvocation("PerformRpc", []interface{}{arg1, arg2, arg3}) + fake.performRpcMutex.Unlock() + if stub != nil { + fake.PerformRpcStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipant) PerformRpcCallCount() int { + fake.performRpcMutex.RLock() + defer fake.performRpcMutex.RUnlock() + return len(fake.performRpcArgsForCall) +} + +func (fake *FakeLocalParticipant) PerformRpcCalls(stub func(*livekit.PerformRpcRequest, chan string, chan error)) { + fake.performRpcMutex.Lock() + defer fake.performRpcMutex.Unlock() + fake.PerformRpcStub = stub +} + +func (fake *FakeLocalParticipant) PerformRpcArgsForCall(i int) (*livekit.PerformRpcRequest, chan string, chan error) { + fake.performRpcMutex.RLock() + defer fake.performRpcMutex.RUnlock() + argsForCall := fake.performRpcArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) ProtocolVersion() types.ProtocolVersion { + fake.protocolVersionMutex.Lock() + ret, specificReturn := fake.protocolVersionReturnsOnCall[len(fake.protocolVersionArgsForCall)] + fake.protocolVersionArgsForCall = append(fake.protocolVersionArgsForCall, struct { + }{}) + stub := fake.ProtocolVersionStub + fakeReturns := fake.protocolVersionReturns + fake.recordInvocation("ProtocolVersion", []interface{}{}) + fake.protocolVersionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) ProtocolVersionCallCount() int { + fake.protocolVersionMutex.RLock() + defer fake.protocolVersionMutex.RUnlock() + return len(fake.protocolVersionArgsForCall) +} + +func (fake *FakeLocalParticipant) ProtocolVersionCalls(stub func() types.ProtocolVersion) { + fake.protocolVersionMutex.Lock() + defer fake.protocolVersionMutex.Unlock() + fake.ProtocolVersionStub = stub +} + +func (fake *FakeLocalParticipant) ProtocolVersionReturns(result1 types.ProtocolVersion) { + fake.protocolVersionMutex.Lock() + defer fake.protocolVersionMutex.Unlock() + fake.ProtocolVersionStub = nil + fake.protocolVersionReturns = struct { + result1 types.ProtocolVersion + }{result1} +} + +func (fake *FakeLocalParticipant) ProtocolVersionReturnsOnCall(i int, result1 types.ProtocolVersion) { + fake.protocolVersionMutex.Lock() + defer fake.protocolVersionMutex.Unlock() + fake.ProtocolVersionStub = nil + if fake.protocolVersionReturnsOnCall == nil { + fake.protocolVersionReturnsOnCall = make(map[int]struct { + result1 types.ProtocolVersion + }) + } + fake.protocolVersionReturnsOnCall[i] = struct { + result1 types.ProtocolVersion + }{result1} +} + +func (fake *FakeLocalParticipant) RemovePublishedDataTrack(arg1 types.DataTrack) { + fake.removePublishedDataTrackMutex.Lock() + fake.removePublishedDataTrackArgsForCall = append(fake.removePublishedDataTrackArgsForCall, struct { + arg1 types.DataTrack + }{arg1}) + stub := fake.RemovePublishedDataTrackStub + fake.recordInvocation("RemovePublishedDataTrack", []interface{}{arg1}) + fake.removePublishedDataTrackMutex.Unlock() + if stub != nil { + fake.RemovePublishedDataTrackStub(arg1) + } +} + +func (fake *FakeLocalParticipant) RemovePublishedDataTrackCallCount() int { + fake.removePublishedDataTrackMutex.RLock() + defer fake.removePublishedDataTrackMutex.RUnlock() + return len(fake.removePublishedDataTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) RemovePublishedDataTrackCalls(stub func(types.DataTrack)) { + fake.removePublishedDataTrackMutex.Lock() + defer fake.removePublishedDataTrackMutex.Unlock() + fake.RemovePublishedDataTrackStub = stub +} + +func (fake *FakeLocalParticipant) RemovePublishedDataTrackArgsForCall(i int) types.DataTrack { + fake.removePublishedDataTrackMutex.RLock() + defer fake.removePublishedDataTrackMutex.RUnlock() + argsForCall := fake.removePublishedDataTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) RemovePublishedTrack(arg1 types.MediaTrack, arg2 bool) { + fake.removePublishedTrackMutex.Lock() + fake.removePublishedTrackArgsForCall = append(fake.removePublishedTrackArgsForCall, struct { + arg1 types.MediaTrack + arg2 bool + }{arg1, arg2}) + stub := fake.RemovePublishedTrackStub + fake.recordInvocation("RemovePublishedTrack", []interface{}{arg1, arg2}) + fake.removePublishedTrackMutex.Unlock() + if stub != nil { + fake.RemovePublishedTrackStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) RemovePublishedTrackCallCount() int { + fake.removePublishedTrackMutex.RLock() + defer fake.removePublishedTrackMutex.RUnlock() + return len(fake.removePublishedTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) RemovePublishedTrackCalls(stub func(types.MediaTrack, bool)) { + fake.removePublishedTrackMutex.Lock() + defer fake.removePublishedTrackMutex.Unlock() + fake.RemovePublishedTrackStub = stub +} + +func (fake *FakeLocalParticipant) RemovePublishedTrackArgsForCall(i int) (types.MediaTrack, bool) { + fake.removePublishedTrackMutex.RLock() + defer fake.removePublishedTrackMutex.RUnlock() + argsForCall := fake.removePublishedTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) RemoveTrackLocal(arg1 *webrtc.RTPSender) error { + fake.removeTrackLocalMutex.Lock() + ret, specificReturn := fake.removeTrackLocalReturnsOnCall[len(fake.removeTrackLocalArgsForCall)] + fake.removeTrackLocalArgsForCall = append(fake.removeTrackLocalArgsForCall, struct { + arg1 *webrtc.RTPSender + }{arg1}) + stub := fake.RemoveTrackLocalStub + fakeReturns := fake.removeTrackLocalReturns + fake.recordInvocation("RemoveTrackLocal", []interface{}{arg1}) + fake.removeTrackLocalMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) RemoveTrackLocalCallCount() int { + fake.removeTrackLocalMutex.RLock() + defer fake.removeTrackLocalMutex.RUnlock() + return len(fake.removeTrackLocalArgsForCall) +} + +func (fake *FakeLocalParticipant) RemoveTrackLocalCalls(stub func(*webrtc.RTPSender) error) { + fake.removeTrackLocalMutex.Lock() + defer fake.removeTrackLocalMutex.Unlock() + fake.RemoveTrackLocalStub = stub +} + +func (fake *FakeLocalParticipant) RemoveTrackLocalArgsForCall(i int) *webrtc.RTPSender { + fake.removeTrackLocalMutex.RLock() + defer fake.removeTrackLocalMutex.RUnlock() + argsForCall := fake.removeTrackLocalArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) RemoveTrackLocalReturns(result1 error) { + fake.removeTrackLocalMutex.Lock() + defer fake.removeTrackLocalMutex.Unlock() + fake.RemoveTrackLocalStub = nil + fake.removeTrackLocalReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) RemoveTrackLocalReturnsOnCall(i int, result1 error) { + fake.removeTrackLocalMutex.Lock() + defer fake.removeTrackLocalMutex.Unlock() + fake.RemoveTrackLocalStub = nil + if fake.removeTrackLocalReturnsOnCall == nil { + fake.removeTrackLocalReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.removeTrackLocalReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendConnectionQualityUpdate(arg1 *livekit.ConnectionQualityUpdate) error { + fake.sendConnectionQualityUpdateMutex.Lock() + ret, specificReturn := fake.sendConnectionQualityUpdateReturnsOnCall[len(fake.sendConnectionQualityUpdateArgsForCall)] + fake.sendConnectionQualityUpdateArgsForCall = append(fake.sendConnectionQualityUpdateArgsForCall, struct { + arg1 *livekit.ConnectionQualityUpdate + }{arg1}) + stub := fake.SendConnectionQualityUpdateStub + fakeReturns := fake.sendConnectionQualityUpdateReturns + fake.recordInvocation("SendConnectionQualityUpdate", []interface{}{arg1}) + fake.sendConnectionQualityUpdateMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendConnectionQualityUpdateCallCount() int { + fake.sendConnectionQualityUpdateMutex.RLock() + defer fake.sendConnectionQualityUpdateMutex.RUnlock() + return len(fake.sendConnectionQualityUpdateArgsForCall) +} + +func (fake *FakeLocalParticipant) SendConnectionQualityUpdateCalls(stub func(*livekit.ConnectionQualityUpdate) error) { + fake.sendConnectionQualityUpdateMutex.Lock() + defer fake.sendConnectionQualityUpdateMutex.Unlock() + fake.SendConnectionQualityUpdateStub = stub +} + +func (fake *FakeLocalParticipant) SendConnectionQualityUpdateArgsForCall(i int) *livekit.ConnectionQualityUpdate { + fake.sendConnectionQualityUpdateMutex.RLock() + defer fake.sendConnectionQualityUpdateMutex.RUnlock() + argsForCall := fake.sendConnectionQualityUpdateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendConnectionQualityUpdateReturns(result1 error) { + fake.sendConnectionQualityUpdateMutex.Lock() + defer fake.sendConnectionQualityUpdateMutex.Unlock() + fake.SendConnectionQualityUpdateStub = nil + fake.sendConnectionQualityUpdateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendConnectionQualityUpdateReturnsOnCall(i int, result1 error) { + fake.sendConnectionQualityUpdateMutex.Lock() + defer fake.sendConnectionQualityUpdateMutex.Unlock() + fake.SendConnectionQualityUpdateStub = nil + if fake.sendConnectionQualityUpdateReturnsOnCall == nil { + fake.sendConnectionQualityUpdateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendConnectionQualityUpdateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataMessage(arg1 livekit.DataPacket_Kind, arg2 []byte, arg3 livekit.ParticipantID, arg4 uint32) error { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.sendDataMessageMutex.Lock() + ret, specificReturn := fake.sendDataMessageReturnsOnCall[len(fake.sendDataMessageArgsForCall)] + fake.sendDataMessageArgsForCall = append(fake.sendDataMessageArgsForCall, struct { + arg1 livekit.DataPacket_Kind + arg2 []byte + arg3 livekit.ParticipantID + arg4 uint32 + }{arg1, arg2Copy, arg3, arg4}) + stub := fake.SendDataMessageStub + fakeReturns := fake.sendDataMessageReturns + fake.recordInvocation("SendDataMessage", []interface{}{arg1, arg2Copy, arg3, arg4}) + fake.sendDataMessageMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendDataMessageCallCount() int { + fake.sendDataMessageMutex.RLock() + defer fake.sendDataMessageMutex.RUnlock() + return len(fake.sendDataMessageArgsForCall) +} + +func (fake *FakeLocalParticipant) SendDataMessageCalls(stub func(livekit.DataPacket_Kind, []byte, livekit.ParticipantID, uint32) error) { + fake.sendDataMessageMutex.Lock() + defer fake.sendDataMessageMutex.Unlock() + fake.SendDataMessageStub = stub +} + +func (fake *FakeLocalParticipant) SendDataMessageArgsForCall(i int) (livekit.DataPacket_Kind, []byte, livekit.ParticipantID, uint32) { + fake.sendDataMessageMutex.RLock() + defer fake.sendDataMessageMutex.RUnlock() + argsForCall := fake.sendDataMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeLocalParticipant) SendDataMessageReturns(result1 error) { + fake.sendDataMessageMutex.Lock() + defer fake.sendDataMessageMutex.Unlock() + fake.SendDataMessageStub = nil + fake.sendDataMessageReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataMessageReturnsOnCall(i int, result1 error) { + fake.sendDataMessageMutex.Lock() + defer fake.sendDataMessageMutex.Unlock() + fake.SendDataMessageStub = nil + if fake.sendDataMessageReturnsOnCall == nil { + fake.sendDataMessageReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendDataMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeled(arg1 []byte, arg2 bool, arg3 livekit.ParticipantIdentity) error { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.sendDataMessageUnlabeledMutex.Lock() + ret, specificReturn := fake.sendDataMessageUnlabeledReturnsOnCall[len(fake.sendDataMessageUnlabeledArgsForCall)] + fake.sendDataMessageUnlabeledArgsForCall = append(fake.sendDataMessageUnlabeledArgsForCall, struct { + arg1 []byte + arg2 bool + arg3 livekit.ParticipantIdentity + }{arg1Copy, arg2, arg3}) + stub := fake.SendDataMessageUnlabeledStub + fakeReturns := fake.sendDataMessageUnlabeledReturns + fake.recordInvocation("SendDataMessageUnlabeled", []interface{}{arg1Copy, arg2, arg3}) + fake.sendDataMessageUnlabeledMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledCallCount() int { + fake.sendDataMessageUnlabeledMutex.RLock() + defer fake.sendDataMessageUnlabeledMutex.RUnlock() + return len(fake.sendDataMessageUnlabeledArgsForCall) +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledCalls(stub func([]byte, bool, livekit.ParticipantIdentity) error) { + fake.sendDataMessageUnlabeledMutex.Lock() + defer fake.sendDataMessageUnlabeledMutex.Unlock() + fake.SendDataMessageUnlabeledStub = stub +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledArgsForCall(i int) ([]byte, bool, livekit.ParticipantIdentity) { + fake.sendDataMessageUnlabeledMutex.RLock() + defer fake.sendDataMessageUnlabeledMutex.RUnlock() + argsForCall := fake.sendDataMessageUnlabeledArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledReturns(result1 error) { + fake.sendDataMessageUnlabeledMutex.Lock() + defer fake.sendDataMessageUnlabeledMutex.Unlock() + fake.SendDataMessageUnlabeledStub = nil + fake.sendDataMessageUnlabeledReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataMessageUnlabeledReturnsOnCall(i int, result1 error) { + fake.sendDataMessageUnlabeledMutex.Lock() + defer fake.sendDataMessageUnlabeledMutex.Unlock() + fake.SendDataMessageUnlabeledStub = nil + if fake.sendDataMessageUnlabeledReturnsOnCall == nil { + fake.sendDataMessageUnlabeledReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendDataMessageUnlabeledReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataTrackSubscriberHandles(arg1 map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack) error { + fake.sendDataTrackSubscriberHandlesMutex.Lock() + ret, specificReturn := fake.sendDataTrackSubscriberHandlesReturnsOnCall[len(fake.sendDataTrackSubscriberHandlesArgsForCall)] + fake.sendDataTrackSubscriberHandlesArgsForCall = append(fake.sendDataTrackSubscriberHandlesArgsForCall, struct { + arg1 map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack + }{arg1}) + stub := fake.SendDataTrackSubscriberHandlesStub + fakeReturns := fake.sendDataTrackSubscriberHandlesReturns + fake.recordInvocation("SendDataTrackSubscriberHandles", []interface{}{arg1}) + fake.sendDataTrackSubscriberHandlesMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendDataTrackSubscriberHandlesCallCount() int { + fake.sendDataTrackSubscriberHandlesMutex.RLock() + defer fake.sendDataTrackSubscriberHandlesMutex.RUnlock() + return len(fake.sendDataTrackSubscriberHandlesArgsForCall) +} + +func (fake *FakeLocalParticipant) SendDataTrackSubscriberHandlesCalls(stub func(map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack) error) { + fake.sendDataTrackSubscriberHandlesMutex.Lock() + defer fake.sendDataTrackSubscriberHandlesMutex.Unlock() + fake.SendDataTrackSubscriberHandlesStub = stub +} + +func (fake *FakeLocalParticipant) SendDataTrackSubscriberHandlesArgsForCall(i int) map[uint32]*livekit.DataTrackSubscriberHandles_PublishedDataTrack { + fake.sendDataTrackSubscriberHandlesMutex.RLock() + defer fake.sendDataTrackSubscriberHandlesMutex.RUnlock() + argsForCall := fake.sendDataTrackSubscriberHandlesArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendDataTrackSubscriberHandlesReturns(result1 error) { + fake.sendDataTrackSubscriberHandlesMutex.Lock() + defer fake.sendDataTrackSubscriberHandlesMutex.Unlock() + fake.SendDataTrackSubscriberHandlesStub = nil + fake.sendDataTrackSubscriberHandlesReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendDataTrackSubscriberHandlesReturnsOnCall(i int, result1 error) { + fake.sendDataTrackSubscriberHandlesMutex.Lock() + defer fake.sendDataTrackSubscriberHandlesMutex.Unlock() + fake.SendDataTrackSubscriberHandlesStub = nil + if fake.sendDataTrackSubscriberHandlesReturnsOnCall == nil { + fake.sendDataTrackSubscriberHandlesReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendDataTrackSubscriberHandlesReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendJoinResponse(arg1 *livekit.JoinResponse) error { + fake.sendJoinResponseMutex.Lock() + ret, specificReturn := fake.sendJoinResponseReturnsOnCall[len(fake.sendJoinResponseArgsForCall)] + fake.sendJoinResponseArgsForCall = append(fake.sendJoinResponseArgsForCall, struct { + arg1 *livekit.JoinResponse + }{arg1}) + stub := fake.SendJoinResponseStub + fakeReturns := fake.sendJoinResponseReturns + fake.recordInvocation("SendJoinResponse", []interface{}{arg1}) + fake.sendJoinResponseMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendJoinResponseCallCount() int { + fake.sendJoinResponseMutex.RLock() + defer fake.sendJoinResponseMutex.RUnlock() + return len(fake.sendJoinResponseArgsForCall) +} + +func (fake *FakeLocalParticipant) SendJoinResponseCalls(stub func(*livekit.JoinResponse) error) { + fake.sendJoinResponseMutex.Lock() + defer fake.sendJoinResponseMutex.Unlock() + fake.SendJoinResponseStub = stub +} + +func (fake *FakeLocalParticipant) SendJoinResponseArgsForCall(i int) *livekit.JoinResponse { + fake.sendJoinResponseMutex.RLock() + defer fake.sendJoinResponseMutex.RUnlock() + argsForCall := fake.sendJoinResponseArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendJoinResponseReturns(result1 error) { + fake.sendJoinResponseMutex.Lock() + defer fake.sendJoinResponseMutex.Unlock() + fake.SendJoinResponseStub = nil + fake.sendJoinResponseReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendJoinResponseReturnsOnCall(i int, result1 error) { + fake.sendJoinResponseMutex.Lock() + defer fake.sendJoinResponseMutex.Unlock() + fake.SendJoinResponseStub = nil + if fake.sendJoinResponseReturnsOnCall == nil { + fake.sendJoinResponseReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendJoinResponseReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendParticipantUpdate(arg1 []*livekit.ParticipantInfo) error { + var arg1Copy []*livekit.ParticipantInfo + if arg1 != nil { + arg1Copy = make([]*livekit.ParticipantInfo, len(arg1)) + copy(arg1Copy, arg1) + } + fake.sendParticipantUpdateMutex.Lock() + ret, specificReturn := fake.sendParticipantUpdateReturnsOnCall[len(fake.sendParticipantUpdateArgsForCall)] + fake.sendParticipantUpdateArgsForCall = append(fake.sendParticipantUpdateArgsForCall, struct { + arg1 []*livekit.ParticipantInfo + }{arg1Copy}) + stub := fake.SendParticipantUpdateStub + fakeReturns := fake.sendParticipantUpdateReturns + fake.recordInvocation("SendParticipantUpdate", []interface{}{arg1Copy}) + fake.sendParticipantUpdateMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendParticipantUpdateCallCount() int { + fake.sendParticipantUpdateMutex.RLock() + defer fake.sendParticipantUpdateMutex.RUnlock() + return len(fake.sendParticipantUpdateArgsForCall) +} + +func (fake *FakeLocalParticipant) SendParticipantUpdateCalls(stub func([]*livekit.ParticipantInfo) error) { + fake.sendParticipantUpdateMutex.Lock() + defer fake.sendParticipantUpdateMutex.Unlock() + fake.SendParticipantUpdateStub = stub +} + +func (fake *FakeLocalParticipant) SendParticipantUpdateArgsForCall(i int) []*livekit.ParticipantInfo { + fake.sendParticipantUpdateMutex.RLock() + defer fake.sendParticipantUpdateMutex.RUnlock() + argsForCall := fake.sendParticipantUpdateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendParticipantUpdateReturns(result1 error) { + fake.sendParticipantUpdateMutex.Lock() + defer fake.sendParticipantUpdateMutex.Unlock() + fake.SendParticipantUpdateStub = nil + fake.sendParticipantUpdateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendParticipantUpdateReturnsOnCall(i int, result1 error) { + fake.sendParticipantUpdateMutex.Lock() + defer fake.sendParticipantUpdateMutex.Unlock() + fake.SendParticipantUpdateStub = nil + if fake.sendParticipantUpdateReturnsOnCall == nil { + fake.sendParticipantUpdateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendParticipantUpdateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRefreshToken(arg1 string) error { + fake.sendRefreshTokenMutex.Lock() + ret, specificReturn := fake.sendRefreshTokenReturnsOnCall[len(fake.sendRefreshTokenArgsForCall)] + fake.sendRefreshTokenArgsForCall = append(fake.sendRefreshTokenArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.SendRefreshTokenStub + fakeReturns := fake.sendRefreshTokenReturns + fake.recordInvocation("SendRefreshToken", []interface{}{arg1}) + fake.sendRefreshTokenMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendRefreshTokenCallCount() int { + fake.sendRefreshTokenMutex.RLock() + defer fake.sendRefreshTokenMutex.RUnlock() + return len(fake.sendRefreshTokenArgsForCall) +} + +func (fake *FakeLocalParticipant) SendRefreshTokenCalls(stub func(string) error) { + fake.sendRefreshTokenMutex.Lock() + defer fake.sendRefreshTokenMutex.Unlock() + fake.SendRefreshTokenStub = stub +} + +func (fake *FakeLocalParticipant) SendRefreshTokenArgsForCall(i int) string { + fake.sendRefreshTokenMutex.RLock() + defer fake.sendRefreshTokenMutex.RUnlock() + argsForCall := fake.sendRefreshTokenArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendRefreshTokenReturns(result1 error) { + fake.sendRefreshTokenMutex.Lock() + defer fake.sendRefreshTokenMutex.Unlock() + fake.SendRefreshTokenStub = nil + fake.sendRefreshTokenReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRefreshTokenReturnsOnCall(i int, result1 error) { + fake.sendRefreshTokenMutex.Lock() + defer fake.sendRefreshTokenMutex.Unlock() + fake.SendRefreshTokenStub = nil + if fake.sendRefreshTokenReturnsOnCall == nil { + fake.sendRefreshTokenReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendRefreshTokenReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRoomMovedResponse(arg1 *livekit.RoomMovedResponse) error { + fake.sendRoomMovedResponseMutex.Lock() + ret, specificReturn := fake.sendRoomMovedResponseReturnsOnCall[len(fake.sendRoomMovedResponseArgsForCall)] + fake.sendRoomMovedResponseArgsForCall = append(fake.sendRoomMovedResponseArgsForCall, struct { + arg1 *livekit.RoomMovedResponse + }{arg1}) + stub := fake.SendRoomMovedResponseStub + fakeReturns := fake.sendRoomMovedResponseReturns + fake.recordInvocation("SendRoomMovedResponse", []interface{}{arg1}) + fake.sendRoomMovedResponseMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendRoomMovedResponseCallCount() int { + fake.sendRoomMovedResponseMutex.RLock() + defer fake.sendRoomMovedResponseMutex.RUnlock() + return len(fake.sendRoomMovedResponseArgsForCall) +} + +func (fake *FakeLocalParticipant) SendRoomMovedResponseCalls(stub func(*livekit.RoomMovedResponse) error) { + fake.sendRoomMovedResponseMutex.Lock() + defer fake.sendRoomMovedResponseMutex.Unlock() + fake.SendRoomMovedResponseStub = stub +} + +func (fake *FakeLocalParticipant) SendRoomMovedResponseArgsForCall(i int) *livekit.RoomMovedResponse { + fake.sendRoomMovedResponseMutex.RLock() + defer fake.sendRoomMovedResponseMutex.RUnlock() + argsForCall := fake.sendRoomMovedResponseArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendRoomMovedResponseReturns(result1 error) { + fake.sendRoomMovedResponseMutex.Lock() + defer fake.sendRoomMovedResponseMutex.Unlock() + fake.SendRoomMovedResponseStub = nil + fake.sendRoomMovedResponseReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRoomMovedResponseReturnsOnCall(i int, result1 error) { + fake.sendRoomMovedResponseMutex.Lock() + defer fake.sendRoomMovedResponseMutex.Unlock() + fake.SendRoomMovedResponseStub = nil + if fake.sendRoomMovedResponseReturnsOnCall == nil { + fake.sendRoomMovedResponseReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendRoomMovedResponseReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRoomUpdate(arg1 *livekit.Room) error { + fake.sendRoomUpdateMutex.Lock() + ret, specificReturn := fake.sendRoomUpdateReturnsOnCall[len(fake.sendRoomUpdateArgsForCall)] + fake.sendRoomUpdateArgsForCall = append(fake.sendRoomUpdateArgsForCall, struct { + arg1 *livekit.Room + }{arg1}) + stub := fake.SendRoomUpdateStub + fakeReturns := fake.sendRoomUpdateReturns + fake.recordInvocation("SendRoomUpdate", []interface{}{arg1}) + fake.sendRoomUpdateMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendRoomUpdateCallCount() int { + fake.sendRoomUpdateMutex.RLock() + defer fake.sendRoomUpdateMutex.RUnlock() + return len(fake.sendRoomUpdateArgsForCall) +} + +func (fake *FakeLocalParticipant) SendRoomUpdateCalls(stub func(*livekit.Room) error) { + fake.sendRoomUpdateMutex.Lock() + defer fake.sendRoomUpdateMutex.Unlock() + fake.SendRoomUpdateStub = stub +} + +func (fake *FakeLocalParticipant) SendRoomUpdateArgsForCall(i int) *livekit.Room { + fake.sendRoomUpdateMutex.RLock() + defer fake.sendRoomUpdateMutex.RUnlock() + argsForCall := fake.sendRoomUpdateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SendRoomUpdateReturns(result1 error) { + fake.sendRoomUpdateMutex.Lock() + defer fake.sendRoomUpdateMutex.Unlock() + fake.SendRoomUpdateStub = nil + fake.sendRoomUpdateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendRoomUpdateReturnsOnCall(i int, result1 error) { + fake.sendRoomUpdateMutex.Lock() + defer fake.sendRoomUpdateMutex.Unlock() + fake.SendRoomUpdateStub = nil + if fake.sendRoomUpdateReturnsOnCall == nil { + fake.sendRoomUpdateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendRoomUpdateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendSpeakerUpdate(arg1 []*livekit.SpeakerInfo, arg2 bool) error { + var arg1Copy []*livekit.SpeakerInfo + if arg1 != nil { + arg1Copy = make([]*livekit.SpeakerInfo, len(arg1)) + copy(arg1Copy, arg1) + } + fake.sendSpeakerUpdateMutex.Lock() + ret, specificReturn := fake.sendSpeakerUpdateReturnsOnCall[len(fake.sendSpeakerUpdateArgsForCall)] + fake.sendSpeakerUpdateArgsForCall = append(fake.sendSpeakerUpdateArgsForCall, struct { + arg1 []*livekit.SpeakerInfo + arg2 bool + }{arg1Copy, arg2}) + stub := fake.SendSpeakerUpdateStub + fakeReturns := fake.sendSpeakerUpdateReturns + fake.recordInvocation("SendSpeakerUpdate", []interface{}{arg1Copy, arg2}) + fake.sendSpeakerUpdateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendSpeakerUpdateCallCount() int { + fake.sendSpeakerUpdateMutex.RLock() + defer fake.sendSpeakerUpdateMutex.RUnlock() + return len(fake.sendSpeakerUpdateArgsForCall) +} + +func (fake *FakeLocalParticipant) SendSpeakerUpdateCalls(stub func([]*livekit.SpeakerInfo, bool) error) { + fake.sendSpeakerUpdateMutex.Lock() + defer fake.sendSpeakerUpdateMutex.Unlock() + fake.SendSpeakerUpdateStub = stub +} + +func (fake *FakeLocalParticipant) SendSpeakerUpdateArgsForCall(i int) ([]*livekit.SpeakerInfo, bool) { + fake.sendSpeakerUpdateMutex.RLock() + defer fake.sendSpeakerUpdateMutex.RUnlock() + argsForCall := fake.sendSpeakerUpdateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) SendSpeakerUpdateReturns(result1 error) { + fake.sendSpeakerUpdateMutex.Lock() + defer fake.sendSpeakerUpdateMutex.Unlock() + fake.SendSpeakerUpdateStub = nil + fake.sendSpeakerUpdateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendSpeakerUpdateReturnsOnCall(i int, result1 error) { + fake.sendSpeakerUpdateMutex.Lock() + defer fake.sendSpeakerUpdateMutex.Unlock() + fake.SendSpeakerUpdateStub = nil + if fake.sendSpeakerUpdateReturnsOnCall == nil { + fake.sendSpeakerUpdateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendSpeakerUpdateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendSubscriptionPermissionUpdate(arg1 livekit.ParticipantID, arg2 livekit.TrackID, arg3 bool) error { + fake.sendSubscriptionPermissionUpdateMutex.Lock() + ret, specificReturn := fake.sendSubscriptionPermissionUpdateReturnsOnCall[len(fake.sendSubscriptionPermissionUpdateArgsForCall)] + fake.sendSubscriptionPermissionUpdateArgsForCall = append(fake.sendSubscriptionPermissionUpdateArgsForCall, struct { + arg1 livekit.ParticipantID + arg2 livekit.TrackID + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.SendSubscriptionPermissionUpdateStub + fakeReturns := fake.sendSubscriptionPermissionUpdateReturns + fake.recordInvocation("SendSubscriptionPermissionUpdate", []interface{}{arg1, arg2, arg3}) + fake.sendSubscriptionPermissionUpdateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SendSubscriptionPermissionUpdateCallCount() int { + fake.sendSubscriptionPermissionUpdateMutex.RLock() + defer fake.sendSubscriptionPermissionUpdateMutex.RUnlock() + return len(fake.sendSubscriptionPermissionUpdateArgsForCall) +} + +func (fake *FakeLocalParticipant) SendSubscriptionPermissionUpdateCalls(stub func(livekit.ParticipantID, livekit.TrackID, bool) error) { + fake.sendSubscriptionPermissionUpdateMutex.Lock() + defer fake.sendSubscriptionPermissionUpdateMutex.Unlock() + fake.SendSubscriptionPermissionUpdateStub = stub +} + +func (fake *FakeLocalParticipant) SendSubscriptionPermissionUpdateArgsForCall(i int) (livekit.ParticipantID, livekit.TrackID, bool) { + fake.sendSubscriptionPermissionUpdateMutex.RLock() + defer fake.sendSubscriptionPermissionUpdateMutex.RUnlock() + argsForCall := fake.sendSubscriptionPermissionUpdateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) SendSubscriptionPermissionUpdateReturns(result1 error) { + fake.sendSubscriptionPermissionUpdateMutex.Lock() + defer fake.sendSubscriptionPermissionUpdateMutex.Unlock() + fake.SendSubscriptionPermissionUpdateStub = nil + fake.sendSubscriptionPermissionUpdateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SendSubscriptionPermissionUpdateReturnsOnCall(i int, result1 error) { + fake.sendSubscriptionPermissionUpdateMutex.Lock() + defer fake.sendSubscriptionPermissionUpdateMutex.Unlock() + fake.SendSubscriptionPermissionUpdateStub = nil + if fake.sendSubscriptionPermissionUpdateReturnsOnCall == nil { + fake.sendSubscriptionPermissionUpdateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendSubscriptionPermissionUpdateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SetAttributes(arg1 map[string]string) { + fake.setAttributesMutex.Lock() + fake.setAttributesArgsForCall = append(fake.setAttributesArgsForCall, struct { + arg1 map[string]string + }{arg1}) + stub := fake.SetAttributesStub + fake.recordInvocation("SetAttributes", []interface{}{arg1}) + fake.setAttributesMutex.Unlock() + if stub != nil { + fake.SetAttributesStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetAttributesCallCount() int { + fake.setAttributesMutex.RLock() + defer fake.setAttributesMutex.RUnlock() + return len(fake.setAttributesArgsForCall) +} + +func (fake *FakeLocalParticipant) SetAttributesCalls(stub func(map[string]string)) { + fake.setAttributesMutex.Lock() + defer fake.setAttributesMutex.Unlock() + fake.SetAttributesStub = stub +} + +func (fake *FakeLocalParticipant) SetAttributesArgsForCall(i int) map[string]string { + fake.setAttributesMutex.RLock() + defer fake.setAttributesMutex.RUnlock() + argsForCall := fake.setAttributesArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetICEConfig(arg1 *livekit.ICEConfig) { + fake.setICEConfigMutex.Lock() + fake.setICEConfigArgsForCall = append(fake.setICEConfigArgsForCall, struct { + arg1 *livekit.ICEConfig + }{arg1}) + stub := fake.SetICEConfigStub + fake.recordInvocation("SetICEConfig", []interface{}{arg1}) + fake.setICEConfigMutex.Unlock() + if stub != nil { + fake.SetICEConfigStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetICEConfigCallCount() int { + fake.setICEConfigMutex.RLock() + defer fake.setICEConfigMutex.RUnlock() + return len(fake.setICEConfigArgsForCall) +} + +func (fake *FakeLocalParticipant) SetICEConfigCalls(stub func(*livekit.ICEConfig)) { + fake.setICEConfigMutex.Lock() + defer fake.setICEConfigMutex.Unlock() + fake.SetICEConfigStub = stub +} + +func (fake *FakeLocalParticipant) SetICEConfigArgsForCall(i int) *livekit.ICEConfig { + fake.setICEConfigMutex.RLock() + defer fake.setICEConfigMutex.RUnlock() + argsForCall := fake.setICEConfigArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetMetadata(arg1 string) { + fake.setMetadataMutex.Lock() + fake.setMetadataArgsForCall = append(fake.setMetadataArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.SetMetadataStub + fake.recordInvocation("SetMetadata", []interface{}{arg1}) + fake.setMetadataMutex.Unlock() + if stub != nil { + fake.SetMetadataStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetMetadataCallCount() int { + fake.setMetadataMutex.RLock() + defer fake.setMetadataMutex.RUnlock() + return len(fake.setMetadataArgsForCall) +} + +func (fake *FakeLocalParticipant) SetMetadataCalls(stub func(string)) { + fake.setMetadataMutex.Lock() + defer fake.setMetadataMutex.Unlock() + fake.SetMetadataStub = stub +} + +func (fake *FakeLocalParticipant) SetMetadataArgsForCall(i int) string { + fake.setMetadataMutex.RLock() + defer fake.setMetadataMutex.RUnlock() + argsForCall := fake.setMetadataArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetMigrateInfo(arg1 *webrtc.SessionDescription, arg2 *webrtc.SessionDescription, arg3 []*livekit.TrackPublishedResponse, arg4 []*livekit.DataChannelInfo, arg5 []*livekit.DataChannelReceiveState, arg6 []*livekit.PublishDataTrackResponse) { + var arg3Copy []*livekit.TrackPublishedResponse + if arg3 != nil { + arg3Copy = make([]*livekit.TrackPublishedResponse, len(arg3)) + copy(arg3Copy, arg3) + } + var arg4Copy []*livekit.DataChannelInfo + if arg4 != nil { + arg4Copy = make([]*livekit.DataChannelInfo, len(arg4)) + copy(arg4Copy, arg4) + } + var arg5Copy []*livekit.DataChannelReceiveState + if arg5 != nil { + arg5Copy = make([]*livekit.DataChannelReceiveState, len(arg5)) + copy(arg5Copy, arg5) + } + var arg6Copy []*livekit.PublishDataTrackResponse + if arg6 != nil { + arg6Copy = make([]*livekit.PublishDataTrackResponse, len(arg6)) + copy(arg6Copy, arg6) + } + fake.setMigrateInfoMutex.Lock() + fake.setMigrateInfoArgsForCall = append(fake.setMigrateInfoArgsForCall, struct { + arg1 *webrtc.SessionDescription + arg2 *webrtc.SessionDescription + arg3 []*livekit.TrackPublishedResponse + arg4 []*livekit.DataChannelInfo + arg5 []*livekit.DataChannelReceiveState + arg6 []*livekit.PublishDataTrackResponse + }{arg1, arg2, arg3Copy, arg4Copy, arg5Copy, arg6Copy}) + stub := fake.SetMigrateInfoStub + fake.recordInvocation("SetMigrateInfo", []interface{}{arg1, arg2, arg3Copy, arg4Copy, arg5Copy, arg6Copy}) + fake.setMigrateInfoMutex.Unlock() + if stub != nil { + fake.SetMigrateInfoStub(arg1, arg2, arg3, arg4, arg5, arg6) + } +} + +func (fake *FakeLocalParticipant) SetMigrateInfoCallCount() int { + fake.setMigrateInfoMutex.RLock() + defer fake.setMigrateInfoMutex.RUnlock() + return len(fake.setMigrateInfoArgsForCall) +} + +func (fake *FakeLocalParticipant) SetMigrateInfoCalls(stub func(*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, []*livekit.DataChannelReceiveState, []*livekit.PublishDataTrackResponse)) { + fake.setMigrateInfoMutex.Lock() + defer fake.setMigrateInfoMutex.Unlock() + fake.SetMigrateInfoStub = stub +} + +func (fake *FakeLocalParticipant) SetMigrateInfoArgsForCall(i int) (*webrtc.SessionDescription, *webrtc.SessionDescription, []*livekit.TrackPublishedResponse, []*livekit.DataChannelInfo, []*livekit.DataChannelReceiveState, []*livekit.PublishDataTrackResponse) { + fake.setMigrateInfoMutex.RLock() + defer fake.setMigrateInfoMutex.RUnlock() + argsForCall := fake.setMigrateInfoArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6 +} + +func (fake *FakeLocalParticipant) SetMigrateState(arg1 types.MigrateState) { + fake.setMigrateStateMutex.Lock() + fake.setMigrateStateArgsForCall = append(fake.setMigrateStateArgsForCall, struct { + arg1 types.MigrateState + }{arg1}) + stub := fake.SetMigrateStateStub + fake.recordInvocation("SetMigrateState", []interface{}{arg1}) + fake.setMigrateStateMutex.Unlock() + if stub != nil { + fake.SetMigrateStateStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetMigrateStateCallCount() int { + fake.setMigrateStateMutex.RLock() + defer fake.setMigrateStateMutex.RUnlock() + return len(fake.setMigrateStateArgsForCall) +} + +func (fake *FakeLocalParticipant) SetMigrateStateCalls(stub func(types.MigrateState)) { + fake.setMigrateStateMutex.Lock() + defer fake.setMigrateStateMutex.Unlock() + fake.SetMigrateStateStub = stub +} + +func (fake *FakeLocalParticipant) SetMigrateStateArgsForCall(i int) types.MigrateState { + fake.setMigrateStateMutex.RLock() + defer fake.setMigrateStateMutex.RUnlock() + argsForCall := fake.setMigrateStateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetName(arg1 string) { + fake.setNameMutex.Lock() + fake.setNameArgsForCall = append(fake.setNameArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.SetNameStub + fake.recordInvocation("SetName", []interface{}{arg1}) + fake.setNameMutex.Unlock() + if stub != nil { + fake.SetNameStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetNameCallCount() int { + fake.setNameMutex.RLock() + defer fake.setNameMutex.RUnlock() + return len(fake.setNameArgsForCall) +} + +func (fake *FakeLocalParticipant) SetNameCalls(stub func(string)) { + fake.setNameMutex.Lock() + defer fake.setNameMutex.Unlock() + fake.SetNameStub = stub +} + +func (fake *FakeLocalParticipant) SetNameArgsForCall(i int) string { + fake.setNameMutex.RLock() + defer fake.setNameMutex.RUnlock() + argsForCall := fake.setNameArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetPermission(arg1 *livekit.ParticipantPermission) bool { + fake.setPermissionMutex.Lock() + ret, specificReturn := fake.setPermissionReturnsOnCall[len(fake.setPermissionArgsForCall)] + fake.setPermissionArgsForCall = append(fake.setPermissionArgsForCall, struct { + arg1 *livekit.ParticipantPermission + }{arg1}) + stub := fake.SetPermissionStub + fakeReturns := fake.setPermissionReturns + fake.recordInvocation("SetPermission", []interface{}{arg1}) + fake.setPermissionMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SetPermissionCallCount() int { + fake.setPermissionMutex.RLock() + defer fake.setPermissionMutex.RUnlock() + return len(fake.setPermissionArgsForCall) +} + +func (fake *FakeLocalParticipant) SetPermissionCalls(stub func(*livekit.ParticipantPermission) bool) { + fake.setPermissionMutex.Lock() + defer fake.setPermissionMutex.Unlock() + fake.SetPermissionStub = stub +} + +func (fake *FakeLocalParticipant) SetPermissionArgsForCall(i int) *livekit.ParticipantPermission { + fake.setPermissionMutex.RLock() + defer fake.setPermissionMutex.RUnlock() + argsForCall := fake.setPermissionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetPermissionReturns(result1 bool) { + fake.setPermissionMutex.Lock() + defer fake.setPermissionMutex.Unlock() + fake.SetPermissionStub = nil + fake.setPermissionReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SetPermissionReturnsOnCall(i int, result1 bool) { + fake.setPermissionMutex.Lock() + defer fake.setPermissionMutex.Unlock() + fake.SetPermissionStub = nil + if fake.setPermissionReturnsOnCall == nil { + fake.setPermissionReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.setPermissionReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SetSignalSourceValid(arg1 bool) { + fake.setSignalSourceValidMutex.Lock() + fake.setSignalSourceValidArgsForCall = append(fake.setSignalSourceValidArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.SetSignalSourceValidStub + fake.recordInvocation("SetSignalSourceValid", []interface{}{arg1}) + fake.setSignalSourceValidMutex.Unlock() + if stub != nil { + fake.SetSignalSourceValidStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetSignalSourceValidCallCount() int { + fake.setSignalSourceValidMutex.RLock() + defer fake.setSignalSourceValidMutex.RUnlock() + return len(fake.setSignalSourceValidArgsForCall) +} + +func (fake *FakeLocalParticipant) SetSignalSourceValidCalls(stub func(bool)) { + fake.setSignalSourceValidMutex.Lock() + defer fake.setSignalSourceValidMutex.Unlock() + fake.SetSignalSourceValidStub = stub +} + +func (fake *FakeLocalParticipant) SetSignalSourceValidArgsForCall(i int) bool { + fake.setSignalSourceValidMutex.RLock() + defer fake.setSignalSourceValidMutex.RUnlock() + argsForCall := fake.setSignalSourceValidArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetSubscriberAllowPause(arg1 bool) { + fake.setSubscriberAllowPauseMutex.Lock() + fake.setSubscriberAllowPauseArgsForCall = append(fake.setSubscriberAllowPauseArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.SetSubscriberAllowPauseStub + fake.recordInvocation("SetSubscriberAllowPause", []interface{}{arg1}) + fake.setSubscriberAllowPauseMutex.Unlock() + if stub != nil { + fake.SetSubscriberAllowPauseStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetSubscriberAllowPauseCallCount() int { + fake.setSubscriberAllowPauseMutex.RLock() + defer fake.setSubscriberAllowPauseMutex.RUnlock() + return len(fake.setSubscriberAllowPauseArgsForCall) +} + +func (fake *FakeLocalParticipant) SetSubscriberAllowPauseCalls(stub func(bool)) { + fake.setSubscriberAllowPauseMutex.Lock() + defer fake.setSubscriberAllowPauseMutex.Unlock() + fake.SetSubscriberAllowPauseStub = stub +} + +func (fake *FakeLocalParticipant) SetSubscriberAllowPauseArgsForCall(i int) bool { + fake.setSubscriberAllowPauseMutex.RLock() + defer fake.setSubscriberAllowPauseMutex.RUnlock() + argsForCall := fake.setSubscriberAllowPauseArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetSubscriberChannelCapacity(arg1 int64) { + fake.setSubscriberChannelCapacityMutex.Lock() + fake.setSubscriberChannelCapacityArgsForCall = append(fake.setSubscriberChannelCapacityArgsForCall, struct { + arg1 int64 + }{arg1}) + stub := fake.SetSubscriberChannelCapacityStub + fake.recordInvocation("SetSubscriberChannelCapacity", []interface{}{arg1}) + fake.setSubscriberChannelCapacityMutex.Unlock() + if stub != nil { + fake.SetSubscriberChannelCapacityStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SetSubscriberChannelCapacityCallCount() int { + fake.setSubscriberChannelCapacityMutex.RLock() + defer fake.setSubscriberChannelCapacityMutex.RUnlock() + return len(fake.setSubscriberChannelCapacityArgsForCall) +} + +func (fake *FakeLocalParticipant) SetSubscriberChannelCapacityCalls(stub func(int64)) { + fake.setSubscriberChannelCapacityMutex.Lock() + defer fake.setSubscriberChannelCapacityMutex.Unlock() + fake.SetSubscriberChannelCapacityStub = stub +} + +func (fake *FakeLocalParticipant) SetSubscriberChannelCapacityArgsForCall(i int) int64 { + fake.setSubscriberChannelCapacityMutex.RLock() + defer fake.setSubscriberChannelCapacityMutex.RUnlock() + argsForCall := fake.setSubscriberChannelCapacityArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SetTrackMuted(arg1 *livekit.MuteTrackRequest, arg2 bool) *livekit.TrackInfo { + fake.setTrackMutedMutex.Lock() + ret, specificReturn := fake.setTrackMutedReturnsOnCall[len(fake.setTrackMutedArgsForCall)] + fake.setTrackMutedArgsForCall = append(fake.setTrackMutedArgsForCall, struct { + arg1 *livekit.MuteTrackRequest + arg2 bool + }{arg1, arg2}) + stub := fake.SetTrackMutedStub + fakeReturns := fake.setTrackMutedReturns + fake.recordInvocation("SetTrackMuted", []interface{}{arg1, arg2}) + fake.setTrackMutedMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SetTrackMutedCallCount() int { + fake.setTrackMutedMutex.RLock() + defer fake.setTrackMutedMutex.RUnlock() + return len(fake.setTrackMutedArgsForCall) +} + +func (fake *FakeLocalParticipant) SetTrackMutedCalls(stub func(*livekit.MuteTrackRequest, bool) *livekit.TrackInfo) { + fake.setTrackMutedMutex.Lock() + defer fake.setTrackMutedMutex.Unlock() + fake.SetTrackMutedStub = stub +} + +func (fake *FakeLocalParticipant) SetTrackMutedArgsForCall(i int) (*livekit.MuteTrackRequest, bool) { + fake.setTrackMutedMutex.RLock() + defer fake.setTrackMutedMutex.RUnlock() + argsForCall := fake.setTrackMutedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) SetTrackMutedReturns(result1 *livekit.TrackInfo) { + fake.setTrackMutedMutex.Lock() + defer fake.setTrackMutedMutex.Unlock() + fake.SetTrackMutedStub = nil + fake.setTrackMutedReturns = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalParticipant) SetTrackMutedReturnsOnCall(i int, result1 *livekit.TrackInfo) { + fake.setTrackMutedMutex.Lock() + defer fake.setTrackMutedMutex.Unlock() + fake.SetTrackMutedStub = nil + if fake.setTrackMutedReturnsOnCall == nil { + fake.setTrackMutedReturnsOnCall = make(map[int]struct { + result1 *livekit.TrackInfo + }) + } + fake.setTrackMutedReturnsOnCall[i] = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeLocalParticipant) State() livekit.ParticipantInfo_State { + fake.stateMutex.Lock() + ret, specificReturn := fake.stateReturnsOnCall[len(fake.stateArgsForCall)] + fake.stateArgsForCall = append(fake.stateArgsForCall, struct { + }{}) + stub := fake.StateStub + fakeReturns := fake.stateReturns + fake.recordInvocation("State", []interface{}{}) + fake.stateMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) StateCallCount() int { + fake.stateMutex.RLock() + defer fake.stateMutex.RUnlock() + return len(fake.stateArgsForCall) +} + +func (fake *FakeLocalParticipant) StateCalls(stub func() livekit.ParticipantInfo_State) { + fake.stateMutex.Lock() + defer fake.stateMutex.Unlock() + fake.StateStub = stub +} + +func (fake *FakeLocalParticipant) StateReturns(result1 livekit.ParticipantInfo_State) { + fake.stateMutex.Lock() + defer fake.stateMutex.Unlock() + fake.StateStub = nil + fake.stateReturns = struct { + result1 livekit.ParticipantInfo_State + }{result1} +} + +func (fake *FakeLocalParticipant) StateReturnsOnCall(i int, result1 livekit.ParticipantInfo_State) { + fake.stateMutex.Lock() + defer fake.stateMutex.Unlock() + fake.StateStub = nil + if fake.stateReturnsOnCall == nil { + fake.stateReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantInfo_State + }) + } + fake.stateReturnsOnCall[i] = struct { + result1 livekit.ParticipantInfo_State + }{result1} +} + +func (fake *FakeLocalParticipant) StopAndGetSubscribedTracksForwarderState() map[livekit.TrackID]*livekit.RTPForwarderState { + fake.stopAndGetSubscribedTracksForwarderStateMutex.Lock() + ret, specificReturn := fake.stopAndGetSubscribedTracksForwarderStateReturnsOnCall[len(fake.stopAndGetSubscribedTracksForwarderStateArgsForCall)] + fake.stopAndGetSubscribedTracksForwarderStateArgsForCall = append(fake.stopAndGetSubscribedTracksForwarderStateArgsForCall, struct { + }{}) + stub := fake.StopAndGetSubscribedTracksForwarderStateStub + fakeReturns := fake.stopAndGetSubscribedTracksForwarderStateReturns + fake.recordInvocation("StopAndGetSubscribedTracksForwarderState", []interface{}{}) + fake.stopAndGetSubscribedTracksForwarderStateMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) StopAndGetSubscribedTracksForwarderStateCallCount() int { + fake.stopAndGetSubscribedTracksForwarderStateMutex.RLock() + defer fake.stopAndGetSubscribedTracksForwarderStateMutex.RUnlock() + return len(fake.stopAndGetSubscribedTracksForwarderStateArgsForCall) +} + +func (fake *FakeLocalParticipant) StopAndGetSubscribedTracksForwarderStateCalls(stub func() map[livekit.TrackID]*livekit.RTPForwarderState) { + fake.stopAndGetSubscribedTracksForwarderStateMutex.Lock() + defer fake.stopAndGetSubscribedTracksForwarderStateMutex.Unlock() + fake.StopAndGetSubscribedTracksForwarderStateStub = stub +} + +func (fake *FakeLocalParticipant) StopAndGetSubscribedTracksForwarderStateReturns(result1 map[livekit.TrackID]*livekit.RTPForwarderState) { + fake.stopAndGetSubscribedTracksForwarderStateMutex.Lock() + defer fake.stopAndGetSubscribedTracksForwarderStateMutex.Unlock() + fake.StopAndGetSubscribedTracksForwarderStateStub = nil + fake.stopAndGetSubscribedTracksForwarderStateReturns = struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + }{result1} +} + +func (fake *FakeLocalParticipant) StopAndGetSubscribedTracksForwarderStateReturnsOnCall(i int, result1 map[livekit.TrackID]*livekit.RTPForwarderState) { + fake.stopAndGetSubscribedTracksForwarderStateMutex.Lock() + defer fake.stopAndGetSubscribedTracksForwarderStateMutex.Unlock() + fake.StopAndGetSubscribedTracksForwarderStateStub = nil + if fake.stopAndGetSubscribedTracksForwarderStateReturnsOnCall == nil { + fake.stopAndGetSubscribedTracksForwarderStateReturnsOnCall = make(map[int]struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + }) + } + fake.stopAndGetSubscribedTracksForwarderStateReturnsOnCall[i] = struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + }{result1} +} + +func (fake *FakeLocalParticipant) SubscribeToDataTrack(arg1 livekit.TrackID) { + fake.subscribeToDataTrackMutex.Lock() + fake.subscribeToDataTrackArgsForCall = append(fake.subscribeToDataTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.SubscribeToDataTrackStub + fake.recordInvocation("SubscribeToDataTrack", []interface{}{arg1}) + fake.subscribeToDataTrackMutex.Unlock() + if stub != nil { + fake.SubscribeToDataTrackStub(arg1) + } +} + +func (fake *FakeLocalParticipant) SubscribeToDataTrackCallCount() int { + fake.subscribeToDataTrackMutex.RLock() + defer fake.subscribeToDataTrackMutex.RUnlock() + return len(fake.subscribeToDataTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) SubscribeToDataTrackCalls(stub func(livekit.TrackID)) { + fake.subscribeToDataTrackMutex.Lock() + defer fake.subscribeToDataTrackMutex.Unlock() + fake.SubscribeToDataTrackStub = stub +} + +func (fake *FakeLocalParticipant) SubscribeToDataTrackArgsForCall(i int) livekit.TrackID { + fake.subscribeToDataTrackMutex.RLock() + defer fake.subscribeToDataTrackMutex.RUnlock() + argsForCall := fake.subscribeToDataTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) SubscribeToTrack(arg1 livekit.TrackID, arg2 bool) { + fake.subscribeToTrackMutex.Lock() + fake.subscribeToTrackArgsForCall = append(fake.subscribeToTrackArgsForCall, struct { + arg1 livekit.TrackID + arg2 bool + }{arg1, arg2}) + stub := fake.SubscribeToTrackStub + fake.recordInvocation("SubscribeToTrack", []interface{}{arg1, arg2}) + fake.subscribeToTrackMutex.Unlock() + if stub != nil { + fake.SubscribeToTrackStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) SubscribeToTrackCallCount() int { + fake.subscribeToTrackMutex.RLock() + defer fake.subscribeToTrackMutex.RUnlock() + return len(fake.subscribeToTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) SubscribeToTrackCalls(stub func(livekit.TrackID, bool)) { + fake.subscribeToTrackMutex.Lock() + defer fake.subscribeToTrackMutex.Unlock() + fake.SubscribeToTrackStub = stub +} + +func (fake *FakeLocalParticipant) SubscribeToTrackArgsForCall(i int) (livekit.TrackID, bool) { + fake.subscribeToTrackMutex.RLock() + defer fake.subscribeToTrackMutex.RUnlock() + argsForCall := fake.subscribeToTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) SubscriberAsPrimary() bool { + fake.subscriberAsPrimaryMutex.Lock() + ret, specificReturn := fake.subscriberAsPrimaryReturnsOnCall[len(fake.subscriberAsPrimaryArgsForCall)] + fake.subscriberAsPrimaryArgsForCall = append(fake.subscriberAsPrimaryArgsForCall, struct { + }{}) + stub := fake.SubscriberAsPrimaryStub + fakeReturns := fake.subscriberAsPrimaryReturns + fake.recordInvocation("SubscriberAsPrimary", []interface{}{}) + fake.subscriberAsPrimaryMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SubscriberAsPrimaryCallCount() int { + fake.subscriberAsPrimaryMutex.RLock() + defer fake.subscriberAsPrimaryMutex.RUnlock() + return len(fake.subscriberAsPrimaryArgsForCall) +} + +func (fake *FakeLocalParticipant) SubscriberAsPrimaryCalls(stub func() bool) { + fake.subscriberAsPrimaryMutex.Lock() + defer fake.subscriberAsPrimaryMutex.Unlock() + fake.SubscriberAsPrimaryStub = stub +} + +func (fake *FakeLocalParticipant) SubscriberAsPrimaryReturns(result1 bool) { + fake.subscriberAsPrimaryMutex.Lock() + defer fake.subscriberAsPrimaryMutex.Unlock() + fake.SubscriberAsPrimaryStub = nil + fake.subscriberAsPrimaryReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SubscriberAsPrimaryReturnsOnCall(i int, result1 bool) { + fake.subscriberAsPrimaryMutex.Lock() + defer fake.subscriberAsPrimaryMutex.Unlock() + fake.SubscriberAsPrimaryStub = nil + if fake.subscriberAsPrimaryReturnsOnCall == nil { + fake.subscriberAsPrimaryReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.subscriberAsPrimaryReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SubscriptionPermission() (*livekit.SubscriptionPermission, utils.TimedVersion) { + fake.subscriptionPermissionMutex.Lock() + ret, specificReturn := fake.subscriptionPermissionReturnsOnCall[len(fake.subscriptionPermissionArgsForCall)] + fake.subscriptionPermissionArgsForCall = append(fake.subscriptionPermissionArgsForCall, struct { + }{}) + stub := fake.SubscriptionPermissionStub + fakeReturns := fake.subscriptionPermissionReturns + fake.recordInvocation("SubscriptionPermission", []interface{}{}) + fake.subscriptionPermissionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipant) SubscriptionPermissionCallCount() int { + fake.subscriptionPermissionMutex.RLock() + defer fake.subscriptionPermissionMutex.RUnlock() + return len(fake.subscriptionPermissionArgsForCall) +} + +func (fake *FakeLocalParticipant) SubscriptionPermissionCalls(stub func() (*livekit.SubscriptionPermission, utils.TimedVersion)) { + fake.subscriptionPermissionMutex.Lock() + defer fake.subscriptionPermissionMutex.Unlock() + fake.SubscriptionPermissionStub = stub +} + +func (fake *FakeLocalParticipant) SubscriptionPermissionReturns(result1 *livekit.SubscriptionPermission, result2 utils.TimedVersion) { + fake.subscriptionPermissionMutex.Lock() + defer fake.subscriptionPermissionMutex.Unlock() + fake.SubscriptionPermissionStub = nil + fake.subscriptionPermissionReturns = struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeLocalParticipant) SubscriptionPermissionReturnsOnCall(i int, result1 *livekit.SubscriptionPermission, result2 utils.TimedVersion) { + fake.subscriptionPermissionMutex.Lock() + defer fake.subscriptionPermissionMutex.Unlock() + fake.SubscriptionPermissionStub = nil + if fake.subscriptionPermissionReturnsOnCall == nil { + fake.subscriptionPermissionReturnsOnCall = make(map[int]struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + }) + } + fake.subscriptionPermissionReturnsOnCall[i] = struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeLocalParticipant) SupportsCodecChange() bool { + fake.supportsCodecChangeMutex.Lock() + ret, specificReturn := fake.supportsCodecChangeReturnsOnCall[len(fake.supportsCodecChangeArgsForCall)] + fake.supportsCodecChangeArgsForCall = append(fake.supportsCodecChangeArgsForCall, struct { + }{}) + stub := fake.SupportsCodecChangeStub + fakeReturns := fake.supportsCodecChangeReturns + fake.recordInvocation("SupportsCodecChange", []interface{}{}) + fake.supportsCodecChangeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeCallCount() int { + fake.supportsCodecChangeMutex.RLock() + defer fake.supportsCodecChangeMutex.RUnlock() + return len(fake.supportsCodecChangeArgsForCall) +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeCalls(stub func() bool) { + fake.supportsCodecChangeMutex.Lock() + defer fake.supportsCodecChangeMutex.Unlock() + fake.SupportsCodecChangeStub = stub +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeReturns(result1 bool) { + fake.supportsCodecChangeMutex.Lock() + defer fake.supportsCodecChangeMutex.Unlock() + fake.SupportsCodecChangeStub = nil + fake.supportsCodecChangeReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsCodecChangeReturnsOnCall(i int, result1 bool) { + fake.supportsCodecChangeMutex.Lock() + defer fake.supportsCodecChangeMutex.Unlock() + fake.SupportsCodecChangeStub = nil + if fake.supportsCodecChangeReturnsOnCall == nil { + fake.supportsCodecChangeReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.supportsCodecChangeReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsMoving() error { + fake.supportsMovingMutex.Lock() + ret, specificReturn := fake.supportsMovingReturnsOnCall[len(fake.supportsMovingArgsForCall)] + fake.supportsMovingArgsForCall = append(fake.supportsMovingArgsForCall, struct { + }{}) + stub := fake.SupportsMovingStub + fakeReturns := fake.supportsMovingReturns + fake.recordInvocation("SupportsMoving", []interface{}{}) + fake.supportsMovingMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SupportsMovingCallCount() int { + fake.supportsMovingMutex.RLock() + defer fake.supportsMovingMutex.RUnlock() + return len(fake.supportsMovingArgsForCall) +} + +func (fake *FakeLocalParticipant) SupportsMovingCalls(stub func() error) { + fake.supportsMovingMutex.Lock() + defer fake.supportsMovingMutex.Unlock() + fake.SupportsMovingStub = stub +} + +func (fake *FakeLocalParticipant) SupportsMovingReturns(result1 error) { + fake.supportsMovingMutex.Lock() + defer fake.supportsMovingMutex.Unlock() + fake.SupportsMovingStub = nil + fake.supportsMovingReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsMovingReturnsOnCall(i int, result1 error) { + fake.supportsMovingMutex.Lock() + defer fake.supportsMovingMutex.Unlock() + fake.SupportsMovingStub = nil + if fake.supportsMovingReturnsOnCall == nil { + fake.supportsMovingReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.supportsMovingReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsSyncStreamID() bool { + fake.supportsSyncStreamIDMutex.Lock() + ret, specificReturn := fake.supportsSyncStreamIDReturnsOnCall[len(fake.supportsSyncStreamIDArgsForCall)] + fake.supportsSyncStreamIDArgsForCall = append(fake.supportsSyncStreamIDArgsForCall, struct { + }{}) + stub := fake.SupportsSyncStreamIDStub + fakeReturns := fake.supportsSyncStreamIDReturns + fake.recordInvocation("SupportsSyncStreamID", []interface{}{}) + fake.supportsSyncStreamIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SupportsSyncStreamIDCallCount() int { + fake.supportsSyncStreamIDMutex.RLock() + defer fake.supportsSyncStreamIDMutex.RUnlock() + return len(fake.supportsSyncStreamIDArgsForCall) +} + +func (fake *FakeLocalParticipant) SupportsSyncStreamIDCalls(stub func() bool) { + fake.supportsSyncStreamIDMutex.Lock() + defer fake.supportsSyncStreamIDMutex.Unlock() + fake.SupportsSyncStreamIDStub = stub +} + +func (fake *FakeLocalParticipant) SupportsSyncStreamIDReturns(result1 bool) { + fake.supportsSyncStreamIDMutex.Lock() + defer fake.supportsSyncStreamIDMutex.Unlock() + fake.SupportsSyncStreamIDStub = nil + fake.supportsSyncStreamIDReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsSyncStreamIDReturnsOnCall(i int, result1 bool) { + fake.supportsSyncStreamIDMutex.Lock() + defer fake.supportsSyncStreamIDMutex.Unlock() + fake.SupportsSyncStreamIDStub = nil + if fake.supportsSyncStreamIDReturnsOnCall == nil { + fake.supportsSyncStreamIDReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.supportsSyncStreamIDReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsTransceiverReuse() bool { + fake.supportsTransceiverReuseMutex.Lock() + ret, specificReturn := fake.supportsTransceiverReuseReturnsOnCall[len(fake.supportsTransceiverReuseArgsForCall)] + fake.supportsTransceiverReuseArgsForCall = append(fake.supportsTransceiverReuseArgsForCall, struct { + }{}) + stub := fake.SupportsTransceiverReuseStub + fakeReturns := fake.supportsTransceiverReuseReturns + fake.recordInvocation("SupportsTransceiverReuse", []interface{}{}) + fake.supportsTransceiverReuseMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) SupportsTransceiverReuseCallCount() int { + fake.supportsTransceiverReuseMutex.RLock() + defer fake.supportsTransceiverReuseMutex.RUnlock() + return len(fake.supportsTransceiverReuseArgsForCall) +} + +func (fake *FakeLocalParticipant) SupportsTransceiverReuseCalls(stub func() bool) { + fake.supportsTransceiverReuseMutex.Lock() + defer fake.supportsTransceiverReuseMutex.Unlock() + fake.SupportsTransceiverReuseStub = stub +} + +func (fake *FakeLocalParticipant) SupportsTransceiverReuseReturns(result1 bool) { + fake.supportsTransceiverReuseMutex.Lock() + defer fake.supportsTransceiverReuseMutex.Unlock() + fake.SupportsTransceiverReuseStub = nil + fake.supportsTransceiverReuseReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SupportsTransceiverReuseReturnsOnCall(i int, result1 bool) { + fake.supportsTransceiverReuseMutex.Lock() + defer fake.supportsTransceiverReuseMutex.Unlock() + fake.SupportsTransceiverReuseStub = nil + if fake.supportsTransceiverReuseReturnsOnCall == nil { + fake.supportsTransceiverReuseReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.supportsTransceiverReuseReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) SwapResponseSink(arg1 routing.MessageSink, arg2 types.SignallingCloseReason) { + fake.swapResponseSinkMutex.Lock() + fake.swapResponseSinkArgsForCall = append(fake.swapResponseSinkArgsForCall, struct { + arg1 routing.MessageSink + arg2 types.SignallingCloseReason + }{arg1, arg2}) + stub := fake.SwapResponseSinkStub + fake.recordInvocation("SwapResponseSink", []interface{}{arg1, arg2}) + fake.swapResponseSinkMutex.Unlock() + if stub != nil { + fake.SwapResponseSinkStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) SwapResponseSinkCallCount() int { + fake.swapResponseSinkMutex.RLock() + defer fake.swapResponseSinkMutex.RUnlock() + return len(fake.swapResponseSinkArgsForCall) +} + +func (fake *FakeLocalParticipant) SwapResponseSinkCalls(stub func(routing.MessageSink, types.SignallingCloseReason)) { + fake.swapResponseSinkMutex.Lock() + defer fake.swapResponseSinkMutex.Unlock() + fake.SwapResponseSinkStub = stub +} + +func (fake *FakeLocalParticipant) SwapResponseSinkArgsForCall(i int) (routing.MessageSink, types.SignallingCloseReason) { + fake.swapResponseSinkMutex.RLock() + defer fake.swapResponseSinkMutex.RUnlock() + argsForCall := fake.swapResponseSinkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) TelemetryGuard() *telemetry.ReferenceGuard { + fake.telemetryGuardMutex.Lock() + ret, specificReturn := fake.telemetryGuardReturnsOnCall[len(fake.telemetryGuardArgsForCall)] + fake.telemetryGuardArgsForCall = append(fake.telemetryGuardArgsForCall, struct { + }{}) + stub := fake.TelemetryGuardStub + fakeReturns := fake.telemetryGuardReturns + fake.recordInvocation("TelemetryGuard", []interface{}{}) + fake.telemetryGuardMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) TelemetryGuardCallCount() int { + fake.telemetryGuardMutex.RLock() + defer fake.telemetryGuardMutex.RUnlock() + return len(fake.telemetryGuardArgsForCall) +} + +func (fake *FakeLocalParticipant) TelemetryGuardCalls(stub func() *telemetry.ReferenceGuard) { + fake.telemetryGuardMutex.Lock() + defer fake.telemetryGuardMutex.Unlock() + fake.TelemetryGuardStub = stub +} + +func (fake *FakeLocalParticipant) TelemetryGuardReturns(result1 *telemetry.ReferenceGuard) { + fake.telemetryGuardMutex.Lock() + defer fake.telemetryGuardMutex.Unlock() + fake.TelemetryGuardStub = nil + fake.telemetryGuardReturns = struct { + result1 *telemetry.ReferenceGuard + }{result1} +} + +func (fake *FakeLocalParticipant) TelemetryGuardReturnsOnCall(i int, result1 *telemetry.ReferenceGuard) { + fake.telemetryGuardMutex.Lock() + defer fake.telemetryGuardMutex.Unlock() + fake.TelemetryGuardStub = nil + if fake.telemetryGuardReturnsOnCall == nil { + fake.telemetryGuardReturnsOnCall = make(map[int]struct { + result1 *telemetry.ReferenceGuard + }) + } + fake.telemetryGuardReturnsOnCall[i] = struct { + result1 *telemetry.ReferenceGuard + }{result1} +} + +func (fake *FakeLocalParticipant) ToProto() *livekit.ParticipantInfo { + fake.toProtoMutex.Lock() + ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] + fake.toProtoArgsForCall = append(fake.toProtoArgsForCall, struct { + }{}) + stub := fake.ToProtoStub + fakeReturns := fake.toProtoReturns + fake.recordInvocation("ToProto", []interface{}{}) + fake.toProtoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) ToProtoCallCount() int { + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() + return len(fake.toProtoArgsForCall) +} + +func (fake *FakeLocalParticipant) ToProtoCalls(stub func() *livekit.ParticipantInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = stub +} + +func (fake *FakeLocalParticipant) ToProtoReturns(result1 *livekit.ParticipantInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + fake.toProtoReturns = struct { + result1 *livekit.ParticipantInfo + }{result1} +} + +func (fake *FakeLocalParticipant) ToProtoReturnsOnCall(i int, result1 *livekit.ParticipantInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + if fake.toProtoReturnsOnCall == nil { + fake.toProtoReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + }) + } + fake.toProtoReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + }{result1} +} + +func (fake *FakeLocalParticipant) ToProtoWithVersion() (*livekit.ParticipantInfo, utils.TimedVersion) { + fake.toProtoWithVersionMutex.Lock() + ret, specificReturn := fake.toProtoWithVersionReturnsOnCall[len(fake.toProtoWithVersionArgsForCall)] + fake.toProtoWithVersionArgsForCall = append(fake.toProtoWithVersionArgsForCall, struct { + }{}) + stub := fake.ToProtoWithVersionStub + fakeReturns := fake.toProtoWithVersionReturns + fake.recordInvocation("ToProtoWithVersion", []interface{}{}) + fake.toProtoWithVersionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipant) ToProtoWithVersionCallCount() int { + fake.toProtoWithVersionMutex.RLock() + defer fake.toProtoWithVersionMutex.RUnlock() + return len(fake.toProtoWithVersionArgsForCall) +} + +func (fake *FakeLocalParticipant) ToProtoWithVersionCalls(stub func() (*livekit.ParticipantInfo, utils.TimedVersion)) { + fake.toProtoWithVersionMutex.Lock() + defer fake.toProtoWithVersionMutex.Unlock() + fake.ToProtoWithVersionStub = stub +} + +func (fake *FakeLocalParticipant) ToProtoWithVersionReturns(result1 *livekit.ParticipantInfo, result2 utils.TimedVersion) { + fake.toProtoWithVersionMutex.Lock() + defer fake.toProtoWithVersionMutex.Unlock() + fake.ToProtoWithVersionStub = nil + fake.toProtoWithVersionReturns = struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeLocalParticipant) ToProtoWithVersionReturnsOnCall(i int, result1 *livekit.ParticipantInfo, result2 utils.TimedVersion) { + fake.toProtoWithVersionMutex.Lock() + defer fake.toProtoWithVersionMutex.Unlock() + fake.ToProtoWithVersionStub = nil + if fake.toProtoWithVersionReturnsOnCall == nil { + fake.toProtoWithVersionReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + }) + } + fake.toProtoWithVersionReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeLocalParticipant) UncacheDownTrack(arg1 *webrtc.RTPTransceiver) { + fake.uncacheDownTrackMutex.Lock() + fake.uncacheDownTrackArgsForCall = append(fake.uncacheDownTrackArgsForCall, struct { + arg1 *webrtc.RTPTransceiver + }{arg1}) + stub := fake.UncacheDownTrackStub + fake.recordInvocation("UncacheDownTrack", []interface{}{arg1}) + fake.uncacheDownTrackMutex.Unlock() + if stub != nil { + fake.UncacheDownTrackStub(arg1) + } +} + +func (fake *FakeLocalParticipant) UncacheDownTrackCallCount() int { + fake.uncacheDownTrackMutex.RLock() + defer fake.uncacheDownTrackMutex.RUnlock() + return len(fake.uncacheDownTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) UncacheDownTrackCalls(stub func(*webrtc.RTPTransceiver)) { + fake.uncacheDownTrackMutex.Lock() + defer fake.uncacheDownTrackMutex.Unlock() + fake.UncacheDownTrackStub = stub +} + +func (fake *FakeLocalParticipant) UncacheDownTrackArgsForCall(i int) *webrtc.RTPTransceiver { + fake.uncacheDownTrackMutex.RLock() + defer fake.uncacheDownTrackMutex.RUnlock() + argsForCall := fake.uncacheDownTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UnsubscribeFromDataTrack(arg1 livekit.TrackID) { + fake.unsubscribeFromDataTrackMutex.Lock() + fake.unsubscribeFromDataTrackArgsForCall = append(fake.unsubscribeFromDataTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.UnsubscribeFromDataTrackStub + fake.recordInvocation("UnsubscribeFromDataTrack", []interface{}{arg1}) + fake.unsubscribeFromDataTrackMutex.Unlock() + if stub != nil { + fake.UnsubscribeFromDataTrackStub(arg1) + } +} + +func (fake *FakeLocalParticipant) UnsubscribeFromDataTrackCallCount() int { + fake.unsubscribeFromDataTrackMutex.RLock() + defer fake.unsubscribeFromDataTrackMutex.RUnlock() + return len(fake.unsubscribeFromDataTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) UnsubscribeFromDataTrackCalls(stub func(livekit.TrackID)) { + fake.unsubscribeFromDataTrackMutex.Lock() + defer fake.unsubscribeFromDataTrackMutex.Unlock() + fake.UnsubscribeFromDataTrackStub = stub +} + +func (fake *FakeLocalParticipant) UnsubscribeFromDataTrackArgsForCall(i int) livekit.TrackID { + fake.unsubscribeFromDataTrackMutex.RLock() + defer fake.unsubscribeFromDataTrackMutex.RUnlock() + argsForCall := fake.unsubscribeFromDataTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UnsubscribeFromTrack(arg1 livekit.TrackID) { + fake.unsubscribeFromTrackMutex.Lock() + fake.unsubscribeFromTrackArgsForCall = append(fake.unsubscribeFromTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.UnsubscribeFromTrackStub + fake.recordInvocation("UnsubscribeFromTrack", []interface{}{arg1}) + fake.unsubscribeFromTrackMutex.Unlock() + if stub != nil { + fake.UnsubscribeFromTrackStub(arg1) + } +} + +func (fake *FakeLocalParticipant) UnsubscribeFromTrackCallCount() int { + fake.unsubscribeFromTrackMutex.RLock() + defer fake.unsubscribeFromTrackMutex.RUnlock() + return len(fake.unsubscribeFromTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) UnsubscribeFromTrackCalls(stub func(livekit.TrackID)) { + fake.unsubscribeFromTrackMutex.Lock() + defer fake.unsubscribeFromTrackMutex.Unlock() + fake.UnsubscribeFromTrackStub = stub +} + +func (fake *FakeLocalParticipant) UnsubscribeFromTrackArgsForCall(i int) livekit.TrackID { + fake.unsubscribeFromTrackMutex.RLock() + defer fake.unsubscribeFromTrackMutex.RUnlock() + argsForCall := fake.unsubscribeFromTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UpdateAudioTrack(arg1 *livekit.UpdateLocalAudioTrack) error { + fake.updateAudioTrackMutex.Lock() + ret, specificReturn := fake.updateAudioTrackReturnsOnCall[len(fake.updateAudioTrackArgsForCall)] + fake.updateAudioTrackArgsForCall = append(fake.updateAudioTrackArgsForCall, struct { + arg1 *livekit.UpdateLocalAudioTrack + }{arg1}) + stub := fake.UpdateAudioTrackStub + fakeReturns := fake.updateAudioTrackReturns + fake.recordInvocation("UpdateAudioTrack", []interface{}{arg1}) + fake.updateAudioTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateAudioTrackCallCount() int { + fake.updateAudioTrackMutex.RLock() + defer fake.updateAudioTrackMutex.RUnlock() + return len(fake.updateAudioTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateAudioTrackCalls(stub func(*livekit.UpdateLocalAudioTrack) error) { + fake.updateAudioTrackMutex.Lock() + defer fake.updateAudioTrackMutex.Unlock() + fake.UpdateAudioTrackStub = stub +} + +func (fake *FakeLocalParticipant) UpdateAudioTrackArgsForCall(i int) *livekit.UpdateLocalAudioTrack { + fake.updateAudioTrackMutex.RLock() + defer fake.updateAudioTrackMutex.RUnlock() + argsForCall := fake.updateAudioTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UpdateAudioTrackReturns(result1 error) { + fake.updateAudioTrackMutex.Lock() + defer fake.updateAudioTrackMutex.Unlock() + fake.UpdateAudioTrackStub = nil + fake.updateAudioTrackReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateAudioTrackReturnsOnCall(i int, result1 error) { + fake.updateAudioTrackMutex.Lock() + defer fake.updateAudioTrackMutex.Unlock() + fake.UpdateAudioTrackStub = nil + if fake.updateAudioTrackReturnsOnCall == nil { + fake.updateAudioTrackReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateAudioTrackReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateDataTrackSubscriptionOptions(arg1 livekit.TrackID, arg2 *livekit.DataTrackSubscriptionOptions) { + fake.updateDataTrackSubscriptionOptionsMutex.Lock() + fake.updateDataTrackSubscriptionOptionsArgsForCall = append(fake.updateDataTrackSubscriptionOptionsArgsForCall, struct { + arg1 livekit.TrackID + arg2 *livekit.DataTrackSubscriptionOptions + }{arg1, arg2}) + stub := fake.UpdateDataTrackSubscriptionOptionsStub + fake.recordInvocation("UpdateDataTrackSubscriptionOptions", []interface{}{arg1, arg2}) + fake.updateDataTrackSubscriptionOptionsMutex.Unlock() + if stub != nil { + fake.UpdateDataTrackSubscriptionOptionsStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) UpdateDataTrackSubscriptionOptionsCallCount() int { + fake.updateDataTrackSubscriptionOptionsMutex.RLock() + defer fake.updateDataTrackSubscriptionOptionsMutex.RUnlock() + return len(fake.updateDataTrackSubscriptionOptionsArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateDataTrackSubscriptionOptionsCalls(stub func(livekit.TrackID, *livekit.DataTrackSubscriptionOptions)) { + fake.updateDataTrackSubscriptionOptionsMutex.Lock() + defer fake.updateDataTrackSubscriptionOptionsMutex.Unlock() + fake.UpdateDataTrackSubscriptionOptionsStub = stub +} + +func (fake *FakeLocalParticipant) UpdateDataTrackSubscriptionOptionsArgsForCall(i int) (livekit.TrackID, *livekit.DataTrackSubscriptionOptions) { + fake.updateDataTrackSubscriptionOptionsMutex.RLock() + defer fake.updateDataTrackSubscriptionOptionsMutex.RUnlock() + argsForCall := fake.updateDataTrackSubscriptionOptionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) UpdateLastSeenSignal() { + fake.updateLastSeenSignalMutex.Lock() + fake.updateLastSeenSignalArgsForCall = append(fake.updateLastSeenSignalArgsForCall, struct { + }{}) + stub := fake.UpdateLastSeenSignalStub + fake.recordInvocation("UpdateLastSeenSignal", []interface{}{}) + fake.updateLastSeenSignalMutex.Unlock() + if stub != nil { + fake.UpdateLastSeenSignalStub() + } +} + +func (fake *FakeLocalParticipant) UpdateLastSeenSignalCallCount() int { + fake.updateLastSeenSignalMutex.RLock() + defer fake.updateLastSeenSignalMutex.RUnlock() + return len(fake.updateLastSeenSignalArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateLastSeenSignalCalls(stub func()) { + fake.updateLastSeenSignalMutex.Lock() + defer fake.updateLastSeenSignalMutex.Unlock() + fake.UpdateLastSeenSignalStub = stub +} + +func (fake *FakeLocalParticipant) UpdateMediaLoss(arg1 livekit.NodeID, arg2 livekit.TrackID, arg3 uint32) error { + fake.updateMediaLossMutex.Lock() + ret, specificReturn := fake.updateMediaLossReturnsOnCall[len(fake.updateMediaLossArgsForCall)] + fake.updateMediaLossArgsForCall = append(fake.updateMediaLossArgsForCall, struct { + arg1 livekit.NodeID + arg2 livekit.TrackID + arg3 uint32 + }{arg1, arg2, arg3}) + stub := fake.UpdateMediaLossStub + fakeReturns := fake.updateMediaLossReturns + fake.recordInvocation("UpdateMediaLoss", []interface{}{arg1, arg2, arg3}) + fake.updateMediaLossMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateMediaLossCallCount() int { + fake.updateMediaLossMutex.RLock() + defer fake.updateMediaLossMutex.RUnlock() + return len(fake.updateMediaLossArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateMediaLossCalls(stub func(livekit.NodeID, livekit.TrackID, uint32) error) { + fake.updateMediaLossMutex.Lock() + defer fake.updateMediaLossMutex.Unlock() + fake.UpdateMediaLossStub = stub +} + +func (fake *FakeLocalParticipant) UpdateMediaLossArgsForCall(i int) (livekit.NodeID, livekit.TrackID, uint32) { + fake.updateMediaLossMutex.RLock() + defer fake.updateMediaLossMutex.RUnlock() + argsForCall := fake.updateMediaLossArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) UpdateMediaLossReturns(result1 error) { + fake.updateMediaLossMutex.Lock() + defer fake.updateMediaLossMutex.Unlock() + fake.UpdateMediaLossStub = nil + fake.updateMediaLossReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateMediaLossReturnsOnCall(i int, result1 error) { + fake.updateMediaLossMutex.Lock() + defer fake.updateMediaLossMutex.Unlock() + fake.UpdateMediaLossStub = nil + if fake.updateMediaLossReturnsOnCall == nil { + fake.updateMediaLossReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateMediaLossReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateMediaRTT(arg1 uint32) { + fake.updateMediaRTTMutex.Lock() + fake.updateMediaRTTArgsForCall = append(fake.updateMediaRTTArgsForCall, struct { + arg1 uint32 + }{arg1}) + stub := fake.UpdateMediaRTTStub + fake.recordInvocation("UpdateMediaRTT", []interface{}{arg1}) + fake.updateMediaRTTMutex.Unlock() + if stub != nil { + fake.UpdateMediaRTTStub(arg1) + } +} + +func (fake *FakeLocalParticipant) UpdateMediaRTTCallCount() int { + fake.updateMediaRTTMutex.RLock() + defer fake.updateMediaRTTMutex.RUnlock() + return len(fake.updateMediaRTTArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateMediaRTTCalls(stub func(uint32)) { + fake.updateMediaRTTMutex.Lock() + defer fake.updateMediaRTTMutex.Unlock() + fake.UpdateMediaRTTStub = stub +} + +func (fake *FakeLocalParticipant) UpdateMediaRTTArgsForCall(i int) uint32 { + fake.updateMediaRTTMutex.RLock() + defer fake.updateMediaRTTMutex.RUnlock() + argsForCall := fake.updateMediaRTTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UpdateMetadata(arg1 *livekit.UpdateParticipantMetadata, arg2 bool) error { + fake.updateMetadataMutex.Lock() + ret, specificReturn := fake.updateMetadataReturnsOnCall[len(fake.updateMetadataArgsForCall)] + fake.updateMetadataArgsForCall = append(fake.updateMetadataArgsForCall, struct { + arg1 *livekit.UpdateParticipantMetadata + arg2 bool + }{arg1, arg2}) + stub := fake.UpdateMetadataStub + fakeReturns := fake.updateMetadataReturns + fake.recordInvocation("UpdateMetadata", []interface{}{arg1, arg2}) + fake.updateMetadataMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateMetadataCallCount() int { + fake.updateMetadataMutex.RLock() + defer fake.updateMetadataMutex.RUnlock() + return len(fake.updateMetadataArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateMetadataCalls(stub func(*livekit.UpdateParticipantMetadata, bool) error) { + fake.updateMetadataMutex.Lock() + defer fake.updateMetadataMutex.Unlock() + fake.UpdateMetadataStub = stub +} + +func (fake *FakeLocalParticipant) UpdateMetadataArgsForCall(i int) (*livekit.UpdateParticipantMetadata, bool) { + fake.updateMetadataMutex.RLock() + defer fake.updateMetadataMutex.RUnlock() + argsForCall := fake.updateMetadataArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) UpdateMetadataReturns(result1 error) { + fake.updateMetadataMutex.Lock() + defer fake.updateMetadataMutex.Unlock() + fake.UpdateMetadataStub = nil + fake.updateMetadataReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateMetadataReturnsOnCall(i int, result1 error) { + fake.updateMetadataMutex.Lock() + defer fake.updateMetadataMutex.Unlock() + fake.UpdateMetadataStub = nil + if fake.updateMetadataReturnsOnCall == nil { + fake.updateMetadataReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateMetadataReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateSignalingRTT(arg1 uint32) { + fake.updateSignalingRTTMutex.Lock() + fake.updateSignalingRTTArgsForCall = append(fake.updateSignalingRTTArgsForCall, struct { + arg1 uint32 + }{arg1}) + stub := fake.UpdateSignalingRTTStub + fake.recordInvocation("UpdateSignalingRTT", []interface{}{arg1}) + fake.updateSignalingRTTMutex.Unlock() + if stub != nil { + fake.UpdateSignalingRTTStub(arg1) + } +} + +func (fake *FakeLocalParticipant) UpdateSignalingRTTCallCount() int { + fake.updateSignalingRTTMutex.RLock() + defer fake.updateSignalingRTTMutex.RUnlock() + return len(fake.updateSignalingRTTArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateSignalingRTTCalls(stub func(uint32)) { + fake.updateSignalingRTTMutex.Lock() + defer fake.updateSignalingRTTMutex.Unlock() + fake.UpdateSignalingRTTStub = stub +} + +func (fake *FakeLocalParticipant) UpdateSignalingRTTArgsForCall(i int) uint32 { + fake.updateSignalingRTTMutex.RLock() + defer fake.updateSignalingRTTMutex.RUnlock() + argsForCall := fake.updateSignalingRTTArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UpdateSubscribedAudioCodecs(arg1 livekit.NodeID, arg2 livekit.TrackID, arg3 []*livekit.SubscribedAudioCodec) error { + var arg3Copy []*livekit.SubscribedAudioCodec + if arg3 != nil { + arg3Copy = make([]*livekit.SubscribedAudioCodec, len(arg3)) + copy(arg3Copy, arg3) + } + fake.updateSubscribedAudioCodecsMutex.Lock() + ret, specificReturn := fake.updateSubscribedAudioCodecsReturnsOnCall[len(fake.updateSubscribedAudioCodecsArgsForCall)] + fake.updateSubscribedAudioCodecsArgsForCall = append(fake.updateSubscribedAudioCodecsArgsForCall, struct { + arg1 livekit.NodeID + arg2 livekit.TrackID + arg3 []*livekit.SubscribedAudioCodec + }{arg1, arg2, arg3Copy}) + stub := fake.UpdateSubscribedAudioCodecsStub + fakeReturns := fake.updateSubscribedAudioCodecsReturns + fake.recordInvocation("UpdateSubscribedAudioCodecs", []interface{}{arg1, arg2, arg3Copy}) + fake.updateSubscribedAudioCodecsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateSubscribedAudioCodecsCallCount() int { + fake.updateSubscribedAudioCodecsMutex.RLock() + defer fake.updateSubscribedAudioCodecsMutex.RUnlock() + return len(fake.updateSubscribedAudioCodecsArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateSubscribedAudioCodecsCalls(stub func(livekit.NodeID, livekit.TrackID, []*livekit.SubscribedAudioCodec) error) { + fake.updateSubscribedAudioCodecsMutex.Lock() + defer fake.updateSubscribedAudioCodecsMutex.Unlock() + fake.UpdateSubscribedAudioCodecsStub = stub +} + +func (fake *FakeLocalParticipant) UpdateSubscribedAudioCodecsArgsForCall(i int) (livekit.NodeID, livekit.TrackID, []*livekit.SubscribedAudioCodec) { + fake.updateSubscribedAudioCodecsMutex.RLock() + defer fake.updateSubscribedAudioCodecsMutex.RUnlock() + argsForCall := fake.updateSubscribedAudioCodecsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) UpdateSubscribedAudioCodecsReturns(result1 error) { + fake.updateSubscribedAudioCodecsMutex.Lock() + defer fake.updateSubscribedAudioCodecsMutex.Unlock() + fake.UpdateSubscribedAudioCodecsStub = nil + fake.updateSubscribedAudioCodecsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateSubscribedAudioCodecsReturnsOnCall(i int, result1 error) { + fake.updateSubscribedAudioCodecsMutex.Lock() + defer fake.updateSubscribedAudioCodecsMutex.Unlock() + fake.UpdateSubscribedAudioCodecsStub = nil + if fake.updateSubscribedAudioCodecsReturnsOnCall == nil { + fake.updateSubscribedAudioCodecsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscribedAudioCodecsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateSubscribedQuality(arg1 livekit.NodeID, arg2 livekit.TrackID, arg3 []types.SubscribedCodecQuality) error { + var arg3Copy []types.SubscribedCodecQuality + if arg3 != nil { + arg3Copy = make([]types.SubscribedCodecQuality, len(arg3)) + copy(arg3Copy, arg3) + } + fake.updateSubscribedQualityMutex.Lock() + ret, specificReturn := fake.updateSubscribedQualityReturnsOnCall[len(fake.updateSubscribedQualityArgsForCall)] + fake.updateSubscribedQualityArgsForCall = append(fake.updateSubscribedQualityArgsForCall, struct { + arg1 livekit.NodeID + arg2 livekit.TrackID + arg3 []types.SubscribedCodecQuality + }{arg1, arg2, arg3Copy}) + stub := fake.UpdateSubscribedQualityStub + fakeReturns := fake.updateSubscribedQualityReturns + fake.recordInvocation("UpdateSubscribedQuality", []interface{}{arg1, arg2, arg3Copy}) + fake.updateSubscribedQualityMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateSubscribedQualityCallCount() int { + fake.updateSubscribedQualityMutex.RLock() + defer fake.updateSubscribedQualityMutex.RUnlock() + return len(fake.updateSubscribedQualityArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateSubscribedQualityCalls(stub func(livekit.NodeID, livekit.TrackID, []types.SubscribedCodecQuality) error) { + fake.updateSubscribedQualityMutex.Lock() + defer fake.updateSubscribedQualityMutex.Unlock() + fake.UpdateSubscribedQualityStub = stub +} + +func (fake *FakeLocalParticipant) UpdateSubscribedQualityArgsForCall(i int) (livekit.NodeID, livekit.TrackID, []types.SubscribedCodecQuality) { + fake.updateSubscribedQualityMutex.RLock() + defer fake.updateSubscribedQualityMutex.RUnlock() + argsForCall := fake.updateSubscribedQualityArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) UpdateSubscribedQualityReturns(result1 error) { + fake.updateSubscribedQualityMutex.Lock() + defer fake.updateSubscribedQualityMutex.Unlock() + fake.UpdateSubscribedQualityStub = nil + fake.updateSubscribedQualityReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateSubscribedQualityReturnsOnCall(i int, result1 error) { + fake.updateSubscribedQualityMutex.Lock() + defer fake.updateSubscribedQualityMutex.Unlock() + fake.UpdateSubscribedQualityStub = nil + if fake.updateSubscribedQualityReturnsOnCall == nil { + fake.updateSubscribedQualityReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscribedQualityReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateSubscribedTrackSettings(arg1 livekit.TrackID, arg2 *livekit.UpdateTrackSettings) { + fake.updateSubscribedTrackSettingsMutex.Lock() + fake.updateSubscribedTrackSettingsArgsForCall = append(fake.updateSubscribedTrackSettingsArgsForCall, struct { + arg1 livekit.TrackID + arg2 *livekit.UpdateTrackSettings + }{arg1, arg2}) + stub := fake.UpdateSubscribedTrackSettingsStub + fake.recordInvocation("UpdateSubscribedTrackSettings", []interface{}{arg1, arg2}) + fake.updateSubscribedTrackSettingsMutex.Unlock() + if stub != nil { + fake.UpdateSubscribedTrackSettingsStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) UpdateSubscribedTrackSettingsCallCount() int { + fake.updateSubscribedTrackSettingsMutex.RLock() + defer fake.updateSubscribedTrackSettingsMutex.RUnlock() + return len(fake.updateSubscribedTrackSettingsArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateSubscribedTrackSettingsCalls(stub func(livekit.TrackID, *livekit.UpdateTrackSettings)) { + fake.updateSubscribedTrackSettingsMutex.Lock() + defer fake.updateSubscribedTrackSettingsMutex.Unlock() + fake.UpdateSubscribedTrackSettingsStub = stub +} + +func (fake *FakeLocalParticipant) UpdateSubscribedTrackSettingsArgsForCall(i int) (livekit.TrackID, *livekit.UpdateTrackSettings) { + fake.updateSubscribedTrackSettingsMutex.RLock() + defer fake.updateSubscribedTrackSettingsMutex.RUnlock() + argsForCall := fake.updateSubscribedTrackSettingsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) UpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission, arg2 utils.TimedVersion, arg3 func(participantID livekit.ParticipantID) types.LocalParticipant) error { + fake.updateSubscriptionPermissionMutex.Lock() + ret, specificReturn := fake.updateSubscriptionPermissionReturnsOnCall[len(fake.updateSubscriptionPermissionArgsForCall)] + fake.updateSubscriptionPermissionArgsForCall = append(fake.updateSubscriptionPermissionArgsForCall, struct { + arg1 *livekit.SubscriptionPermission + arg2 utils.TimedVersion + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant + }{arg1, arg2, arg3}) + stub := fake.UpdateSubscriptionPermissionStub + fakeReturns := fake.updateSubscriptionPermissionReturns + fake.recordInvocation("UpdateSubscriptionPermission", []interface{}{arg1, arg2, arg3}) + fake.updateSubscriptionPermissionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionCallCount() int { + fake.updateSubscriptionPermissionMutex.RLock() + defer fake.updateSubscriptionPermissionMutex.RUnlock() + return len(fake.updateSubscriptionPermissionArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission, utils.TimedVersion, func(participantID livekit.ParticipantID) types.LocalParticipant) error) { + fake.updateSubscriptionPermissionMutex.Lock() + defer fake.updateSubscriptionPermissionMutex.Unlock() + fake.UpdateSubscriptionPermissionStub = stub +} + +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionArgsForCall(i int) (*livekit.SubscriptionPermission, utils.TimedVersion, func(participantID livekit.ParticipantID) types.LocalParticipant) { + fake.updateSubscriptionPermissionMutex.RLock() + defer fake.updateSubscriptionPermissionMutex.RUnlock() + argsForCall := fake.updateSubscriptionPermissionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionReturns(result1 error) { + fake.updateSubscriptionPermissionMutex.Lock() + defer fake.updateSubscriptionPermissionMutex.Unlock() + fake.UpdateSubscriptionPermissionStub = nil + fake.updateSubscriptionPermissionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateSubscriptionPermissionReturnsOnCall(i int, result1 error) { + fake.updateSubscriptionPermissionMutex.Lock() + defer fake.updateSubscriptionPermissionMutex.Unlock() + fake.UpdateSubscriptionPermissionStub = nil + if fake.updateSubscriptionPermissionReturnsOnCall == nil { + fake.updateSubscriptionPermissionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscriptionPermissionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateVideoTrack(arg1 *livekit.UpdateLocalVideoTrack) error { + fake.updateVideoTrackMutex.Lock() + ret, specificReturn := fake.updateVideoTrackReturnsOnCall[len(fake.updateVideoTrackArgsForCall)] + fake.updateVideoTrackArgsForCall = append(fake.updateVideoTrackArgsForCall, struct { + arg1 *livekit.UpdateLocalVideoTrack + }{arg1}) + stub := fake.UpdateVideoTrackStub + fakeReturns := fake.updateVideoTrackReturns + fake.recordInvocation("UpdateVideoTrack", []interface{}{arg1}) + fake.updateVideoTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) UpdateVideoTrackCallCount() int { + fake.updateVideoTrackMutex.RLock() + defer fake.updateVideoTrackMutex.RUnlock() + return len(fake.updateVideoTrackArgsForCall) +} + +func (fake *FakeLocalParticipant) UpdateVideoTrackCalls(stub func(*livekit.UpdateLocalVideoTrack) error) { + fake.updateVideoTrackMutex.Lock() + defer fake.updateVideoTrackMutex.Unlock() + fake.UpdateVideoTrackStub = stub +} + +func (fake *FakeLocalParticipant) UpdateVideoTrackArgsForCall(i int) *livekit.UpdateLocalVideoTrack { + fake.updateVideoTrackMutex.RLock() + defer fake.updateVideoTrackMutex.RUnlock() + argsForCall := fake.updateVideoTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) UpdateVideoTrackReturns(result1 error) { + fake.updateVideoTrackMutex.Lock() + defer fake.updateVideoTrackMutex.Unlock() + fake.UpdateVideoTrackStub = nil + fake.updateVideoTrackReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) UpdateVideoTrackReturnsOnCall(i int, result1 error) { + fake.updateVideoTrackMutex.Lock() + defer fake.updateVideoTrackMutex.Unlock() + fake.UpdateVideoTrackStub = nil + if fake.updateVideoTrackReturnsOnCall == nil { + fake.updateVideoTrackReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateVideoTrackReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) Verify() bool { + fake.verifyMutex.Lock() + ret, specificReturn := fake.verifyReturnsOnCall[len(fake.verifyArgsForCall)] + fake.verifyArgsForCall = append(fake.verifyArgsForCall, struct { + }{}) + stub := fake.VerifyStub + fakeReturns := fake.verifyReturns + fake.recordInvocation("Verify", []interface{}{}) + fake.verifyMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) VerifyCallCount() int { + fake.verifyMutex.RLock() + defer fake.verifyMutex.RUnlock() + return len(fake.verifyArgsForCall) +} + +func (fake *FakeLocalParticipant) VerifyCalls(stub func() bool) { + fake.verifyMutex.Lock() + defer fake.verifyMutex.Unlock() + fake.VerifyStub = stub +} + +func (fake *FakeLocalParticipant) VerifyReturns(result1 bool) { + fake.verifyMutex.Lock() + defer fake.verifyMutex.Unlock() + fake.VerifyStub = nil + fake.verifyReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) VerifyReturnsOnCall(i int, result1 bool) { + fake.verifyMutex.Lock() + defer fake.verifyMutex.Unlock() + fake.VerifyStub = nil + if fake.verifyReturnsOnCall == nil { + fake.verifyReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.verifyReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipant) VerifySubscribeParticipantInfo(arg1 livekit.ParticipantID, arg2 uint32) { + fake.verifySubscribeParticipantInfoMutex.Lock() + fake.verifySubscribeParticipantInfoArgsForCall = append(fake.verifySubscribeParticipantInfoArgsForCall, struct { + arg1 livekit.ParticipantID + arg2 uint32 + }{arg1, arg2}) + stub := fake.VerifySubscribeParticipantInfoStub + fake.recordInvocation("VerifySubscribeParticipantInfo", []interface{}{arg1, arg2}) + fake.verifySubscribeParticipantInfoMutex.Unlock() + if stub != nil { + fake.VerifySubscribeParticipantInfoStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipant) VerifySubscribeParticipantInfoCallCount() int { + fake.verifySubscribeParticipantInfoMutex.RLock() + defer fake.verifySubscribeParticipantInfoMutex.RUnlock() + return len(fake.verifySubscribeParticipantInfoArgsForCall) +} + +func (fake *FakeLocalParticipant) VerifySubscribeParticipantInfoCalls(stub func(livekit.ParticipantID, uint32)) { + fake.verifySubscribeParticipantInfoMutex.Lock() + defer fake.verifySubscribeParticipantInfoMutex.Unlock() + fake.VerifySubscribeParticipantInfoStub = stub +} + +func (fake *FakeLocalParticipant) VerifySubscribeParticipantInfoArgsForCall(i int) (livekit.ParticipantID, uint32) { + fake.verifySubscribeParticipantInfoMutex.RLock() + defer fake.verifySubscribeParticipantInfoMutex.RUnlock() + argsForCall := fake.verifySubscribeParticipantInfoArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipant) Version() utils.TimedVersion { + fake.versionMutex.Lock() + ret, specificReturn := fake.versionReturnsOnCall[len(fake.versionArgsForCall)] + fake.versionArgsForCall = append(fake.versionArgsForCall, struct { + }{}) + stub := fake.VersionStub + fakeReturns := fake.versionReturns + fake.recordInvocation("Version", []interface{}{}) + fake.versionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) VersionCallCount() int { + fake.versionMutex.RLock() + defer fake.versionMutex.RUnlock() + return len(fake.versionArgsForCall) +} + +func (fake *FakeLocalParticipant) VersionCalls(stub func() utils.TimedVersion) { + fake.versionMutex.Lock() + defer fake.versionMutex.Unlock() + fake.VersionStub = stub +} + +func (fake *FakeLocalParticipant) VersionReturns(result1 utils.TimedVersion) { + fake.versionMutex.Lock() + defer fake.versionMutex.Unlock() + fake.VersionStub = nil + fake.versionReturns = struct { + result1 utils.TimedVersion + }{result1} +} + +func (fake *FakeLocalParticipant) VersionReturnsOnCall(i int, result1 utils.TimedVersion) { + fake.versionMutex.Lock() + defer fake.versionMutex.Unlock() + fake.VersionStub = nil + if fake.versionReturnsOnCall == nil { + fake.versionReturnsOnCall = make(map[int]struct { + result1 utils.TimedVersion + }) + } + fake.versionReturnsOnCall[i] = struct { + result1 utils.TimedVersion + }{result1} +} + +func (fake *FakeLocalParticipant) WaitUntilSubscribed(arg1 time.Duration) error { + fake.waitUntilSubscribedMutex.Lock() + ret, specificReturn := fake.waitUntilSubscribedReturnsOnCall[len(fake.waitUntilSubscribedArgsForCall)] + fake.waitUntilSubscribedArgsForCall = append(fake.waitUntilSubscribedArgsForCall, struct { + arg1 time.Duration + }{arg1}) + stub := fake.WaitUntilSubscribedStub + fakeReturns := fake.waitUntilSubscribedReturns + fake.recordInvocation("WaitUntilSubscribed", []interface{}{arg1}) + fake.waitUntilSubscribedMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) WaitUntilSubscribedCallCount() int { + fake.waitUntilSubscribedMutex.RLock() + defer fake.waitUntilSubscribedMutex.RUnlock() + return len(fake.waitUntilSubscribedArgsForCall) +} + +func (fake *FakeLocalParticipant) WaitUntilSubscribedCalls(stub func(time.Duration) error) { + fake.waitUntilSubscribedMutex.Lock() + defer fake.waitUntilSubscribedMutex.Unlock() + fake.WaitUntilSubscribedStub = stub +} + +func (fake *FakeLocalParticipant) WaitUntilSubscribedArgsForCall(i int) time.Duration { + fake.waitUntilSubscribedMutex.RLock() + defer fake.waitUntilSubscribedMutex.RUnlock() + argsForCall := fake.waitUntilSubscribedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) WaitUntilSubscribedReturns(result1 error) { + fake.waitUntilSubscribedMutex.Lock() + defer fake.waitUntilSubscribedMutex.Unlock() + fake.WaitUntilSubscribedStub = nil + fake.waitUntilSubscribedReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) WaitUntilSubscribedReturnsOnCall(i int, result1 error) { + fake.waitUntilSubscribedMutex.Lock() + defer fake.waitUntilSubscribedMutex.Unlock() + fake.WaitUntilSubscribedStub = nil + if fake.waitUntilSubscribedReturnsOnCall == nil { + fake.waitUntilSubscribedReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.waitUntilSubscribedReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) WriteSubscriberRTCP(arg1 []rtcp.Packet) error { + var arg1Copy []rtcp.Packet + if arg1 != nil { + arg1Copy = make([]rtcp.Packet, len(arg1)) + copy(arg1Copy, arg1) + } + fake.writeSubscriberRTCPMutex.Lock() + ret, specificReturn := fake.writeSubscriberRTCPReturnsOnCall[len(fake.writeSubscriberRTCPArgsForCall)] + fake.writeSubscriberRTCPArgsForCall = append(fake.writeSubscriberRTCPArgsForCall, struct { + arg1 []rtcp.Packet + }{arg1Copy}) + stub := fake.WriteSubscriberRTCPStub + fakeReturns := fake.writeSubscriberRTCPReturns + fake.recordInvocation("WriteSubscriberRTCP", []interface{}{arg1Copy}) + fake.writeSubscriberRTCPMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipant) WriteSubscriberRTCPCallCount() int { + fake.writeSubscriberRTCPMutex.RLock() + defer fake.writeSubscriberRTCPMutex.RUnlock() + return len(fake.writeSubscriberRTCPArgsForCall) +} + +func (fake *FakeLocalParticipant) WriteSubscriberRTCPCalls(stub func([]rtcp.Packet) error) { + fake.writeSubscriberRTCPMutex.Lock() + defer fake.writeSubscriberRTCPMutex.Unlock() + fake.WriteSubscriberRTCPStub = stub +} + +func (fake *FakeLocalParticipant) WriteSubscriberRTCPArgsForCall(i int) []rtcp.Packet { + fake.writeSubscriberRTCPMutex.RLock() + defer fake.writeSubscriberRTCPMutex.RUnlock() + argsForCall := fake.writeSubscriberRTCPArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipant) WriteSubscriberRTCPReturns(result1 error) { + fake.writeSubscriberRTCPMutex.Lock() + defer fake.writeSubscriberRTCPMutex.Unlock() + fake.WriteSubscriberRTCPStub = nil + fake.writeSubscriberRTCPReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) WriteSubscriberRTCPReturnsOnCall(i int, result1 error) { + fake.writeSubscriberRTCPMutex.Lock() + defer fake.writeSubscriberRTCPMutex.Unlock() + fake.WriteSubscriberRTCPStub = nil + if fake.writeSubscriberRTCPReturnsOnCall == nil { + fake.writeSubscriberRTCPReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeSubscriberRTCPReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipant) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeLocalParticipant) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.LocalParticipant = new(FakeLocalParticipant) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_local_participant_helper.go b/livekit/pkg/rtc/types/typesfakes/fake_local_participant_helper.go new file mode 100644 index 0000000..4676dc6 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_local_participant_helper.go @@ -0,0 +1,542 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeLocalParticipantHelper struct { + GetCachedReliableDataMessageStub func(map[livekit.ParticipantID]uint32) []*types.DataMessageCache + getCachedReliableDataMessageMutex sync.RWMutex + getCachedReliableDataMessageArgsForCall []struct { + arg1 map[livekit.ParticipantID]uint32 + } + getCachedReliableDataMessageReturns struct { + result1 []*types.DataMessageCache + } + getCachedReliableDataMessageReturnsOnCall map[int]struct { + result1 []*types.DataMessageCache + } + GetParticipantInfoStub func(livekit.ParticipantID) *livekit.ParticipantInfo + getParticipantInfoMutex sync.RWMutex + getParticipantInfoArgsForCall []struct { + arg1 livekit.ParticipantID + } + getParticipantInfoReturns struct { + result1 *livekit.ParticipantInfo + } + getParticipantInfoReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + } + GetRegionSettingsStub func(string) *livekit.RegionSettings + getRegionSettingsMutex sync.RWMutex + getRegionSettingsArgsForCall []struct { + arg1 string + } + getRegionSettingsReturns struct { + result1 *livekit.RegionSettings + } + getRegionSettingsReturnsOnCall map[int]struct { + result1 *livekit.RegionSettings + } + GetSubscriberForwarderStateStub func(types.LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error) + getSubscriberForwarderStateMutex sync.RWMutex + getSubscriberForwarderStateArgsForCall []struct { + arg1 types.LocalParticipant + } + getSubscriberForwarderStateReturns struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + result2 error + } + getSubscriberForwarderStateReturnsOnCall map[int]struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + result2 error + } + ResolveDataTrackStub func(types.LocalParticipant, livekit.TrackID) types.DataResolverResult + resolveDataTrackMutex sync.RWMutex + resolveDataTrackArgsForCall []struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + } + resolveDataTrackReturns struct { + result1 types.DataResolverResult + } + resolveDataTrackReturnsOnCall map[int]struct { + result1 types.DataResolverResult + } + ResolveMediaTrackStub func(types.LocalParticipant, livekit.TrackID) types.MediaResolverResult + resolveMediaTrackMutex sync.RWMutex + resolveMediaTrackArgsForCall []struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + } + resolveMediaTrackReturns struct { + result1 types.MediaResolverResult + } + resolveMediaTrackReturnsOnCall map[int]struct { + result1 types.MediaResolverResult + } + ShouldRegressCodecStub func() bool + shouldRegressCodecMutex sync.RWMutex + shouldRegressCodecArgsForCall []struct { + } + shouldRegressCodecReturns struct { + result1 bool + } + shouldRegressCodecReturnsOnCall map[int]struct { + result1 bool + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeLocalParticipantHelper) GetCachedReliableDataMessage(arg1 map[livekit.ParticipantID]uint32) []*types.DataMessageCache { + fake.getCachedReliableDataMessageMutex.Lock() + ret, specificReturn := fake.getCachedReliableDataMessageReturnsOnCall[len(fake.getCachedReliableDataMessageArgsForCall)] + fake.getCachedReliableDataMessageArgsForCall = append(fake.getCachedReliableDataMessageArgsForCall, struct { + arg1 map[livekit.ParticipantID]uint32 + }{arg1}) + stub := fake.GetCachedReliableDataMessageStub + fakeReturns := fake.getCachedReliableDataMessageReturns + fake.recordInvocation("GetCachedReliableDataMessage", []interface{}{arg1}) + fake.getCachedReliableDataMessageMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantHelper) GetCachedReliableDataMessageCallCount() int { + fake.getCachedReliableDataMessageMutex.RLock() + defer fake.getCachedReliableDataMessageMutex.RUnlock() + return len(fake.getCachedReliableDataMessageArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) GetCachedReliableDataMessageCalls(stub func(map[livekit.ParticipantID]uint32) []*types.DataMessageCache) { + fake.getCachedReliableDataMessageMutex.Lock() + defer fake.getCachedReliableDataMessageMutex.Unlock() + fake.GetCachedReliableDataMessageStub = stub +} + +func (fake *FakeLocalParticipantHelper) GetCachedReliableDataMessageArgsForCall(i int) map[livekit.ParticipantID]uint32 { + fake.getCachedReliableDataMessageMutex.RLock() + defer fake.getCachedReliableDataMessageMutex.RUnlock() + argsForCall := fake.getCachedReliableDataMessageArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantHelper) GetCachedReliableDataMessageReturns(result1 []*types.DataMessageCache) { + fake.getCachedReliableDataMessageMutex.Lock() + defer fake.getCachedReliableDataMessageMutex.Unlock() + fake.GetCachedReliableDataMessageStub = nil + fake.getCachedReliableDataMessageReturns = struct { + result1 []*types.DataMessageCache + }{result1} +} + +func (fake *FakeLocalParticipantHelper) GetCachedReliableDataMessageReturnsOnCall(i int, result1 []*types.DataMessageCache) { + fake.getCachedReliableDataMessageMutex.Lock() + defer fake.getCachedReliableDataMessageMutex.Unlock() + fake.GetCachedReliableDataMessageStub = nil + if fake.getCachedReliableDataMessageReturnsOnCall == nil { + fake.getCachedReliableDataMessageReturnsOnCall = make(map[int]struct { + result1 []*types.DataMessageCache + }) + } + fake.getCachedReliableDataMessageReturnsOnCall[i] = struct { + result1 []*types.DataMessageCache + }{result1} +} + +func (fake *FakeLocalParticipantHelper) GetParticipantInfo(arg1 livekit.ParticipantID) *livekit.ParticipantInfo { + fake.getParticipantInfoMutex.Lock() + ret, specificReturn := fake.getParticipantInfoReturnsOnCall[len(fake.getParticipantInfoArgsForCall)] + fake.getParticipantInfoArgsForCall = append(fake.getParticipantInfoArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.GetParticipantInfoStub + fakeReturns := fake.getParticipantInfoReturns + fake.recordInvocation("GetParticipantInfo", []interface{}{arg1}) + fake.getParticipantInfoMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantHelper) GetParticipantInfoCallCount() int { + fake.getParticipantInfoMutex.RLock() + defer fake.getParticipantInfoMutex.RUnlock() + return len(fake.getParticipantInfoArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) GetParticipantInfoCalls(stub func(livekit.ParticipantID) *livekit.ParticipantInfo) { + fake.getParticipantInfoMutex.Lock() + defer fake.getParticipantInfoMutex.Unlock() + fake.GetParticipantInfoStub = stub +} + +func (fake *FakeLocalParticipantHelper) GetParticipantInfoArgsForCall(i int) livekit.ParticipantID { + fake.getParticipantInfoMutex.RLock() + defer fake.getParticipantInfoMutex.RUnlock() + argsForCall := fake.getParticipantInfoArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantHelper) GetParticipantInfoReturns(result1 *livekit.ParticipantInfo) { + fake.getParticipantInfoMutex.Lock() + defer fake.getParticipantInfoMutex.Unlock() + fake.GetParticipantInfoStub = nil + fake.getParticipantInfoReturns = struct { + result1 *livekit.ParticipantInfo + }{result1} +} + +func (fake *FakeLocalParticipantHelper) GetParticipantInfoReturnsOnCall(i int, result1 *livekit.ParticipantInfo) { + fake.getParticipantInfoMutex.Lock() + defer fake.getParticipantInfoMutex.Unlock() + fake.GetParticipantInfoStub = nil + if fake.getParticipantInfoReturnsOnCall == nil { + fake.getParticipantInfoReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + }) + } + fake.getParticipantInfoReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + }{result1} +} + +func (fake *FakeLocalParticipantHelper) GetRegionSettings(arg1 string) *livekit.RegionSettings { + fake.getRegionSettingsMutex.Lock() + ret, specificReturn := fake.getRegionSettingsReturnsOnCall[len(fake.getRegionSettingsArgsForCall)] + fake.getRegionSettingsArgsForCall = append(fake.getRegionSettingsArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.GetRegionSettingsStub + fakeReturns := fake.getRegionSettingsReturns + fake.recordInvocation("GetRegionSettings", []interface{}{arg1}) + fake.getRegionSettingsMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantHelper) GetRegionSettingsCallCount() int { + fake.getRegionSettingsMutex.RLock() + defer fake.getRegionSettingsMutex.RUnlock() + return len(fake.getRegionSettingsArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) GetRegionSettingsCalls(stub func(string) *livekit.RegionSettings) { + fake.getRegionSettingsMutex.Lock() + defer fake.getRegionSettingsMutex.Unlock() + fake.GetRegionSettingsStub = stub +} + +func (fake *FakeLocalParticipantHelper) GetRegionSettingsArgsForCall(i int) string { + fake.getRegionSettingsMutex.RLock() + defer fake.getRegionSettingsMutex.RUnlock() + argsForCall := fake.getRegionSettingsArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantHelper) GetRegionSettingsReturns(result1 *livekit.RegionSettings) { + fake.getRegionSettingsMutex.Lock() + defer fake.getRegionSettingsMutex.Unlock() + fake.GetRegionSettingsStub = nil + fake.getRegionSettingsReturns = struct { + result1 *livekit.RegionSettings + }{result1} +} + +func (fake *FakeLocalParticipantHelper) GetRegionSettingsReturnsOnCall(i int, result1 *livekit.RegionSettings) { + fake.getRegionSettingsMutex.Lock() + defer fake.getRegionSettingsMutex.Unlock() + fake.GetRegionSettingsStub = nil + if fake.getRegionSettingsReturnsOnCall == nil { + fake.getRegionSettingsReturnsOnCall = make(map[int]struct { + result1 *livekit.RegionSettings + }) + } + fake.getRegionSettingsReturnsOnCall[i] = struct { + result1 *livekit.RegionSettings + }{result1} +} + +func (fake *FakeLocalParticipantHelper) GetSubscriberForwarderState(arg1 types.LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error) { + fake.getSubscriberForwarderStateMutex.Lock() + ret, specificReturn := fake.getSubscriberForwarderStateReturnsOnCall[len(fake.getSubscriberForwarderStateArgsForCall)] + fake.getSubscriberForwarderStateArgsForCall = append(fake.getSubscriberForwarderStateArgsForCall, struct { + arg1 types.LocalParticipant + }{arg1}) + stub := fake.GetSubscriberForwarderStateStub + fakeReturns := fake.getSubscriberForwarderStateReturns + fake.recordInvocation("GetSubscriberForwarderState", []interface{}{arg1}) + fake.getSubscriberForwarderStateMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeLocalParticipantHelper) GetSubscriberForwarderStateCallCount() int { + fake.getSubscriberForwarderStateMutex.RLock() + defer fake.getSubscriberForwarderStateMutex.RUnlock() + return len(fake.getSubscriberForwarderStateArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) GetSubscriberForwarderStateCalls(stub func(types.LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error)) { + fake.getSubscriberForwarderStateMutex.Lock() + defer fake.getSubscriberForwarderStateMutex.Unlock() + fake.GetSubscriberForwarderStateStub = stub +} + +func (fake *FakeLocalParticipantHelper) GetSubscriberForwarderStateArgsForCall(i int) types.LocalParticipant { + fake.getSubscriberForwarderStateMutex.RLock() + defer fake.getSubscriberForwarderStateMutex.RUnlock() + argsForCall := fake.getSubscriberForwarderStateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantHelper) GetSubscriberForwarderStateReturns(result1 map[livekit.TrackID]*livekit.RTPForwarderState, result2 error) { + fake.getSubscriberForwarderStateMutex.Lock() + defer fake.getSubscriberForwarderStateMutex.Unlock() + fake.GetSubscriberForwarderStateStub = nil + fake.getSubscriberForwarderStateReturns = struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + result2 error + }{result1, result2} +} + +func (fake *FakeLocalParticipantHelper) GetSubscriberForwarderStateReturnsOnCall(i int, result1 map[livekit.TrackID]*livekit.RTPForwarderState, result2 error) { + fake.getSubscriberForwarderStateMutex.Lock() + defer fake.getSubscriberForwarderStateMutex.Unlock() + fake.GetSubscriberForwarderStateStub = nil + if fake.getSubscriberForwarderStateReturnsOnCall == nil { + fake.getSubscriberForwarderStateReturnsOnCall = make(map[int]struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + result2 error + }) + } + fake.getSubscriberForwarderStateReturnsOnCall[i] = struct { + result1 map[livekit.TrackID]*livekit.RTPForwarderState + result2 error + }{result1, result2} +} + +func (fake *FakeLocalParticipantHelper) ResolveDataTrack(arg1 types.LocalParticipant, arg2 livekit.TrackID) types.DataResolverResult { + fake.resolveDataTrackMutex.Lock() + ret, specificReturn := fake.resolveDataTrackReturnsOnCall[len(fake.resolveDataTrackArgsForCall)] + fake.resolveDataTrackArgsForCall = append(fake.resolveDataTrackArgsForCall, struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + }{arg1, arg2}) + stub := fake.ResolveDataTrackStub + fakeReturns := fake.resolveDataTrackReturns + fake.recordInvocation("ResolveDataTrack", []interface{}{arg1, arg2}) + fake.resolveDataTrackMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantHelper) ResolveDataTrackCallCount() int { + fake.resolveDataTrackMutex.RLock() + defer fake.resolveDataTrackMutex.RUnlock() + return len(fake.resolveDataTrackArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) ResolveDataTrackCalls(stub func(types.LocalParticipant, livekit.TrackID) types.DataResolverResult) { + fake.resolveDataTrackMutex.Lock() + defer fake.resolveDataTrackMutex.Unlock() + fake.ResolveDataTrackStub = stub +} + +func (fake *FakeLocalParticipantHelper) ResolveDataTrackArgsForCall(i int) (types.LocalParticipant, livekit.TrackID) { + fake.resolveDataTrackMutex.RLock() + defer fake.resolveDataTrackMutex.RUnlock() + argsForCall := fake.resolveDataTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantHelper) ResolveDataTrackReturns(result1 types.DataResolverResult) { + fake.resolveDataTrackMutex.Lock() + defer fake.resolveDataTrackMutex.Unlock() + fake.ResolveDataTrackStub = nil + fake.resolveDataTrackReturns = struct { + result1 types.DataResolverResult + }{result1} +} + +func (fake *FakeLocalParticipantHelper) ResolveDataTrackReturnsOnCall(i int, result1 types.DataResolverResult) { + fake.resolveDataTrackMutex.Lock() + defer fake.resolveDataTrackMutex.Unlock() + fake.ResolveDataTrackStub = nil + if fake.resolveDataTrackReturnsOnCall == nil { + fake.resolveDataTrackReturnsOnCall = make(map[int]struct { + result1 types.DataResolverResult + }) + } + fake.resolveDataTrackReturnsOnCall[i] = struct { + result1 types.DataResolverResult + }{result1} +} + +func (fake *FakeLocalParticipantHelper) ResolveMediaTrack(arg1 types.LocalParticipant, arg2 livekit.TrackID) types.MediaResolverResult { + fake.resolveMediaTrackMutex.Lock() + ret, specificReturn := fake.resolveMediaTrackReturnsOnCall[len(fake.resolveMediaTrackArgsForCall)] + fake.resolveMediaTrackArgsForCall = append(fake.resolveMediaTrackArgsForCall, struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + }{arg1, arg2}) + stub := fake.ResolveMediaTrackStub + fakeReturns := fake.resolveMediaTrackReturns + fake.recordInvocation("ResolveMediaTrack", []interface{}{arg1, arg2}) + fake.resolveMediaTrackMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantHelper) ResolveMediaTrackCallCount() int { + fake.resolveMediaTrackMutex.RLock() + defer fake.resolveMediaTrackMutex.RUnlock() + return len(fake.resolveMediaTrackArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) ResolveMediaTrackCalls(stub func(types.LocalParticipant, livekit.TrackID) types.MediaResolverResult) { + fake.resolveMediaTrackMutex.Lock() + defer fake.resolveMediaTrackMutex.Unlock() + fake.ResolveMediaTrackStub = stub +} + +func (fake *FakeLocalParticipantHelper) ResolveMediaTrackArgsForCall(i int) (types.LocalParticipant, livekit.TrackID) { + fake.resolveMediaTrackMutex.RLock() + defer fake.resolveMediaTrackMutex.RUnlock() + argsForCall := fake.resolveMediaTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantHelper) ResolveMediaTrackReturns(result1 types.MediaResolverResult) { + fake.resolveMediaTrackMutex.Lock() + defer fake.resolveMediaTrackMutex.Unlock() + fake.ResolveMediaTrackStub = nil + fake.resolveMediaTrackReturns = struct { + result1 types.MediaResolverResult + }{result1} +} + +func (fake *FakeLocalParticipantHelper) ResolveMediaTrackReturnsOnCall(i int, result1 types.MediaResolverResult) { + fake.resolveMediaTrackMutex.Lock() + defer fake.resolveMediaTrackMutex.Unlock() + fake.ResolveMediaTrackStub = nil + if fake.resolveMediaTrackReturnsOnCall == nil { + fake.resolveMediaTrackReturnsOnCall = make(map[int]struct { + result1 types.MediaResolverResult + }) + } + fake.resolveMediaTrackReturnsOnCall[i] = struct { + result1 types.MediaResolverResult + }{result1} +} + +func (fake *FakeLocalParticipantHelper) ShouldRegressCodec() bool { + fake.shouldRegressCodecMutex.Lock() + ret, specificReturn := fake.shouldRegressCodecReturnsOnCall[len(fake.shouldRegressCodecArgsForCall)] + fake.shouldRegressCodecArgsForCall = append(fake.shouldRegressCodecArgsForCall, struct { + }{}) + stub := fake.ShouldRegressCodecStub + fakeReturns := fake.shouldRegressCodecReturns + fake.recordInvocation("ShouldRegressCodec", []interface{}{}) + fake.shouldRegressCodecMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantHelper) ShouldRegressCodecCallCount() int { + fake.shouldRegressCodecMutex.RLock() + defer fake.shouldRegressCodecMutex.RUnlock() + return len(fake.shouldRegressCodecArgsForCall) +} + +func (fake *FakeLocalParticipantHelper) ShouldRegressCodecCalls(stub func() bool) { + fake.shouldRegressCodecMutex.Lock() + defer fake.shouldRegressCodecMutex.Unlock() + fake.ShouldRegressCodecStub = stub +} + +func (fake *FakeLocalParticipantHelper) ShouldRegressCodecReturns(result1 bool) { + fake.shouldRegressCodecMutex.Lock() + defer fake.shouldRegressCodecMutex.Unlock() + fake.ShouldRegressCodecStub = nil + fake.shouldRegressCodecReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipantHelper) ShouldRegressCodecReturnsOnCall(i int, result1 bool) { + fake.shouldRegressCodecMutex.Lock() + defer fake.shouldRegressCodecMutex.Unlock() + fake.ShouldRegressCodecStub = nil + if fake.shouldRegressCodecReturnsOnCall == nil { + fake.shouldRegressCodecReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.shouldRegressCodecReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeLocalParticipantHelper) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeLocalParticipantHelper) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.LocalParticipantHelper = new(FakeLocalParticipantHelper) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_local_participant_listener.go b/livekit/pkg/rtc/types/typesfakes/fake_local_participant_listener.go new file mode 100644 index 0000000..4cd0fef --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_local_participant_listener.go @@ -0,0 +1,948 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeLocalParticipantListener struct { + OnDataMessageStub func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) + onDataMessageMutex sync.RWMutex + onDataMessageArgsForCall []struct { + arg1 types.LocalParticipant + arg2 livekit.DataPacket_Kind + arg3 *livekit.DataPacket + } + OnDataMessageUnlabeledStub func(types.LocalParticipant, []byte) + onDataMessageUnlabeledMutex sync.RWMutex + onDataMessageUnlabeledArgsForCall []struct { + arg1 types.LocalParticipant + arg2 []byte + } + OnDataTrackMessageStub func(types.Participant, []byte, *datatrack.Packet) + onDataTrackMessageMutex sync.RWMutex + onDataTrackMessageArgsForCall []struct { + arg1 types.Participant + arg2 []byte + arg3 *datatrack.Packet + } + OnDataTrackPublishedStub func(types.Participant, types.DataTrack) + onDataTrackPublishedMutex sync.RWMutex + onDataTrackPublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.DataTrack + } + OnDataTrackUnpublishedStub func(types.Participant, types.DataTrack) + onDataTrackUnpublishedMutex sync.RWMutex + onDataTrackUnpublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.DataTrack + } + OnLeaveStub func(types.LocalParticipant, types.ParticipantCloseReason) + onLeaveMutex sync.RWMutex + onLeaveArgsForCall []struct { + arg1 types.LocalParticipant + arg2 types.ParticipantCloseReason + } + OnMetricsStub func(types.Participant, *livekit.DataPacket) + onMetricsMutex sync.RWMutex + onMetricsArgsForCall []struct { + arg1 types.Participant + arg2 *livekit.DataPacket + } + OnMigrateStateChangeStub func(types.LocalParticipant, types.MigrateState) + onMigrateStateChangeMutex sync.RWMutex + onMigrateStateChangeArgsForCall []struct { + arg1 types.LocalParticipant + arg2 types.MigrateState + } + OnParticipantUpdateStub func(types.Participant) + onParticipantUpdateMutex sync.RWMutex + onParticipantUpdateArgsForCall []struct { + arg1 types.Participant + } + OnSimulateScenarioStub func(types.LocalParticipant, *livekit.SimulateScenario) error + onSimulateScenarioMutex sync.RWMutex + onSimulateScenarioArgsForCall []struct { + arg1 types.LocalParticipant + arg2 *livekit.SimulateScenario + } + onSimulateScenarioReturns struct { + result1 error + } + onSimulateScenarioReturnsOnCall map[int]struct { + result1 error + } + OnStateChangeStub func(types.LocalParticipant) + onStateChangeMutex sync.RWMutex + onStateChangeArgsForCall []struct { + arg1 types.LocalParticipant + } + OnSubscribeStatusChangedStub func(types.LocalParticipant, livekit.ParticipantID, bool) + onSubscribeStatusChangedMutex sync.RWMutex + onSubscribeStatusChangedArgsForCall []struct { + arg1 types.LocalParticipant + arg2 livekit.ParticipantID + arg3 bool + } + OnSubscriberReadyStub func(types.LocalParticipant) + onSubscriberReadyMutex sync.RWMutex + onSubscriberReadyArgsForCall []struct { + arg1 types.LocalParticipant + } + OnSyncStateStub func(types.LocalParticipant, *livekit.SyncState) error + onSyncStateMutex sync.RWMutex + onSyncStateArgsForCall []struct { + arg1 types.LocalParticipant + arg2 *livekit.SyncState + } + onSyncStateReturns struct { + result1 error + } + onSyncStateReturnsOnCall map[int]struct { + result1 error + } + OnTrackPublishedStub func(types.Participant, types.MediaTrack) + onTrackPublishedMutex sync.RWMutex + onTrackPublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.MediaTrack + } + OnTrackUnpublishedStub func(types.Participant, types.MediaTrack) + onTrackUnpublishedMutex sync.RWMutex + onTrackUnpublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.MediaTrack + } + OnTrackUpdatedStub func(types.Participant, types.MediaTrack) + onTrackUpdatedMutex sync.RWMutex + onTrackUpdatedArgsForCall []struct { + arg1 types.Participant + arg2 types.MediaTrack + } + OnUpdateDataSubscriptionsStub func(types.LocalParticipant, *livekit.UpdateDataSubscription) + onUpdateDataSubscriptionsMutex sync.RWMutex + onUpdateDataSubscriptionsArgsForCall []struct { + arg1 types.LocalParticipant + arg2 *livekit.UpdateDataSubscription + } + OnUpdateSubscriptionPermissionStub func(types.LocalParticipant, *livekit.SubscriptionPermission) error + onUpdateSubscriptionPermissionMutex sync.RWMutex + onUpdateSubscriptionPermissionArgsForCall []struct { + arg1 types.LocalParticipant + arg2 *livekit.SubscriptionPermission + } + onUpdateSubscriptionPermissionReturns struct { + result1 error + } + onUpdateSubscriptionPermissionReturnsOnCall map[int]struct { + result1 error + } + OnUpdateSubscriptionsStub func(types.LocalParticipant, []livekit.TrackID, []*livekit.ParticipantTracks, bool) + onUpdateSubscriptionsMutex sync.RWMutex + onUpdateSubscriptionsArgsForCall []struct { + arg1 types.LocalParticipant + arg2 []livekit.TrackID + arg3 []*livekit.ParticipantTracks + arg4 bool + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeLocalParticipantListener) OnDataMessage(arg1 types.LocalParticipant, arg2 livekit.DataPacket_Kind, arg3 *livekit.DataPacket) { + fake.onDataMessageMutex.Lock() + fake.onDataMessageArgsForCall = append(fake.onDataMessageArgsForCall, struct { + arg1 types.LocalParticipant + arg2 livekit.DataPacket_Kind + arg3 *livekit.DataPacket + }{arg1, arg2, arg3}) + stub := fake.OnDataMessageStub + fake.recordInvocation("OnDataMessage", []interface{}{arg1, arg2, arg3}) + fake.onDataMessageMutex.Unlock() + if stub != nil { + fake.OnDataMessageStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipantListener) OnDataMessageCallCount() int { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + return len(fake.onDataMessageArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnDataMessageCalls(stub func(types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket)) { + fake.onDataMessageMutex.Lock() + defer fake.onDataMessageMutex.Unlock() + fake.OnDataMessageStub = stub +} + +func (fake *FakeLocalParticipantListener) OnDataMessageArgsForCall(i int) (types.LocalParticipant, livekit.DataPacket_Kind, *livekit.DataPacket) { + fake.onDataMessageMutex.RLock() + defer fake.onDataMessageMutex.RUnlock() + argsForCall := fake.onDataMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipantListener) OnDataMessageUnlabeled(arg1 types.LocalParticipant, arg2 []byte) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.onDataMessageUnlabeledMutex.Lock() + fake.onDataMessageUnlabeledArgsForCall = append(fake.onDataMessageUnlabeledArgsForCall, struct { + arg1 types.LocalParticipant + arg2 []byte + }{arg1, arg2Copy}) + stub := fake.OnDataMessageUnlabeledStub + fake.recordInvocation("OnDataMessageUnlabeled", []interface{}{arg1, arg2Copy}) + fake.onDataMessageUnlabeledMutex.Unlock() + if stub != nil { + fake.OnDataMessageUnlabeledStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnDataMessageUnlabeledCallCount() int { + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() + return len(fake.onDataMessageUnlabeledArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnDataMessageUnlabeledCalls(stub func(types.LocalParticipant, []byte)) { + fake.onDataMessageUnlabeledMutex.Lock() + defer fake.onDataMessageUnlabeledMutex.Unlock() + fake.OnDataMessageUnlabeledStub = stub +} + +func (fake *FakeLocalParticipantListener) OnDataMessageUnlabeledArgsForCall(i int) (types.LocalParticipant, []byte) { + fake.onDataMessageUnlabeledMutex.RLock() + defer fake.onDataMessageUnlabeledMutex.RUnlock() + argsForCall := fake.onDataMessageUnlabeledArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnDataTrackMessage(arg1 types.Participant, arg2 []byte, arg3 *datatrack.Packet) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.onDataTrackMessageMutex.Lock() + fake.onDataTrackMessageArgsForCall = append(fake.onDataTrackMessageArgsForCall, struct { + arg1 types.Participant + arg2 []byte + arg3 *datatrack.Packet + }{arg1, arg2Copy, arg3}) + stub := fake.OnDataTrackMessageStub + fake.recordInvocation("OnDataTrackMessage", []interface{}{arg1, arg2Copy, arg3}) + fake.onDataTrackMessageMutex.Unlock() + if stub != nil { + fake.OnDataTrackMessageStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipantListener) OnDataTrackMessageCallCount() int { + fake.onDataTrackMessageMutex.RLock() + defer fake.onDataTrackMessageMutex.RUnlock() + return len(fake.onDataTrackMessageArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnDataTrackMessageCalls(stub func(types.Participant, []byte, *datatrack.Packet)) { + fake.onDataTrackMessageMutex.Lock() + defer fake.onDataTrackMessageMutex.Unlock() + fake.OnDataTrackMessageStub = stub +} + +func (fake *FakeLocalParticipantListener) OnDataTrackMessageArgsForCall(i int) (types.Participant, []byte, *datatrack.Packet) { + fake.onDataTrackMessageMutex.RLock() + defer fake.onDataTrackMessageMutex.RUnlock() + argsForCall := fake.onDataTrackMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipantListener) OnDataTrackPublished(arg1 types.Participant, arg2 types.DataTrack) { + fake.onDataTrackPublishedMutex.Lock() + fake.onDataTrackPublishedArgsForCall = append(fake.onDataTrackPublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.DataTrack + }{arg1, arg2}) + stub := fake.OnDataTrackPublishedStub + fake.recordInvocation("OnDataTrackPublished", []interface{}{arg1, arg2}) + fake.onDataTrackPublishedMutex.Unlock() + if stub != nil { + fake.OnDataTrackPublishedStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnDataTrackPublishedCallCount() int { + fake.onDataTrackPublishedMutex.RLock() + defer fake.onDataTrackPublishedMutex.RUnlock() + return len(fake.onDataTrackPublishedArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnDataTrackPublishedCalls(stub func(types.Participant, types.DataTrack)) { + fake.onDataTrackPublishedMutex.Lock() + defer fake.onDataTrackPublishedMutex.Unlock() + fake.OnDataTrackPublishedStub = stub +} + +func (fake *FakeLocalParticipantListener) OnDataTrackPublishedArgsForCall(i int) (types.Participant, types.DataTrack) { + fake.onDataTrackPublishedMutex.RLock() + defer fake.onDataTrackPublishedMutex.RUnlock() + argsForCall := fake.onDataTrackPublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnDataTrackUnpublished(arg1 types.Participant, arg2 types.DataTrack) { + fake.onDataTrackUnpublishedMutex.Lock() + fake.onDataTrackUnpublishedArgsForCall = append(fake.onDataTrackUnpublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.DataTrack + }{arg1, arg2}) + stub := fake.OnDataTrackUnpublishedStub + fake.recordInvocation("OnDataTrackUnpublished", []interface{}{arg1, arg2}) + fake.onDataTrackUnpublishedMutex.Unlock() + if stub != nil { + fake.OnDataTrackUnpublishedStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnDataTrackUnpublishedCallCount() int { + fake.onDataTrackUnpublishedMutex.RLock() + defer fake.onDataTrackUnpublishedMutex.RUnlock() + return len(fake.onDataTrackUnpublishedArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnDataTrackUnpublishedCalls(stub func(types.Participant, types.DataTrack)) { + fake.onDataTrackUnpublishedMutex.Lock() + defer fake.onDataTrackUnpublishedMutex.Unlock() + fake.OnDataTrackUnpublishedStub = stub +} + +func (fake *FakeLocalParticipantListener) OnDataTrackUnpublishedArgsForCall(i int) (types.Participant, types.DataTrack) { + fake.onDataTrackUnpublishedMutex.RLock() + defer fake.onDataTrackUnpublishedMutex.RUnlock() + argsForCall := fake.onDataTrackUnpublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnLeave(arg1 types.LocalParticipant, arg2 types.ParticipantCloseReason) { + fake.onLeaveMutex.Lock() + fake.onLeaveArgsForCall = append(fake.onLeaveArgsForCall, struct { + arg1 types.LocalParticipant + arg2 types.ParticipantCloseReason + }{arg1, arg2}) + stub := fake.OnLeaveStub + fake.recordInvocation("OnLeave", []interface{}{arg1, arg2}) + fake.onLeaveMutex.Unlock() + if stub != nil { + fake.OnLeaveStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnLeaveCallCount() int { + fake.onLeaveMutex.RLock() + defer fake.onLeaveMutex.RUnlock() + return len(fake.onLeaveArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnLeaveCalls(stub func(types.LocalParticipant, types.ParticipantCloseReason)) { + fake.onLeaveMutex.Lock() + defer fake.onLeaveMutex.Unlock() + fake.OnLeaveStub = stub +} + +func (fake *FakeLocalParticipantListener) OnLeaveArgsForCall(i int) (types.LocalParticipant, types.ParticipantCloseReason) { + fake.onLeaveMutex.RLock() + defer fake.onLeaveMutex.RUnlock() + argsForCall := fake.onLeaveArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnMetrics(arg1 types.Participant, arg2 *livekit.DataPacket) { + fake.onMetricsMutex.Lock() + fake.onMetricsArgsForCall = append(fake.onMetricsArgsForCall, struct { + arg1 types.Participant + arg2 *livekit.DataPacket + }{arg1, arg2}) + stub := fake.OnMetricsStub + fake.recordInvocation("OnMetrics", []interface{}{arg1, arg2}) + fake.onMetricsMutex.Unlock() + if stub != nil { + fake.OnMetricsStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnMetricsCallCount() int { + fake.onMetricsMutex.RLock() + defer fake.onMetricsMutex.RUnlock() + return len(fake.onMetricsArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnMetricsCalls(stub func(types.Participant, *livekit.DataPacket)) { + fake.onMetricsMutex.Lock() + defer fake.onMetricsMutex.Unlock() + fake.OnMetricsStub = stub +} + +func (fake *FakeLocalParticipantListener) OnMetricsArgsForCall(i int) (types.Participant, *livekit.DataPacket) { + fake.onMetricsMutex.RLock() + defer fake.onMetricsMutex.RUnlock() + argsForCall := fake.onMetricsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnMigrateStateChange(arg1 types.LocalParticipant, arg2 types.MigrateState) { + fake.onMigrateStateChangeMutex.Lock() + fake.onMigrateStateChangeArgsForCall = append(fake.onMigrateStateChangeArgsForCall, struct { + arg1 types.LocalParticipant + arg2 types.MigrateState + }{arg1, arg2}) + stub := fake.OnMigrateStateChangeStub + fake.recordInvocation("OnMigrateStateChange", []interface{}{arg1, arg2}) + fake.onMigrateStateChangeMutex.Unlock() + if stub != nil { + fake.OnMigrateStateChangeStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnMigrateStateChangeCallCount() int { + fake.onMigrateStateChangeMutex.RLock() + defer fake.onMigrateStateChangeMutex.RUnlock() + return len(fake.onMigrateStateChangeArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnMigrateStateChangeCalls(stub func(types.LocalParticipant, types.MigrateState)) { + fake.onMigrateStateChangeMutex.Lock() + defer fake.onMigrateStateChangeMutex.Unlock() + fake.OnMigrateStateChangeStub = stub +} + +func (fake *FakeLocalParticipantListener) OnMigrateStateChangeArgsForCall(i int) (types.LocalParticipant, types.MigrateState) { + fake.onMigrateStateChangeMutex.RLock() + defer fake.onMigrateStateChangeMutex.RUnlock() + argsForCall := fake.onMigrateStateChangeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnParticipantUpdate(arg1 types.Participant) { + fake.onParticipantUpdateMutex.Lock() + fake.onParticipantUpdateArgsForCall = append(fake.onParticipantUpdateArgsForCall, struct { + arg1 types.Participant + }{arg1}) + stub := fake.OnParticipantUpdateStub + fake.recordInvocation("OnParticipantUpdate", []interface{}{arg1}) + fake.onParticipantUpdateMutex.Unlock() + if stub != nil { + fake.OnParticipantUpdateStub(arg1) + } +} + +func (fake *FakeLocalParticipantListener) OnParticipantUpdateCallCount() int { + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() + return len(fake.onParticipantUpdateArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnParticipantUpdateCalls(stub func(types.Participant)) { + fake.onParticipantUpdateMutex.Lock() + defer fake.onParticipantUpdateMutex.Unlock() + fake.OnParticipantUpdateStub = stub +} + +func (fake *FakeLocalParticipantListener) OnParticipantUpdateArgsForCall(i int) types.Participant { + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() + argsForCall := fake.onParticipantUpdateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantListener) OnSimulateScenario(arg1 types.LocalParticipant, arg2 *livekit.SimulateScenario) error { + fake.onSimulateScenarioMutex.Lock() + ret, specificReturn := fake.onSimulateScenarioReturnsOnCall[len(fake.onSimulateScenarioArgsForCall)] + fake.onSimulateScenarioArgsForCall = append(fake.onSimulateScenarioArgsForCall, struct { + arg1 types.LocalParticipant + arg2 *livekit.SimulateScenario + }{arg1, arg2}) + stub := fake.OnSimulateScenarioStub + fakeReturns := fake.onSimulateScenarioReturns + fake.recordInvocation("OnSimulateScenario", []interface{}{arg1, arg2}) + fake.onSimulateScenarioMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantListener) OnSimulateScenarioCallCount() int { + fake.onSimulateScenarioMutex.RLock() + defer fake.onSimulateScenarioMutex.RUnlock() + return len(fake.onSimulateScenarioArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnSimulateScenarioCalls(stub func(types.LocalParticipant, *livekit.SimulateScenario) error) { + fake.onSimulateScenarioMutex.Lock() + defer fake.onSimulateScenarioMutex.Unlock() + fake.OnSimulateScenarioStub = stub +} + +func (fake *FakeLocalParticipantListener) OnSimulateScenarioArgsForCall(i int) (types.LocalParticipant, *livekit.SimulateScenario) { + fake.onSimulateScenarioMutex.RLock() + defer fake.onSimulateScenarioMutex.RUnlock() + argsForCall := fake.onSimulateScenarioArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnSimulateScenarioReturns(result1 error) { + fake.onSimulateScenarioMutex.Lock() + defer fake.onSimulateScenarioMutex.Unlock() + fake.OnSimulateScenarioStub = nil + fake.onSimulateScenarioReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipantListener) OnSimulateScenarioReturnsOnCall(i int, result1 error) { + fake.onSimulateScenarioMutex.Lock() + defer fake.onSimulateScenarioMutex.Unlock() + fake.OnSimulateScenarioStub = nil + if fake.onSimulateScenarioReturnsOnCall == nil { + fake.onSimulateScenarioReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onSimulateScenarioReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipantListener) OnStateChange(arg1 types.LocalParticipant) { + fake.onStateChangeMutex.Lock() + fake.onStateChangeArgsForCall = append(fake.onStateChangeArgsForCall, struct { + arg1 types.LocalParticipant + }{arg1}) + stub := fake.OnStateChangeStub + fake.recordInvocation("OnStateChange", []interface{}{arg1}) + fake.onStateChangeMutex.Unlock() + if stub != nil { + fake.OnStateChangeStub(arg1) + } +} + +func (fake *FakeLocalParticipantListener) OnStateChangeCallCount() int { + fake.onStateChangeMutex.RLock() + defer fake.onStateChangeMutex.RUnlock() + return len(fake.onStateChangeArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnStateChangeCalls(stub func(types.LocalParticipant)) { + fake.onStateChangeMutex.Lock() + defer fake.onStateChangeMutex.Unlock() + fake.OnStateChangeStub = stub +} + +func (fake *FakeLocalParticipantListener) OnStateChangeArgsForCall(i int) types.LocalParticipant { + fake.onStateChangeMutex.RLock() + defer fake.onStateChangeMutex.RUnlock() + argsForCall := fake.onStateChangeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantListener) OnSubscribeStatusChanged(arg1 types.LocalParticipant, arg2 livekit.ParticipantID, arg3 bool) { + fake.onSubscribeStatusChangedMutex.Lock() + fake.onSubscribeStatusChangedArgsForCall = append(fake.onSubscribeStatusChangedArgsForCall, struct { + arg1 types.LocalParticipant + arg2 livekit.ParticipantID + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.OnSubscribeStatusChangedStub + fake.recordInvocation("OnSubscribeStatusChanged", []interface{}{arg1, arg2, arg3}) + fake.onSubscribeStatusChangedMutex.Unlock() + if stub != nil { + fake.OnSubscribeStatusChangedStub(arg1, arg2, arg3) + } +} + +func (fake *FakeLocalParticipantListener) OnSubscribeStatusChangedCallCount() int { + fake.onSubscribeStatusChangedMutex.RLock() + defer fake.onSubscribeStatusChangedMutex.RUnlock() + return len(fake.onSubscribeStatusChangedArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnSubscribeStatusChangedCalls(stub func(types.LocalParticipant, livekit.ParticipantID, bool)) { + fake.onSubscribeStatusChangedMutex.Lock() + defer fake.onSubscribeStatusChangedMutex.Unlock() + fake.OnSubscribeStatusChangedStub = stub +} + +func (fake *FakeLocalParticipantListener) OnSubscribeStatusChangedArgsForCall(i int) (types.LocalParticipant, livekit.ParticipantID, bool) { + fake.onSubscribeStatusChangedMutex.RLock() + defer fake.onSubscribeStatusChangedMutex.RUnlock() + argsForCall := fake.onSubscribeStatusChangedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeLocalParticipantListener) OnSubscriberReady(arg1 types.LocalParticipant) { + fake.onSubscriberReadyMutex.Lock() + fake.onSubscriberReadyArgsForCall = append(fake.onSubscriberReadyArgsForCall, struct { + arg1 types.LocalParticipant + }{arg1}) + stub := fake.OnSubscriberReadyStub + fake.recordInvocation("OnSubscriberReady", []interface{}{arg1}) + fake.onSubscriberReadyMutex.Unlock() + if stub != nil { + fake.OnSubscriberReadyStub(arg1) + } +} + +func (fake *FakeLocalParticipantListener) OnSubscriberReadyCallCount() int { + fake.onSubscriberReadyMutex.RLock() + defer fake.onSubscriberReadyMutex.RUnlock() + return len(fake.onSubscriberReadyArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnSubscriberReadyCalls(stub func(types.LocalParticipant)) { + fake.onSubscriberReadyMutex.Lock() + defer fake.onSubscriberReadyMutex.Unlock() + fake.OnSubscriberReadyStub = stub +} + +func (fake *FakeLocalParticipantListener) OnSubscriberReadyArgsForCall(i int) types.LocalParticipant { + fake.onSubscriberReadyMutex.RLock() + defer fake.onSubscriberReadyMutex.RUnlock() + argsForCall := fake.onSubscriberReadyArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeLocalParticipantListener) OnSyncState(arg1 types.LocalParticipant, arg2 *livekit.SyncState) error { + fake.onSyncStateMutex.Lock() + ret, specificReturn := fake.onSyncStateReturnsOnCall[len(fake.onSyncStateArgsForCall)] + fake.onSyncStateArgsForCall = append(fake.onSyncStateArgsForCall, struct { + arg1 types.LocalParticipant + arg2 *livekit.SyncState + }{arg1, arg2}) + stub := fake.OnSyncStateStub + fakeReturns := fake.onSyncStateReturns + fake.recordInvocation("OnSyncState", []interface{}{arg1, arg2}) + fake.onSyncStateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantListener) OnSyncStateCallCount() int { + fake.onSyncStateMutex.RLock() + defer fake.onSyncStateMutex.RUnlock() + return len(fake.onSyncStateArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnSyncStateCalls(stub func(types.LocalParticipant, *livekit.SyncState) error) { + fake.onSyncStateMutex.Lock() + defer fake.onSyncStateMutex.Unlock() + fake.OnSyncStateStub = stub +} + +func (fake *FakeLocalParticipantListener) OnSyncStateArgsForCall(i int) (types.LocalParticipant, *livekit.SyncState) { + fake.onSyncStateMutex.RLock() + defer fake.onSyncStateMutex.RUnlock() + argsForCall := fake.onSyncStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnSyncStateReturns(result1 error) { + fake.onSyncStateMutex.Lock() + defer fake.onSyncStateMutex.Unlock() + fake.OnSyncStateStub = nil + fake.onSyncStateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipantListener) OnSyncStateReturnsOnCall(i int, result1 error) { + fake.onSyncStateMutex.Lock() + defer fake.onSyncStateMutex.Unlock() + fake.OnSyncStateStub = nil + if fake.onSyncStateReturnsOnCall == nil { + fake.onSyncStateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onSyncStateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipantListener) OnTrackPublished(arg1 types.Participant, arg2 types.MediaTrack) { + fake.onTrackPublishedMutex.Lock() + fake.onTrackPublishedArgsForCall = append(fake.onTrackPublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.MediaTrack + }{arg1, arg2}) + stub := fake.OnTrackPublishedStub + fake.recordInvocation("OnTrackPublished", []interface{}{arg1, arg2}) + fake.onTrackPublishedMutex.Unlock() + if stub != nil { + fake.OnTrackPublishedStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnTrackPublishedCallCount() int { + fake.onTrackPublishedMutex.RLock() + defer fake.onTrackPublishedMutex.RUnlock() + return len(fake.onTrackPublishedArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnTrackPublishedCalls(stub func(types.Participant, types.MediaTrack)) { + fake.onTrackPublishedMutex.Lock() + defer fake.onTrackPublishedMutex.Unlock() + fake.OnTrackPublishedStub = stub +} + +func (fake *FakeLocalParticipantListener) OnTrackPublishedArgsForCall(i int) (types.Participant, types.MediaTrack) { + fake.onTrackPublishedMutex.RLock() + defer fake.onTrackPublishedMutex.RUnlock() + argsForCall := fake.onTrackPublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnTrackUnpublished(arg1 types.Participant, arg2 types.MediaTrack) { + fake.onTrackUnpublishedMutex.Lock() + fake.onTrackUnpublishedArgsForCall = append(fake.onTrackUnpublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.MediaTrack + }{arg1, arg2}) + stub := fake.OnTrackUnpublishedStub + fake.recordInvocation("OnTrackUnpublished", []interface{}{arg1, arg2}) + fake.onTrackUnpublishedMutex.Unlock() + if stub != nil { + fake.OnTrackUnpublishedStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnTrackUnpublishedCallCount() int { + fake.onTrackUnpublishedMutex.RLock() + defer fake.onTrackUnpublishedMutex.RUnlock() + return len(fake.onTrackUnpublishedArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnTrackUnpublishedCalls(stub func(types.Participant, types.MediaTrack)) { + fake.onTrackUnpublishedMutex.Lock() + defer fake.onTrackUnpublishedMutex.Unlock() + fake.OnTrackUnpublishedStub = stub +} + +func (fake *FakeLocalParticipantListener) OnTrackUnpublishedArgsForCall(i int) (types.Participant, types.MediaTrack) { + fake.onTrackUnpublishedMutex.RLock() + defer fake.onTrackUnpublishedMutex.RUnlock() + argsForCall := fake.onTrackUnpublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnTrackUpdated(arg1 types.Participant, arg2 types.MediaTrack) { + fake.onTrackUpdatedMutex.Lock() + fake.onTrackUpdatedArgsForCall = append(fake.onTrackUpdatedArgsForCall, struct { + arg1 types.Participant + arg2 types.MediaTrack + }{arg1, arg2}) + stub := fake.OnTrackUpdatedStub + fake.recordInvocation("OnTrackUpdated", []interface{}{arg1, arg2}) + fake.onTrackUpdatedMutex.Unlock() + if stub != nil { + fake.OnTrackUpdatedStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnTrackUpdatedCallCount() int { + fake.onTrackUpdatedMutex.RLock() + defer fake.onTrackUpdatedMutex.RUnlock() + return len(fake.onTrackUpdatedArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnTrackUpdatedCalls(stub func(types.Participant, types.MediaTrack)) { + fake.onTrackUpdatedMutex.Lock() + defer fake.onTrackUpdatedMutex.Unlock() + fake.OnTrackUpdatedStub = stub +} + +func (fake *FakeLocalParticipantListener) OnTrackUpdatedArgsForCall(i int) (types.Participant, types.MediaTrack) { + fake.onTrackUpdatedMutex.RLock() + defer fake.onTrackUpdatedMutex.RUnlock() + argsForCall := fake.onTrackUpdatedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnUpdateDataSubscriptions(arg1 types.LocalParticipant, arg2 *livekit.UpdateDataSubscription) { + fake.onUpdateDataSubscriptionsMutex.Lock() + fake.onUpdateDataSubscriptionsArgsForCall = append(fake.onUpdateDataSubscriptionsArgsForCall, struct { + arg1 types.LocalParticipant + arg2 *livekit.UpdateDataSubscription + }{arg1, arg2}) + stub := fake.OnUpdateDataSubscriptionsStub + fake.recordInvocation("OnUpdateDataSubscriptions", []interface{}{arg1, arg2}) + fake.onUpdateDataSubscriptionsMutex.Unlock() + if stub != nil { + fake.OnUpdateDataSubscriptionsStub(arg1, arg2) + } +} + +func (fake *FakeLocalParticipantListener) OnUpdateDataSubscriptionsCallCount() int { + fake.onUpdateDataSubscriptionsMutex.RLock() + defer fake.onUpdateDataSubscriptionsMutex.RUnlock() + return len(fake.onUpdateDataSubscriptionsArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnUpdateDataSubscriptionsCalls(stub func(types.LocalParticipant, *livekit.UpdateDataSubscription)) { + fake.onUpdateDataSubscriptionsMutex.Lock() + defer fake.onUpdateDataSubscriptionsMutex.Unlock() + fake.OnUpdateDataSubscriptionsStub = stub +} + +func (fake *FakeLocalParticipantListener) OnUpdateDataSubscriptionsArgsForCall(i int) (types.LocalParticipant, *livekit.UpdateDataSubscription) { + fake.onUpdateDataSubscriptionsMutex.RLock() + defer fake.onUpdateDataSubscriptionsMutex.RUnlock() + argsForCall := fake.onUpdateDataSubscriptionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionPermission(arg1 types.LocalParticipant, arg2 *livekit.SubscriptionPermission) error { + fake.onUpdateSubscriptionPermissionMutex.Lock() + ret, specificReturn := fake.onUpdateSubscriptionPermissionReturnsOnCall[len(fake.onUpdateSubscriptionPermissionArgsForCall)] + fake.onUpdateSubscriptionPermissionArgsForCall = append(fake.onUpdateSubscriptionPermissionArgsForCall, struct { + arg1 types.LocalParticipant + arg2 *livekit.SubscriptionPermission + }{arg1, arg2}) + stub := fake.OnUpdateSubscriptionPermissionStub + fakeReturns := fake.onUpdateSubscriptionPermissionReturns + fake.recordInvocation("OnUpdateSubscriptionPermission", []interface{}{arg1, arg2}) + fake.onUpdateSubscriptionPermissionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionPermissionCallCount() int { + fake.onUpdateSubscriptionPermissionMutex.RLock() + defer fake.onUpdateSubscriptionPermissionMutex.RUnlock() + return len(fake.onUpdateSubscriptionPermissionArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionPermissionCalls(stub func(types.LocalParticipant, *livekit.SubscriptionPermission) error) { + fake.onUpdateSubscriptionPermissionMutex.Lock() + defer fake.onUpdateSubscriptionPermissionMutex.Unlock() + fake.OnUpdateSubscriptionPermissionStub = stub +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionPermissionArgsForCall(i int) (types.LocalParticipant, *livekit.SubscriptionPermission) { + fake.onUpdateSubscriptionPermissionMutex.RLock() + defer fake.onUpdateSubscriptionPermissionMutex.RUnlock() + argsForCall := fake.onUpdateSubscriptionPermissionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionPermissionReturns(result1 error) { + fake.onUpdateSubscriptionPermissionMutex.Lock() + defer fake.onUpdateSubscriptionPermissionMutex.Unlock() + fake.OnUpdateSubscriptionPermissionStub = nil + fake.onUpdateSubscriptionPermissionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionPermissionReturnsOnCall(i int, result1 error) { + fake.onUpdateSubscriptionPermissionMutex.Lock() + defer fake.onUpdateSubscriptionPermissionMutex.Unlock() + fake.OnUpdateSubscriptionPermissionStub = nil + if fake.onUpdateSubscriptionPermissionReturnsOnCall == nil { + fake.onUpdateSubscriptionPermissionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.onUpdateSubscriptionPermissionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptions(arg1 types.LocalParticipant, arg2 []livekit.TrackID, arg3 []*livekit.ParticipantTracks, arg4 bool) { + var arg2Copy []livekit.TrackID + if arg2 != nil { + arg2Copy = make([]livekit.TrackID, len(arg2)) + copy(arg2Copy, arg2) + } + var arg3Copy []*livekit.ParticipantTracks + if arg3 != nil { + arg3Copy = make([]*livekit.ParticipantTracks, len(arg3)) + copy(arg3Copy, arg3) + } + fake.onUpdateSubscriptionsMutex.Lock() + fake.onUpdateSubscriptionsArgsForCall = append(fake.onUpdateSubscriptionsArgsForCall, struct { + arg1 types.LocalParticipant + arg2 []livekit.TrackID + arg3 []*livekit.ParticipantTracks + arg4 bool + }{arg1, arg2Copy, arg3Copy, arg4}) + stub := fake.OnUpdateSubscriptionsStub + fake.recordInvocation("OnUpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3Copy, arg4}) + fake.onUpdateSubscriptionsMutex.Unlock() + if stub != nil { + fake.OnUpdateSubscriptionsStub(arg1, arg2, arg3, arg4) + } +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionsCallCount() int { + fake.onUpdateSubscriptionsMutex.RLock() + defer fake.onUpdateSubscriptionsMutex.RUnlock() + return len(fake.onUpdateSubscriptionsArgsForCall) +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionsCalls(stub func(types.LocalParticipant, []livekit.TrackID, []*livekit.ParticipantTracks, bool)) { + fake.onUpdateSubscriptionsMutex.Lock() + defer fake.onUpdateSubscriptionsMutex.Unlock() + fake.OnUpdateSubscriptionsStub = stub +} + +func (fake *FakeLocalParticipantListener) OnUpdateSubscriptionsArgsForCall(i int) (types.LocalParticipant, []livekit.TrackID, []*livekit.ParticipantTracks, bool) { + fake.onUpdateSubscriptionsMutex.RLock() + defer fake.onUpdateSubscriptionsMutex.RUnlock() + argsForCall := fake.onUpdateSubscriptionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeLocalParticipantListener) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeLocalParticipantListener) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.LocalParticipantListener = new(FakeLocalParticipantListener) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_media_track.go b/livekit/pkg/rtc/types/typesfakes/fake_media_track.go new file mode 100644 index 0000000..ba44585 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_media_track.go @@ -0,0 +1,1821 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type FakeMediaTrack struct { + AddOnCloseStub func(func(isExpectedToResume bool)) + addOnCloseMutex sync.RWMutex + addOnCloseArgsForCall []struct { + arg1 func(isExpectedToResume bool) + } + AddSubscriberStub func(types.LocalParticipant) (types.SubscribedTrack, error) + addSubscriberMutex sync.RWMutex + addSubscriberArgsForCall []struct { + arg1 types.LocalParticipant + } + addSubscriberReturns struct { + result1 types.SubscribedTrack + result2 error + } + addSubscriberReturnsOnCall map[int]struct { + result1 types.SubscribedTrack + result2 error + } + ClearAllReceiversStub func(bool) + clearAllReceiversMutex sync.RWMutex + clearAllReceiversArgsForCall []struct { + arg1 bool + } + CloseStub func(bool) + closeMutex sync.RWMutex + closeArgsForCall []struct { + arg1 bool + } + GetAllSubscribersStub func() []livekit.ParticipantID + getAllSubscribersMutex sync.RWMutex + getAllSubscribersArgsForCall []struct { + } + getAllSubscribersReturns struct { + result1 []livekit.ParticipantID + } + getAllSubscribersReturnsOnCall map[int]struct { + result1 []livekit.ParticipantID + } + GetAudioLevelStub func() (float64, bool) + getAudioLevelMutex sync.RWMutex + getAudioLevelArgsForCall []struct { + } + getAudioLevelReturns struct { + result1 float64 + result2 bool + } + getAudioLevelReturnsOnCall map[int]struct { + result1 float64 + result2 bool + } + GetNumSubscribersStub func() int + getNumSubscribersMutex sync.RWMutex + getNumSubscribersArgsForCall []struct { + } + getNumSubscribersReturns struct { + result1 int + } + getNumSubscribersReturnsOnCall map[int]struct { + result1 int + } + GetQualityForDimensionStub func(mime.MimeType, uint32, uint32) livekit.VideoQuality + getQualityForDimensionMutex sync.RWMutex + getQualityForDimensionArgsForCall []struct { + arg1 mime.MimeType + arg2 uint32 + arg3 uint32 + } + getQualityForDimensionReturns struct { + result1 livekit.VideoQuality + } + getQualityForDimensionReturnsOnCall map[int]struct { + result1 livekit.VideoQuality + } + GetTemporalLayerForSpatialFpsStub func(mime.MimeType, int32, uint32) int32 + getTemporalLayerForSpatialFpsMutex sync.RWMutex + getTemporalLayerForSpatialFpsArgsForCall []struct { + arg1 mime.MimeType + arg2 int32 + arg3 uint32 + } + getTemporalLayerForSpatialFpsReturns struct { + result1 int32 + } + getTemporalLayerForSpatialFpsReturnsOnCall map[int]struct { + result1 int32 + } + IDStub func() livekit.TrackID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.TrackID + } + iDReturnsOnCall map[int]struct { + result1 livekit.TrackID + } + IsEncryptedStub func() bool + isEncryptedMutex sync.RWMutex + isEncryptedArgsForCall []struct { + } + isEncryptedReturns struct { + result1 bool + } + isEncryptedReturnsOnCall map[int]struct { + result1 bool + } + IsMutedStub func() bool + isMutedMutex sync.RWMutex + isMutedArgsForCall []struct { + } + isMutedReturns struct { + result1 bool + } + isMutedReturnsOnCall map[int]struct { + result1 bool + } + IsOpenStub func() bool + isOpenMutex sync.RWMutex + isOpenArgsForCall []struct { + } + isOpenReturns struct { + result1 bool + } + isOpenReturnsOnCall map[int]struct { + result1 bool + } + IsSubscriberStub func(livekit.ParticipantID) bool + isSubscriberMutex sync.RWMutex + isSubscriberArgsForCall []struct { + arg1 livekit.ParticipantID + } + isSubscriberReturns struct { + result1 bool + } + isSubscriberReturnsOnCall map[int]struct { + result1 bool + } + KindStub func() livekit.TrackType + kindMutex sync.RWMutex + kindArgsForCall []struct { + } + kindReturns struct { + result1 livekit.TrackType + } + kindReturnsOnCall map[int]struct { + result1 livekit.TrackType + } + LoggerStub func() logger.Logger + loggerMutex sync.RWMutex + loggerArgsForCall []struct { + } + loggerReturns struct { + result1 logger.Logger + } + loggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + NameStub func() string + nameMutex sync.RWMutex + nameArgsForCall []struct { + } + nameReturns struct { + result1 string + } + nameReturnsOnCall map[int]struct { + result1 string + } + OnTrackSubscribedStub func() + onTrackSubscribedMutex sync.RWMutex + onTrackSubscribedArgsForCall []struct { + } + PublisherIDStub func() livekit.ParticipantID + publisherIDMutex sync.RWMutex + publisherIDArgsForCall []struct { + } + publisherIDReturns struct { + result1 livekit.ParticipantID + } + publisherIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + PublisherIdentityStub func() livekit.ParticipantIdentity + publisherIdentityMutex sync.RWMutex + publisherIdentityArgsForCall []struct { + } + publisherIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + publisherIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + PublisherVersionStub func() uint32 + publisherVersionMutex sync.RWMutex + publisherVersionArgsForCall []struct { + } + publisherVersionReturns struct { + result1 uint32 + } + publisherVersionReturnsOnCall map[int]struct { + result1 uint32 + } + ReceiversStub func() []sfu.TrackReceiver + receiversMutex sync.RWMutex + receiversArgsForCall []struct { + } + receiversReturns struct { + result1 []sfu.TrackReceiver + } + receiversReturnsOnCall map[int]struct { + result1 []sfu.TrackReceiver + } + RemoveSubscriberStub func(livekit.ParticipantID, bool) + removeSubscriberMutex sync.RWMutex + removeSubscriberArgsForCall []struct { + arg1 livekit.ParticipantID + arg2 bool + } + RevokeDisallowedSubscribersStub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity + revokeDisallowedSubscribersMutex sync.RWMutex + revokeDisallowedSubscribersArgsForCall []struct { + arg1 []livekit.ParticipantIdentity + } + revokeDisallowedSubscribersReturns struct { + result1 []livekit.ParticipantIdentity + } + revokeDisallowedSubscribersReturnsOnCall map[int]struct { + result1 []livekit.ParticipantIdentity + } + SetMutedStub func(bool) + setMutedMutex sync.RWMutex + setMutedArgsForCall []struct { + arg1 bool + } + SourceStub func() livekit.TrackSource + sourceMutex sync.RWMutex + sourceArgsForCall []struct { + } + sourceReturns struct { + result1 livekit.TrackSource + } + sourceReturnsOnCall map[int]struct { + result1 livekit.TrackSource + } + StreamStub func() string + streamMutex sync.RWMutex + streamArgsForCall []struct { + } + streamReturns struct { + result1 string + } + streamReturnsOnCall map[int]struct { + result1 string + } + ToProtoStub func() *livekit.TrackInfo + toProtoMutex sync.RWMutex + toProtoArgsForCall []struct { + } + toProtoReturns struct { + result1 *livekit.TrackInfo + } + toProtoReturnsOnCall map[int]struct { + result1 *livekit.TrackInfo + } + UpdateAudioTrackStub func(*livekit.UpdateLocalAudioTrack) + updateAudioTrackMutex sync.RWMutex + updateAudioTrackArgsForCall []struct { + arg1 *livekit.UpdateLocalAudioTrack + } + UpdateTrackInfoStub func(*livekit.TrackInfo) + updateTrackInfoMutex sync.RWMutex + updateTrackInfoArgsForCall []struct { + arg1 *livekit.TrackInfo + } + UpdateVideoTrackStub func(*livekit.UpdateLocalVideoTrack) + updateVideoTrackMutex sync.RWMutex + updateVideoTrackArgsForCall []struct { + arg1 *livekit.UpdateLocalVideoTrack + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeMediaTrack) AddOnClose(arg1 func(isExpectedToResume bool)) { + fake.addOnCloseMutex.Lock() + fake.addOnCloseArgsForCall = append(fake.addOnCloseArgsForCall, struct { + arg1 func(isExpectedToResume bool) + }{arg1}) + stub := fake.AddOnCloseStub + fake.recordInvocation("AddOnClose", []interface{}{arg1}) + fake.addOnCloseMutex.Unlock() + if stub != nil { + fake.AddOnCloseStub(arg1) + } +} + +func (fake *FakeMediaTrack) AddOnCloseCallCount() int { + fake.addOnCloseMutex.RLock() + defer fake.addOnCloseMutex.RUnlock() + return len(fake.addOnCloseArgsForCall) +} + +func (fake *FakeMediaTrack) AddOnCloseCalls(stub func(func(isExpectedToResume bool))) { + fake.addOnCloseMutex.Lock() + defer fake.addOnCloseMutex.Unlock() + fake.AddOnCloseStub = stub +} + +func (fake *FakeMediaTrack) AddOnCloseArgsForCall(i int) func(isExpectedToResume bool) { + fake.addOnCloseMutex.RLock() + defer fake.addOnCloseMutex.RUnlock() + argsForCall := fake.addOnCloseArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) AddSubscriber(arg1 types.LocalParticipant) (types.SubscribedTrack, error) { + fake.addSubscriberMutex.Lock() + ret, specificReturn := fake.addSubscriberReturnsOnCall[len(fake.addSubscriberArgsForCall)] + fake.addSubscriberArgsForCall = append(fake.addSubscriberArgsForCall, struct { + arg1 types.LocalParticipant + }{arg1}) + stub := fake.AddSubscriberStub + fakeReturns := fake.addSubscriberReturns + fake.recordInvocation("AddSubscriber", []interface{}{arg1}) + fake.addSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeMediaTrack) AddSubscriberCallCount() int { + fake.addSubscriberMutex.RLock() + defer fake.addSubscriberMutex.RUnlock() + return len(fake.addSubscriberArgsForCall) +} + +func (fake *FakeMediaTrack) AddSubscriberCalls(stub func(types.LocalParticipant) (types.SubscribedTrack, error)) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = stub +} + +func (fake *FakeMediaTrack) AddSubscriberArgsForCall(i int) types.LocalParticipant { + fake.addSubscriberMutex.RLock() + defer fake.addSubscriberMutex.RUnlock() + argsForCall := fake.addSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) AddSubscriberReturns(result1 types.SubscribedTrack, result2 error) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = nil + fake.addSubscriberReturns = struct { + result1 types.SubscribedTrack + result2 error + }{result1, result2} +} + +func (fake *FakeMediaTrack) AddSubscriberReturnsOnCall(i int, result1 types.SubscribedTrack, result2 error) { + fake.addSubscriberMutex.Lock() + defer fake.addSubscriberMutex.Unlock() + fake.AddSubscriberStub = nil + if fake.addSubscriberReturnsOnCall == nil { + fake.addSubscriberReturnsOnCall = make(map[int]struct { + result1 types.SubscribedTrack + result2 error + }) + } + fake.addSubscriberReturnsOnCall[i] = struct { + result1 types.SubscribedTrack + result2 error + }{result1, result2} +} + +func (fake *FakeMediaTrack) ClearAllReceivers(arg1 bool) { + fake.clearAllReceiversMutex.Lock() + fake.clearAllReceiversArgsForCall = append(fake.clearAllReceiversArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.ClearAllReceiversStub + fake.recordInvocation("ClearAllReceivers", []interface{}{arg1}) + fake.clearAllReceiversMutex.Unlock() + if stub != nil { + fake.ClearAllReceiversStub(arg1) + } +} + +func (fake *FakeMediaTrack) ClearAllReceiversCallCount() int { + fake.clearAllReceiversMutex.RLock() + defer fake.clearAllReceiversMutex.RUnlock() + return len(fake.clearAllReceiversArgsForCall) +} + +func (fake *FakeMediaTrack) ClearAllReceiversCalls(stub func(bool)) { + fake.clearAllReceiversMutex.Lock() + defer fake.clearAllReceiversMutex.Unlock() + fake.ClearAllReceiversStub = stub +} + +func (fake *FakeMediaTrack) ClearAllReceiversArgsForCall(i int) bool { + fake.clearAllReceiversMutex.RLock() + defer fake.clearAllReceiversMutex.RUnlock() + argsForCall := fake.clearAllReceiversArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) Close(arg1 bool) { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{arg1}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub(arg1) + } +} + +func (fake *FakeMediaTrack) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeMediaTrack) CloseCalls(stub func(bool)) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeMediaTrack) CloseArgsForCall(i int) bool { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + argsForCall := fake.closeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) GetAllSubscribers() []livekit.ParticipantID { + fake.getAllSubscribersMutex.Lock() + ret, specificReturn := fake.getAllSubscribersReturnsOnCall[len(fake.getAllSubscribersArgsForCall)] + fake.getAllSubscribersArgsForCall = append(fake.getAllSubscribersArgsForCall, struct { + }{}) + stub := fake.GetAllSubscribersStub + fakeReturns := fake.getAllSubscribersReturns + fake.recordInvocation("GetAllSubscribers", []interface{}{}) + fake.getAllSubscribersMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) GetAllSubscribersCallCount() int { + fake.getAllSubscribersMutex.RLock() + defer fake.getAllSubscribersMutex.RUnlock() + return len(fake.getAllSubscribersArgsForCall) +} + +func (fake *FakeMediaTrack) GetAllSubscribersCalls(stub func() []livekit.ParticipantID) { + fake.getAllSubscribersMutex.Lock() + defer fake.getAllSubscribersMutex.Unlock() + fake.GetAllSubscribersStub = stub +} + +func (fake *FakeMediaTrack) GetAllSubscribersReturns(result1 []livekit.ParticipantID) { + fake.getAllSubscribersMutex.Lock() + defer fake.getAllSubscribersMutex.Unlock() + fake.GetAllSubscribersStub = nil + fake.getAllSubscribersReturns = struct { + result1 []livekit.ParticipantID + }{result1} +} + +func (fake *FakeMediaTrack) GetAllSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantID) { + fake.getAllSubscribersMutex.Lock() + defer fake.getAllSubscribersMutex.Unlock() + fake.GetAllSubscribersStub = nil + if fake.getAllSubscribersReturnsOnCall == nil { + fake.getAllSubscribersReturnsOnCall = make(map[int]struct { + result1 []livekit.ParticipantID + }) + } + fake.getAllSubscribersReturnsOnCall[i] = struct { + result1 []livekit.ParticipantID + }{result1} +} + +func (fake *FakeMediaTrack) GetAudioLevel() (float64, bool) { + fake.getAudioLevelMutex.Lock() + ret, specificReturn := fake.getAudioLevelReturnsOnCall[len(fake.getAudioLevelArgsForCall)] + fake.getAudioLevelArgsForCall = append(fake.getAudioLevelArgsForCall, struct { + }{}) + stub := fake.GetAudioLevelStub + fakeReturns := fake.getAudioLevelReturns + fake.recordInvocation("GetAudioLevel", []interface{}{}) + fake.getAudioLevelMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeMediaTrack) GetAudioLevelCallCount() int { + fake.getAudioLevelMutex.RLock() + defer fake.getAudioLevelMutex.RUnlock() + return len(fake.getAudioLevelArgsForCall) +} + +func (fake *FakeMediaTrack) GetAudioLevelCalls(stub func() (float64, bool)) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = stub +} + +func (fake *FakeMediaTrack) GetAudioLevelReturns(result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + fake.getAudioLevelReturns = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeMediaTrack) GetAudioLevelReturnsOnCall(i int, result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + if fake.getAudioLevelReturnsOnCall == nil { + fake.getAudioLevelReturnsOnCall = make(map[int]struct { + result1 float64 + result2 bool + }) + } + fake.getAudioLevelReturnsOnCall[i] = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeMediaTrack) GetNumSubscribers() int { + fake.getNumSubscribersMutex.Lock() + ret, specificReturn := fake.getNumSubscribersReturnsOnCall[len(fake.getNumSubscribersArgsForCall)] + fake.getNumSubscribersArgsForCall = append(fake.getNumSubscribersArgsForCall, struct { + }{}) + stub := fake.GetNumSubscribersStub + fakeReturns := fake.getNumSubscribersReturns + fake.recordInvocation("GetNumSubscribers", []interface{}{}) + fake.getNumSubscribersMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) GetNumSubscribersCallCount() int { + fake.getNumSubscribersMutex.RLock() + defer fake.getNumSubscribersMutex.RUnlock() + return len(fake.getNumSubscribersArgsForCall) +} + +func (fake *FakeMediaTrack) GetNumSubscribersCalls(stub func() int) { + fake.getNumSubscribersMutex.Lock() + defer fake.getNumSubscribersMutex.Unlock() + fake.GetNumSubscribersStub = stub +} + +func (fake *FakeMediaTrack) GetNumSubscribersReturns(result1 int) { + fake.getNumSubscribersMutex.Lock() + defer fake.getNumSubscribersMutex.Unlock() + fake.GetNumSubscribersStub = nil + fake.getNumSubscribersReturns = struct { + result1 int + }{result1} +} + +func (fake *FakeMediaTrack) GetNumSubscribersReturnsOnCall(i int, result1 int) { + fake.getNumSubscribersMutex.Lock() + defer fake.getNumSubscribersMutex.Unlock() + fake.GetNumSubscribersStub = nil + if fake.getNumSubscribersReturnsOnCall == nil { + fake.getNumSubscribersReturnsOnCall = make(map[int]struct { + result1 int + }) + } + fake.getNumSubscribersReturnsOnCall[i] = struct { + result1 int + }{result1} +} + +func (fake *FakeMediaTrack) GetQualityForDimension(arg1 mime.MimeType, arg2 uint32, arg3 uint32) livekit.VideoQuality { + fake.getQualityForDimensionMutex.Lock() + ret, specificReturn := fake.getQualityForDimensionReturnsOnCall[len(fake.getQualityForDimensionArgsForCall)] + fake.getQualityForDimensionArgsForCall = append(fake.getQualityForDimensionArgsForCall, struct { + arg1 mime.MimeType + arg2 uint32 + arg3 uint32 + }{arg1, arg2, arg3}) + stub := fake.GetQualityForDimensionStub + fakeReturns := fake.getQualityForDimensionReturns + fake.recordInvocation("GetQualityForDimension", []interface{}{arg1, arg2, arg3}) + fake.getQualityForDimensionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) GetQualityForDimensionCallCount() int { + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() + return len(fake.getQualityForDimensionArgsForCall) +} + +func (fake *FakeMediaTrack) GetQualityForDimensionCalls(stub func(mime.MimeType, uint32, uint32) livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = stub +} + +func (fake *FakeMediaTrack) GetQualityForDimensionArgsForCall(i int) (mime.MimeType, uint32, uint32) { + fake.getQualityForDimensionMutex.RLock() + defer fake.getQualityForDimensionMutex.RUnlock() + argsForCall := fake.getQualityForDimensionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeMediaTrack) GetQualityForDimensionReturns(result1 livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = nil + fake.getQualityForDimensionReturns = struct { + result1 livekit.VideoQuality + }{result1} +} + +func (fake *FakeMediaTrack) GetQualityForDimensionReturnsOnCall(i int, result1 livekit.VideoQuality) { + fake.getQualityForDimensionMutex.Lock() + defer fake.getQualityForDimensionMutex.Unlock() + fake.GetQualityForDimensionStub = nil + if fake.getQualityForDimensionReturnsOnCall == nil { + fake.getQualityForDimensionReturnsOnCall = make(map[int]struct { + result1 livekit.VideoQuality + }) + } + fake.getQualityForDimensionReturnsOnCall[i] = struct { + result1 livekit.VideoQuality + }{result1} +} + +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFps(arg1 mime.MimeType, arg2 int32, arg3 uint32) int32 { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + ret, specificReturn := fake.getTemporalLayerForSpatialFpsReturnsOnCall[len(fake.getTemporalLayerForSpatialFpsArgsForCall)] + fake.getTemporalLayerForSpatialFpsArgsForCall = append(fake.getTemporalLayerForSpatialFpsArgsForCall, struct { + arg1 mime.MimeType + arg2 int32 + arg3 uint32 + }{arg1, arg2, arg3}) + stub := fake.GetTemporalLayerForSpatialFpsStub + fakeReturns := fake.getTemporalLayerForSpatialFpsReturns + fake.recordInvocation("GetTemporalLayerForSpatialFps", []interface{}{arg1, arg2, arg3}) + fake.getTemporalLayerForSpatialFpsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCallCount() int { + fake.getTemporalLayerForSpatialFpsMutex.RLock() + defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() + return len(fake.getTemporalLayerForSpatialFpsArgsForCall) +} + +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsCalls(stub func(mime.MimeType, int32, uint32) int32) { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() + fake.GetTemporalLayerForSpatialFpsStub = stub +} + +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsArgsForCall(i int) (mime.MimeType, int32, uint32) { + fake.getTemporalLayerForSpatialFpsMutex.RLock() + defer fake.getTemporalLayerForSpatialFpsMutex.RUnlock() + argsForCall := fake.getTemporalLayerForSpatialFpsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsReturns(result1 int32) { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() + fake.GetTemporalLayerForSpatialFpsStub = nil + fake.getTemporalLayerForSpatialFpsReturns = struct { + result1 int32 + }{result1} +} + +func (fake *FakeMediaTrack) GetTemporalLayerForSpatialFpsReturnsOnCall(i int, result1 int32) { + fake.getTemporalLayerForSpatialFpsMutex.Lock() + defer fake.getTemporalLayerForSpatialFpsMutex.Unlock() + fake.GetTemporalLayerForSpatialFpsStub = nil + if fake.getTemporalLayerForSpatialFpsReturnsOnCall == nil { + fake.getTemporalLayerForSpatialFpsReturnsOnCall = make(map[int]struct { + result1 int32 + }) + } + fake.getTemporalLayerForSpatialFpsReturnsOnCall[i] = struct { + result1 int32 + }{result1} +} + +func (fake *FakeMediaTrack) ID() livekit.TrackID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeMediaTrack) IDCalls(stub func() livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeMediaTrack) IDReturns(result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeMediaTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.TrackID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeMediaTrack) IsEncrypted() bool { + fake.isEncryptedMutex.Lock() + ret, specificReturn := fake.isEncryptedReturnsOnCall[len(fake.isEncryptedArgsForCall)] + fake.isEncryptedArgsForCall = append(fake.isEncryptedArgsForCall, struct { + }{}) + stub := fake.IsEncryptedStub + fakeReturns := fake.isEncryptedReturns + fake.recordInvocation("IsEncrypted", []interface{}{}) + fake.isEncryptedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) IsEncryptedCallCount() int { + fake.isEncryptedMutex.RLock() + defer fake.isEncryptedMutex.RUnlock() + return len(fake.isEncryptedArgsForCall) +} + +func (fake *FakeMediaTrack) IsEncryptedCalls(stub func() bool) { + fake.isEncryptedMutex.Lock() + defer fake.isEncryptedMutex.Unlock() + fake.IsEncryptedStub = stub +} + +func (fake *FakeMediaTrack) IsEncryptedReturns(result1 bool) { + fake.isEncryptedMutex.Lock() + defer fake.isEncryptedMutex.Unlock() + fake.IsEncryptedStub = nil + fake.isEncryptedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsEncryptedReturnsOnCall(i int, result1 bool) { + fake.isEncryptedMutex.Lock() + defer fake.isEncryptedMutex.Unlock() + fake.IsEncryptedStub = nil + if fake.isEncryptedReturnsOnCall == nil { + fake.isEncryptedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isEncryptedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsMuted() bool { + fake.isMutedMutex.Lock() + ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)] + fake.isMutedArgsForCall = append(fake.isMutedArgsForCall, struct { + }{}) + stub := fake.IsMutedStub + fakeReturns := fake.isMutedReturns + fake.recordInvocation("IsMuted", []interface{}{}) + fake.isMutedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) IsMutedCallCount() int { + fake.isMutedMutex.RLock() + defer fake.isMutedMutex.RUnlock() + return len(fake.isMutedArgsForCall) +} + +func (fake *FakeMediaTrack) IsMutedCalls(stub func() bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = stub +} + +func (fake *FakeMediaTrack) IsMutedReturns(result1 bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = nil + fake.isMutedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsMutedReturnsOnCall(i int, result1 bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = nil + if fake.isMutedReturnsOnCall == nil { + fake.isMutedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isMutedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsOpen() bool { + fake.isOpenMutex.Lock() + ret, specificReturn := fake.isOpenReturnsOnCall[len(fake.isOpenArgsForCall)] + fake.isOpenArgsForCall = append(fake.isOpenArgsForCall, struct { + }{}) + stub := fake.IsOpenStub + fakeReturns := fake.isOpenReturns + fake.recordInvocation("IsOpen", []interface{}{}) + fake.isOpenMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) IsOpenCallCount() int { + fake.isOpenMutex.RLock() + defer fake.isOpenMutex.RUnlock() + return len(fake.isOpenArgsForCall) +} + +func (fake *FakeMediaTrack) IsOpenCalls(stub func() bool) { + fake.isOpenMutex.Lock() + defer fake.isOpenMutex.Unlock() + fake.IsOpenStub = stub +} + +func (fake *FakeMediaTrack) IsOpenReturns(result1 bool) { + fake.isOpenMutex.Lock() + defer fake.isOpenMutex.Unlock() + fake.IsOpenStub = nil + fake.isOpenReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsOpenReturnsOnCall(i int, result1 bool) { + fake.isOpenMutex.Lock() + defer fake.isOpenMutex.Unlock() + fake.IsOpenStub = nil + if fake.isOpenReturnsOnCall == nil { + fake.isOpenReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isOpenReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsSubscriber(arg1 livekit.ParticipantID) bool { + fake.isSubscriberMutex.Lock() + ret, specificReturn := fake.isSubscriberReturnsOnCall[len(fake.isSubscriberArgsForCall)] + fake.isSubscriberArgsForCall = append(fake.isSubscriberArgsForCall, struct { + arg1 livekit.ParticipantID + }{arg1}) + stub := fake.IsSubscriberStub + fakeReturns := fake.isSubscriberReturns + fake.recordInvocation("IsSubscriber", []interface{}{arg1}) + fake.isSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) IsSubscriberCallCount() int { + fake.isSubscriberMutex.RLock() + defer fake.isSubscriberMutex.RUnlock() + return len(fake.isSubscriberArgsForCall) +} + +func (fake *FakeMediaTrack) IsSubscriberCalls(stub func(livekit.ParticipantID) bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = stub +} + +func (fake *FakeMediaTrack) IsSubscriberArgsForCall(i int) livekit.ParticipantID { + fake.isSubscriberMutex.RLock() + defer fake.isSubscriberMutex.RUnlock() + argsForCall := fake.isSubscriberArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) IsSubscriberReturns(result1 bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = nil + fake.isSubscriberReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) IsSubscriberReturnsOnCall(i int, result1 bool) { + fake.isSubscriberMutex.Lock() + defer fake.isSubscriberMutex.Unlock() + fake.IsSubscriberStub = nil + if fake.isSubscriberReturnsOnCall == nil { + fake.isSubscriberReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isSubscriberReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeMediaTrack) Kind() livekit.TrackType { + fake.kindMutex.Lock() + ret, specificReturn := fake.kindReturnsOnCall[len(fake.kindArgsForCall)] + fake.kindArgsForCall = append(fake.kindArgsForCall, struct { + }{}) + stub := fake.KindStub + fakeReturns := fake.kindReturns + fake.recordInvocation("Kind", []interface{}{}) + fake.kindMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) KindCallCount() int { + fake.kindMutex.RLock() + defer fake.kindMutex.RUnlock() + return len(fake.kindArgsForCall) +} + +func (fake *FakeMediaTrack) KindCalls(stub func() livekit.TrackType) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = stub +} + +func (fake *FakeMediaTrack) KindReturns(result1 livekit.TrackType) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + fake.kindReturns = struct { + result1 livekit.TrackType + }{result1} +} + +func (fake *FakeMediaTrack) KindReturnsOnCall(i int, result1 livekit.TrackType) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + if fake.kindReturnsOnCall == nil { + fake.kindReturnsOnCall = make(map[int]struct { + result1 livekit.TrackType + }) + } + fake.kindReturnsOnCall[i] = struct { + result1 livekit.TrackType + }{result1} +} + +func (fake *FakeMediaTrack) Logger() logger.Logger { + fake.loggerMutex.Lock() + ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)] + fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct { + }{}) + stub := fake.LoggerStub + fakeReturns := fake.loggerReturns + fake.recordInvocation("Logger", []interface{}{}) + fake.loggerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) LoggerCallCount() int { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + return len(fake.loggerArgsForCall) +} + +func (fake *FakeMediaTrack) LoggerCalls(stub func() logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = stub +} + +func (fake *FakeMediaTrack) LoggerReturns(result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + fake.loggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeMediaTrack) LoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + if fake.loggerReturnsOnCall == nil { + fake.loggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.loggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeMediaTrack) Name() string { + fake.nameMutex.Lock() + ret, specificReturn := fake.nameReturnsOnCall[len(fake.nameArgsForCall)] + fake.nameArgsForCall = append(fake.nameArgsForCall, struct { + }{}) + stub := fake.NameStub + fakeReturns := fake.nameReturns + fake.recordInvocation("Name", []interface{}{}) + fake.nameMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) NameCallCount() int { + fake.nameMutex.RLock() + defer fake.nameMutex.RUnlock() + return len(fake.nameArgsForCall) +} + +func (fake *FakeMediaTrack) NameCalls(stub func() string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = stub +} + +func (fake *FakeMediaTrack) NameReturns(result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + fake.nameReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeMediaTrack) NameReturnsOnCall(i int, result1 string) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + if fake.nameReturnsOnCall == nil { + fake.nameReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.nameReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeMediaTrack) OnTrackSubscribed() { + fake.onTrackSubscribedMutex.Lock() + fake.onTrackSubscribedArgsForCall = append(fake.onTrackSubscribedArgsForCall, struct { + }{}) + stub := fake.OnTrackSubscribedStub + fake.recordInvocation("OnTrackSubscribed", []interface{}{}) + fake.onTrackSubscribedMutex.Unlock() + if stub != nil { + fake.OnTrackSubscribedStub() + } +} + +func (fake *FakeMediaTrack) OnTrackSubscribedCallCount() int { + fake.onTrackSubscribedMutex.RLock() + defer fake.onTrackSubscribedMutex.RUnlock() + return len(fake.onTrackSubscribedArgsForCall) +} + +func (fake *FakeMediaTrack) OnTrackSubscribedCalls(stub func()) { + fake.onTrackSubscribedMutex.Lock() + defer fake.onTrackSubscribedMutex.Unlock() + fake.OnTrackSubscribedStub = stub +} + +func (fake *FakeMediaTrack) PublisherID() livekit.ParticipantID { + fake.publisherIDMutex.Lock() + ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] + fake.publisherIDArgsForCall = append(fake.publisherIDArgsForCall, struct { + }{}) + stub := fake.PublisherIDStub + fakeReturns := fake.publisherIDReturns + fake.recordInvocation("PublisherID", []interface{}{}) + fake.publisherIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) PublisherIDCallCount() int { + fake.publisherIDMutex.RLock() + defer fake.publisherIDMutex.RUnlock() + return len(fake.publisherIDArgsForCall) +} + +func (fake *FakeMediaTrack) PublisherIDCalls(stub func() livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = stub +} + +func (fake *FakeMediaTrack) PublisherIDReturns(result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + fake.publisherIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeMediaTrack) PublisherIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + if fake.publisherIDReturnsOnCall == nil { + fake.publisherIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.publisherIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeMediaTrack) PublisherIdentity() livekit.ParticipantIdentity { + fake.publisherIdentityMutex.Lock() + ret, specificReturn := fake.publisherIdentityReturnsOnCall[len(fake.publisherIdentityArgsForCall)] + fake.publisherIdentityArgsForCall = append(fake.publisherIdentityArgsForCall, struct { + }{}) + stub := fake.PublisherIdentityStub + fakeReturns := fake.publisherIdentityReturns + fake.recordInvocation("PublisherIdentity", []interface{}{}) + fake.publisherIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) PublisherIdentityCallCount() int { + fake.publisherIdentityMutex.RLock() + defer fake.publisherIdentityMutex.RUnlock() + return len(fake.publisherIdentityArgsForCall) +} + +func (fake *FakeMediaTrack) PublisherIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = stub +} + +func (fake *FakeMediaTrack) PublisherIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + fake.publisherIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeMediaTrack) PublisherIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + if fake.publisherIdentityReturnsOnCall == nil { + fake.publisherIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.publisherIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeMediaTrack) PublisherVersion() uint32 { + fake.publisherVersionMutex.Lock() + ret, specificReturn := fake.publisherVersionReturnsOnCall[len(fake.publisherVersionArgsForCall)] + fake.publisherVersionArgsForCall = append(fake.publisherVersionArgsForCall, struct { + }{}) + stub := fake.PublisherVersionStub + fakeReturns := fake.publisherVersionReturns + fake.recordInvocation("PublisherVersion", []interface{}{}) + fake.publisherVersionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) PublisherVersionCallCount() int { + fake.publisherVersionMutex.RLock() + defer fake.publisherVersionMutex.RUnlock() + return len(fake.publisherVersionArgsForCall) +} + +func (fake *FakeMediaTrack) PublisherVersionCalls(stub func() uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = stub +} + +func (fake *FakeMediaTrack) PublisherVersionReturns(result1 uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = nil + fake.publisherVersionReturns = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeMediaTrack) PublisherVersionReturnsOnCall(i int, result1 uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = nil + if fake.publisherVersionReturnsOnCall == nil { + fake.publisherVersionReturnsOnCall = make(map[int]struct { + result1 uint32 + }) + } + fake.publisherVersionReturnsOnCall[i] = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeMediaTrack) Receivers() []sfu.TrackReceiver { + fake.receiversMutex.Lock() + ret, specificReturn := fake.receiversReturnsOnCall[len(fake.receiversArgsForCall)] + fake.receiversArgsForCall = append(fake.receiversArgsForCall, struct { + }{}) + stub := fake.ReceiversStub + fakeReturns := fake.receiversReturns + fake.recordInvocation("Receivers", []interface{}{}) + fake.receiversMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) ReceiversCallCount() int { + fake.receiversMutex.RLock() + defer fake.receiversMutex.RUnlock() + return len(fake.receiversArgsForCall) +} + +func (fake *FakeMediaTrack) ReceiversCalls(stub func() []sfu.TrackReceiver) { + fake.receiversMutex.Lock() + defer fake.receiversMutex.Unlock() + fake.ReceiversStub = stub +} + +func (fake *FakeMediaTrack) ReceiversReturns(result1 []sfu.TrackReceiver) { + fake.receiversMutex.Lock() + defer fake.receiversMutex.Unlock() + fake.ReceiversStub = nil + fake.receiversReturns = struct { + result1 []sfu.TrackReceiver + }{result1} +} + +func (fake *FakeMediaTrack) ReceiversReturnsOnCall(i int, result1 []sfu.TrackReceiver) { + fake.receiversMutex.Lock() + defer fake.receiversMutex.Unlock() + fake.ReceiversStub = nil + if fake.receiversReturnsOnCall == nil { + fake.receiversReturnsOnCall = make(map[int]struct { + result1 []sfu.TrackReceiver + }) + } + fake.receiversReturnsOnCall[i] = struct { + result1 []sfu.TrackReceiver + }{result1} +} + +func (fake *FakeMediaTrack) RemoveSubscriber(arg1 livekit.ParticipantID, arg2 bool) { + fake.removeSubscriberMutex.Lock() + fake.removeSubscriberArgsForCall = append(fake.removeSubscriberArgsForCall, struct { + arg1 livekit.ParticipantID + arg2 bool + }{arg1, arg2}) + stub := fake.RemoveSubscriberStub + fake.recordInvocation("RemoveSubscriber", []interface{}{arg1, arg2}) + fake.removeSubscriberMutex.Unlock() + if stub != nil { + fake.RemoveSubscriberStub(arg1, arg2) + } +} + +func (fake *FakeMediaTrack) RemoveSubscriberCallCount() int { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + return len(fake.removeSubscriberArgsForCall) +} + +func (fake *FakeMediaTrack) RemoveSubscriberCalls(stub func(livekit.ParticipantID, bool)) { + fake.removeSubscriberMutex.Lock() + defer fake.removeSubscriberMutex.Unlock() + fake.RemoveSubscriberStub = stub +} + +func (fake *FakeMediaTrack) RemoveSubscriberArgsForCall(i int) (livekit.ParticipantID, bool) { + fake.removeSubscriberMutex.RLock() + defer fake.removeSubscriberMutex.RUnlock() + argsForCall := fake.removeSubscriberArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribers(arg1 []livekit.ParticipantIdentity) []livekit.ParticipantIdentity { + var arg1Copy []livekit.ParticipantIdentity + if arg1 != nil { + arg1Copy = make([]livekit.ParticipantIdentity, len(arg1)) + copy(arg1Copy, arg1) + } + fake.revokeDisallowedSubscribersMutex.Lock() + ret, specificReturn := fake.revokeDisallowedSubscribersReturnsOnCall[len(fake.revokeDisallowedSubscribersArgsForCall)] + fake.revokeDisallowedSubscribersArgsForCall = append(fake.revokeDisallowedSubscribersArgsForCall, struct { + arg1 []livekit.ParticipantIdentity + }{arg1Copy}) + stub := fake.RevokeDisallowedSubscribersStub + fakeReturns := fake.revokeDisallowedSubscribersReturns + fake.recordInvocation("RevokeDisallowedSubscribers", []interface{}{arg1Copy}) + fake.revokeDisallowedSubscribersMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCallCount() int { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + return len(fake.revokeDisallowedSubscribersArgsForCall) +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersCalls(stub func([]livekit.ParticipantIdentity) []livekit.ParticipantIdentity) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = stub +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersArgsForCall(i int) []livekit.ParticipantIdentity { + fake.revokeDisallowedSubscribersMutex.RLock() + defer fake.revokeDisallowedSubscribersMutex.RUnlock() + argsForCall := fake.revokeDisallowedSubscribersArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturns(result1 []livekit.ParticipantIdentity) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + fake.revokeDisallowedSubscribersReturns = struct { + result1 []livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeMediaTrack) RevokeDisallowedSubscribersReturnsOnCall(i int, result1 []livekit.ParticipantIdentity) { + fake.revokeDisallowedSubscribersMutex.Lock() + defer fake.revokeDisallowedSubscribersMutex.Unlock() + fake.RevokeDisallowedSubscribersStub = nil + if fake.revokeDisallowedSubscribersReturnsOnCall == nil { + fake.revokeDisallowedSubscribersReturnsOnCall = make(map[int]struct { + result1 []livekit.ParticipantIdentity + }) + } + fake.revokeDisallowedSubscribersReturnsOnCall[i] = struct { + result1 []livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeMediaTrack) SetMuted(arg1 bool) { + fake.setMutedMutex.Lock() + fake.setMutedArgsForCall = append(fake.setMutedArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.SetMutedStub + fake.recordInvocation("SetMuted", []interface{}{arg1}) + fake.setMutedMutex.Unlock() + if stub != nil { + fake.SetMutedStub(arg1) + } +} + +func (fake *FakeMediaTrack) SetMutedCallCount() int { + fake.setMutedMutex.RLock() + defer fake.setMutedMutex.RUnlock() + return len(fake.setMutedArgsForCall) +} + +func (fake *FakeMediaTrack) SetMutedCalls(stub func(bool)) { + fake.setMutedMutex.Lock() + defer fake.setMutedMutex.Unlock() + fake.SetMutedStub = stub +} + +func (fake *FakeMediaTrack) SetMutedArgsForCall(i int) bool { + fake.setMutedMutex.RLock() + defer fake.setMutedMutex.RUnlock() + argsForCall := fake.setMutedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) Source() livekit.TrackSource { + fake.sourceMutex.Lock() + ret, specificReturn := fake.sourceReturnsOnCall[len(fake.sourceArgsForCall)] + fake.sourceArgsForCall = append(fake.sourceArgsForCall, struct { + }{}) + stub := fake.SourceStub + fakeReturns := fake.sourceReturns + fake.recordInvocation("Source", []interface{}{}) + fake.sourceMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) SourceCallCount() int { + fake.sourceMutex.RLock() + defer fake.sourceMutex.RUnlock() + return len(fake.sourceArgsForCall) +} + +func (fake *FakeMediaTrack) SourceCalls(stub func() livekit.TrackSource) { + fake.sourceMutex.Lock() + defer fake.sourceMutex.Unlock() + fake.SourceStub = stub +} + +func (fake *FakeMediaTrack) SourceReturns(result1 livekit.TrackSource) { + fake.sourceMutex.Lock() + defer fake.sourceMutex.Unlock() + fake.SourceStub = nil + fake.sourceReturns = struct { + result1 livekit.TrackSource + }{result1} +} + +func (fake *FakeMediaTrack) SourceReturnsOnCall(i int, result1 livekit.TrackSource) { + fake.sourceMutex.Lock() + defer fake.sourceMutex.Unlock() + fake.SourceStub = nil + if fake.sourceReturnsOnCall == nil { + fake.sourceReturnsOnCall = make(map[int]struct { + result1 livekit.TrackSource + }) + } + fake.sourceReturnsOnCall[i] = struct { + result1 livekit.TrackSource + }{result1} +} + +func (fake *FakeMediaTrack) Stream() string { + fake.streamMutex.Lock() + ret, specificReturn := fake.streamReturnsOnCall[len(fake.streamArgsForCall)] + fake.streamArgsForCall = append(fake.streamArgsForCall, struct { + }{}) + stub := fake.StreamStub + fakeReturns := fake.streamReturns + fake.recordInvocation("Stream", []interface{}{}) + fake.streamMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) StreamCallCount() int { + fake.streamMutex.RLock() + defer fake.streamMutex.RUnlock() + return len(fake.streamArgsForCall) +} + +func (fake *FakeMediaTrack) StreamCalls(stub func() string) { + fake.streamMutex.Lock() + defer fake.streamMutex.Unlock() + fake.StreamStub = stub +} + +func (fake *FakeMediaTrack) StreamReturns(result1 string) { + fake.streamMutex.Lock() + defer fake.streamMutex.Unlock() + fake.StreamStub = nil + fake.streamReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeMediaTrack) StreamReturnsOnCall(i int, result1 string) { + fake.streamMutex.Lock() + defer fake.streamMutex.Unlock() + fake.StreamStub = nil + if fake.streamReturnsOnCall == nil { + fake.streamReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.streamReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeMediaTrack) ToProto() *livekit.TrackInfo { + fake.toProtoMutex.Lock() + ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] + fake.toProtoArgsForCall = append(fake.toProtoArgsForCall, struct { + }{}) + stub := fake.ToProtoStub + fakeReturns := fake.toProtoReturns + fake.recordInvocation("ToProto", []interface{}{}) + fake.toProtoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeMediaTrack) ToProtoCallCount() int { + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() + return len(fake.toProtoArgsForCall) +} + +func (fake *FakeMediaTrack) ToProtoCalls(stub func() *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = stub +} + +func (fake *FakeMediaTrack) ToProtoReturns(result1 *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + fake.toProtoReturns = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeMediaTrack) ToProtoReturnsOnCall(i int, result1 *livekit.TrackInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + if fake.toProtoReturnsOnCall == nil { + fake.toProtoReturnsOnCall = make(map[int]struct { + result1 *livekit.TrackInfo + }) + } + fake.toProtoReturnsOnCall[i] = struct { + result1 *livekit.TrackInfo + }{result1} +} + +func (fake *FakeMediaTrack) UpdateAudioTrack(arg1 *livekit.UpdateLocalAudioTrack) { + fake.updateAudioTrackMutex.Lock() + fake.updateAudioTrackArgsForCall = append(fake.updateAudioTrackArgsForCall, struct { + arg1 *livekit.UpdateLocalAudioTrack + }{arg1}) + stub := fake.UpdateAudioTrackStub + fake.recordInvocation("UpdateAudioTrack", []interface{}{arg1}) + fake.updateAudioTrackMutex.Unlock() + if stub != nil { + fake.UpdateAudioTrackStub(arg1) + } +} + +func (fake *FakeMediaTrack) UpdateAudioTrackCallCount() int { + fake.updateAudioTrackMutex.RLock() + defer fake.updateAudioTrackMutex.RUnlock() + return len(fake.updateAudioTrackArgsForCall) +} + +func (fake *FakeMediaTrack) UpdateAudioTrackCalls(stub func(*livekit.UpdateLocalAudioTrack)) { + fake.updateAudioTrackMutex.Lock() + defer fake.updateAudioTrackMutex.Unlock() + fake.UpdateAudioTrackStub = stub +} + +func (fake *FakeMediaTrack) UpdateAudioTrackArgsForCall(i int) *livekit.UpdateLocalAudioTrack { + fake.updateAudioTrackMutex.RLock() + defer fake.updateAudioTrackMutex.RUnlock() + argsForCall := fake.updateAudioTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) UpdateTrackInfo(arg1 *livekit.TrackInfo) { + fake.updateTrackInfoMutex.Lock() + fake.updateTrackInfoArgsForCall = append(fake.updateTrackInfoArgsForCall, struct { + arg1 *livekit.TrackInfo + }{arg1}) + stub := fake.UpdateTrackInfoStub + fake.recordInvocation("UpdateTrackInfo", []interface{}{arg1}) + fake.updateTrackInfoMutex.Unlock() + if stub != nil { + fake.UpdateTrackInfoStub(arg1) + } +} + +func (fake *FakeMediaTrack) UpdateTrackInfoCallCount() int { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + return len(fake.updateTrackInfoArgsForCall) +} + +func (fake *FakeMediaTrack) UpdateTrackInfoCalls(stub func(*livekit.TrackInfo)) { + fake.updateTrackInfoMutex.Lock() + defer fake.updateTrackInfoMutex.Unlock() + fake.UpdateTrackInfoStub = stub +} + +func (fake *FakeMediaTrack) UpdateTrackInfoArgsForCall(i int) *livekit.TrackInfo { + fake.updateTrackInfoMutex.RLock() + defer fake.updateTrackInfoMutex.RUnlock() + argsForCall := fake.updateTrackInfoArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) UpdateVideoTrack(arg1 *livekit.UpdateLocalVideoTrack) { + fake.updateVideoTrackMutex.Lock() + fake.updateVideoTrackArgsForCall = append(fake.updateVideoTrackArgsForCall, struct { + arg1 *livekit.UpdateLocalVideoTrack + }{arg1}) + stub := fake.UpdateVideoTrackStub + fake.recordInvocation("UpdateVideoTrack", []interface{}{arg1}) + fake.updateVideoTrackMutex.Unlock() + if stub != nil { + fake.UpdateVideoTrackStub(arg1) + } +} + +func (fake *FakeMediaTrack) UpdateVideoTrackCallCount() int { + fake.updateVideoTrackMutex.RLock() + defer fake.updateVideoTrackMutex.RUnlock() + return len(fake.updateVideoTrackArgsForCall) +} + +func (fake *FakeMediaTrack) UpdateVideoTrackCalls(stub func(*livekit.UpdateLocalVideoTrack)) { + fake.updateVideoTrackMutex.Lock() + defer fake.updateVideoTrackMutex.Unlock() + fake.UpdateVideoTrackStub = stub +} + +func (fake *FakeMediaTrack) UpdateVideoTrackArgsForCall(i int) *livekit.UpdateLocalVideoTrack { + fake.updateVideoTrackMutex.RLock() + defer fake.updateVideoTrackMutex.RUnlock() + argsForCall := fake.updateVideoTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeMediaTrack) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeMediaTrack) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.MediaTrack = new(FakeMediaTrack) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_participant.go b/livekit/pkg/rtc/types/typesfakes/fake_participant.go new file mode 100644 index 0000000..ed38762 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_participant.go @@ -0,0 +1,2124 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" +) + +type FakeParticipant struct { + CanSkipBroadcastStub func() bool + canSkipBroadcastMutex sync.RWMutex + canSkipBroadcastArgsForCall []struct { + } + canSkipBroadcastReturns struct { + result1 bool + } + canSkipBroadcastReturnsOnCall map[int]struct { + result1 bool + } + CloseStub func(bool, types.ParticipantCloseReason, bool) error + closeMutex sync.RWMutex + closeArgsForCall []struct { + arg1 bool + arg2 types.ParticipantCloseReason + arg3 bool + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + CloseReasonStub func() types.ParticipantCloseReason + closeReasonMutex sync.RWMutex + closeReasonArgsForCall []struct { + } + closeReasonReturns struct { + result1 types.ParticipantCloseReason + } + closeReasonReturnsOnCall map[int]struct { + result1 types.ParticipantCloseReason + } + ConnectedAtStub func() time.Time + connectedAtMutex sync.RWMutex + connectedAtArgsForCall []struct { + } + connectedAtReturns struct { + result1 time.Time + } + connectedAtReturnsOnCall map[int]struct { + result1 time.Time + } + DebugInfoStub func() map[string]any + debugInfoMutex sync.RWMutex + debugInfoArgsForCall []struct { + } + debugInfoReturns struct { + result1 map[string]any + } + debugInfoReturnsOnCall map[int]struct { + result1 map[string]any + } + GetAudioLevelStub func() (float64, bool) + getAudioLevelMutex sync.RWMutex + getAudioLevelArgsForCall []struct { + } + getAudioLevelReturns struct { + result1 float64 + result2 bool + } + getAudioLevelReturnsOnCall map[int]struct { + result1 float64 + result2 bool + } + GetLoggerStub func() logger.Logger + getLoggerMutex sync.RWMutex + getLoggerArgsForCall []struct { + } + getLoggerReturns struct { + result1 logger.Logger + } + getLoggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + GetParticipantListenerStub func() types.ParticipantListener + getParticipantListenerMutex sync.RWMutex + getParticipantListenerArgsForCall []struct { + } + getParticipantListenerReturns struct { + result1 types.ParticipantListener + } + getParticipantListenerReturnsOnCall map[int]struct { + result1 types.ParticipantListener + } + GetPublishedDataTrackStub func(uint16) types.DataTrack + getPublishedDataTrackMutex sync.RWMutex + getPublishedDataTrackArgsForCall []struct { + arg1 uint16 + } + getPublishedDataTrackReturns struct { + result1 types.DataTrack + } + getPublishedDataTrackReturnsOnCall map[int]struct { + result1 types.DataTrack + } + GetPublishedDataTracksStub func() []types.DataTrack + getPublishedDataTracksMutex sync.RWMutex + getPublishedDataTracksArgsForCall []struct { + } + getPublishedDataTracksReturns struct { + result1 []types.DataTrack + } + getPublishedDataTracksReturnsOnCall map[int]struct { + result1 []types.DataTrack + } + GetPublishedTrackStub func(livekit.TrackID) types.MediaTrack + getPublishedTrackMutex sync.RWMutex + getPublishedTrackArgsForCall []struct { + arg1 livekit.TrackID + } + getPublishedTrackReturns struct { + result1 types.MediaTrack + } + getPublishedTrackReturnsOnCall map[int]struct { + result1 types.MediaTrack + } + GetPublishedTracksStub func() []types.MediaTrack + getPublishedTracksMutex sync.RWMutex + getPublishedTracksArgsForCall []struct { + } + getPublishedTracksReturns struct { + result1 []types.MediaTrack + } + getPublishedTracksReturnsOnCall map[int]struct { + result1 []types.MediaTrack + } + HandleReceivedDataTrackMessageStub func([]byte, *datatrack.Packet, int64) + handleReceivedDataTrackMessageMutex sync.RWMutex + handleReceivedDataTrackMessageArgsForCall []struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + } + HasPermissionStub func(livekit.TrackID, livekit.ParticipantIdentity) bool + hasPermissionMutex sync.RWMutex + hasPermissionArgsForCall []struct { + arg1 livekit.TrackID + arg2 livekit.ParticipantIdentity + } + hasPermissionReturns struct { + result1 bool + } + hasPermissionReturnsOnCall map[int]struct { + result1 bool + } + HiddenStub func() bool + hiddenMutex sync.RWMutex + hiddenArgsForCall []struct { + } + hiddenReturns struct { + result1 bool + } + hiddenReturnsOnCall map[int]struct { + result1 bool + } + IDStub func() livekit.ParticipantID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.ParticipantID + } + iDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + IdentityStub func() livekit.ParticipantIdentity + identityMutex sync.RWMutex + identityArgsForCall []struct { + } + identityReturns struct { + result1 livekit.ParticipantIdentity + } + identityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + IsAgentStub func() bool + isAgentMutex sync.RWMutex + isAgentArgsForCall []struct { + } + isAgentReturns struct { + result1 bool + } + isAgentReturnsOnCall map[int]struct { + result1 bool + } + IsClosedStub func() bool + isClosedMutex sync.RWMutex + isClosedArgsForCall []struct { + } + isClosedReturns struct { + result1 bool + } + isClosedReturnsOnCall map[int]struct { + result1 bool + } + IsDependentStub func() bool + isDependentMutex sync.RWMutex + isDependentArgsForCall []struct { + } + isDependentReturns struct { + result1 bool + } + isDependentReturnsOnCall map[int]struct { + result1 bool + } + IsDisconnectedStub func() bool + isDisconnectedMutex sync.RWMutex + isDisconnectedArgsForCall []struct { + } + isDisconnectedReturns struct { + result1 bool + } + isDisconnectedReturnsOnCall map[int]struct { + result1 bool + } + IsPublisherStub func() bool + isPublisherMutex sync.RWMutex + isPublisherArgsForCall []struct { + } + isPublisherReturns struct { + result1 bool + } + isPublisherReturnsOnCall map[int]struct { + result1 bool + } + IsRecorderStub func() bool + isRecorderMutex sync.RWMutex + isRecorderArgsForCall []struct { + } + isRecorderReturns struct { + result1 bool + } + isRecorderReturnsOnCall map[int]struct { + result1 bool + } + KindStub func() livekit.ParticipantInfo_Kind + kindMutex sync.RWMutex + kindArgsForCall []struct { + } + kindReturns struct { + result1 livekit.ParticipantInfo_Kind + } + kindReturnsOnCall map[int]struct { + result1 livekit.ParticipantInfo_Kind + } + MigrateStateStub func() types.MigrateState + migrateStateMutex sync.RWMutex + migrateStateArgsForCall []struct { + } + migrateStateReturns struct { + result1 types.MigrateState + } + migrateStateReturnsOnCall map[int]struct { + result1 types.MigrateState + } + RemovePublishedDataTrackStub func(types.DataTrack) + removePublishedDataTrackMutex sync.RWMutex + removePublishedDataTrackArgsForCall []struct { + arg1 types.DataTrack + } + RemovePublishedTrackStub func(types.MediaTrack, bool) + removePublishedTrackMutex sync.RWMutex + removePublishedTrackArgsForCall []struct { + arg1 types.MediaTrack + arg2 bool + } + StateStub func() livekit.ParticipantInfo_State + stateMutex sync.RWMutex + stateArgsForCall []struct { + } + stateReturns struct { + result1 livekit.ParticipantInfo_State + } + stateReturnsOnCall map[int]struct { + result1 livekit.ParticipantInfo_State + } + SubscriptionPermissionStub func() (*livekit.SubscriptionPermission, utils.TimedVersion) + subscriptionPermissionMutex sync.RWMutex + subscriptionPermissionArgsForCall []struct { + } + subscriptionPermissionReturns struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + } + subscriptionPermissionReturnsOnCall map[int]struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + } + ToProtoStub func() *livekit.ParticipantInfo + toProtoMutex sync.RWMutex + toProtoArgsForCall []struct { + } + toProtoReturns struct { + result1 *livekit.ParticipantInfo + } + toProtoReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + } + ToProtoWithVersionStub func() (*livekit.ParticipantInfo, utils.TimedVersion) + toProtoWithVersionMutex sync.RWMutex + toProtoWithVersionArgsForCall []struct { + } + toProtoWithVersionReturns struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + } + toProtoWithVersionReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + } + UpdateSubscriptionPermissionStub func(*livekit.SubscriptionPermission, utils.TimedVersion, func(participantID livekit.ParticipantID) types.LocalParticipant) error + updateSubscriptionPermissionMutex sync.RWMutex + updateSubscriptionPermissionArgsForCall []struct { + arg1 *livekit.SubscriptionPermission + arg2 utils.TimedVersion + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant + } + updateSubscriptionPermissionReturns struct { + result1 error + } + updateSubscriptionPermissionReturnsOnCall map[int]struct { + result1 error + } + VersionStub func() utils.TimedVersion + versionMutex sync.RWMutex + versionArgsForCall []struct { + } + versionReturns struct { + result1 utils.TimedVersion + } + versionReturnsOnCall map[int]struct { + result1 utils.TimedVersion + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeParticipant) CanSkipBroadcast() bool { + fake.canSkipBroadcastMutex.Lock() + ret, specificReturn := fake.canSkipBroadcastReturnsOnCall[len(fake.canSkipBroadcastArgsForCall)] + fake.canSkipBroadcastArgsForCall = append(fake.canSkipBroadcastArgsForCall, struct { + }{}) + stub := fake.CanSkipBroadcastStub + fakeReturns := fake.canSkipBroadcastReturns + fake.recordInvocation("CanSkipBroadcast", []interface{}{}) + fake.canSkipBroadcastMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) CanSkipBroadcastCallCount() int { + fake.canSkipBroadcastMutex.RLock() + defer fake.canSkipBroadcastMutex.RUnlock() + return len(fake.canSkipBroadcastArgsForCall) +} + +func (fake *FakeParticipant) CanSkipBroadcastCalls(stub func() bool) { + fake.canSkipBroadcastMutex.Lock() + defer fake.canSkipBroadcastMutex.Unlock() + fake.CanSkipBroadcastStub = stub +} + +func (fake *FakeParticipant) CanSkipBroadcastReturns(result1 bool) { + fake.canSkipBroadcastMutex.Lock() + defer fake.canSkipBroadcastMutex.Unlock() + fake.CanSkipBroadcastStub = nil + fake.canSkipBroadcastReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) CanSkipBroadcastReturnsOnCall(i int, result1 bool) { + fake.canSkipBroadcastMutex.Lock() + defer fake.canSkipBroadcastMutex.Unlock() + fake.CanSkipBroadcastStub = nil + if fake.canSkipBroadcastReturnsOnCall == nil { + fake.canSkipBroadcastReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.canSkipBroadcastReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) Close(arg1 bool, arg2 types.ParticipantCloseReason, arg3 bool) error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + arg1 bool + arg2 types.ParticipantCloseReason + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{arg1, arg2, arg3}) + fake.closeMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeParticipant) CloseCalls(stub func(bool, types.ParticipantCloseReason, bool) error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeParticipant) CloseArgsForCall(i int) (bool, types.ParticipantCloseReason, bool) { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + argsForCall := fake.closeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeParticipant) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeParticipant) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeParticipant) CloseReason() types.ParticipantCloseReason { + fake.closeReasonMutex.Lock() + ret, specificReturn := fake.closeReasonReturnsOnCall[len(fake.closeReasonArgsForCall)] + fake.closeReasonArgsForCall = append(fake.closeReasonArgsForCall, struct { + }{}) + stub := fake.CloseReasonStub + fakeReturns := fake.closeReasonReturns + fake.recordInvocation("CloseReason", []interface{}{}) + fake.closeReasonMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) CloseReasonCallCount() int { + fake.closeReasonMutex.RLock() + defer fake.closeReasonMutex.RUnlock() + return len(fake.closeReasonArgsForCall) +} + +func (fake *FakeParticipant) CloseReasonCalls(stub func() types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = stub +} + +func (fake *FakeParticipant) CloseReasonReturns(result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + fake.closeReasonReturns = struct { + result1 types.ParticipantCloseReason + }{result1} +} + +func (fake *FakeParticipant) CloseReasonReturnsOnCall(i int, result1 types.ParticipantCloseReason) { + fake.closeReasonMutex.Lock() + defer fake.closeReasonMutex.Unlock() + fake.CloseReasonStub = nil + if fake.closeReasonReturnsOnCall == nil { + fake.closeReasonReturnsOnCall = make(map[int]struct { + result1 types.ParticipantCloseReason + }) + } + fake.closeReasonReturnsOnCall[i] = struct { + result1 types.ParticipantCloseReason + }{result1} +} + +func (fake *FakeParticipant) ConnectedAt() time.Time { + fake.connectedAtMutex.Lock() + ret, specificReturn := fake.connectedAtReturnsOnCall[len(fake.connectedAtArgsForCall)] + fake.connectedAtArgsForCall = append(fake.connectedAtArgsForCall, struct { + }{}) + stub := fake.ConnectedAtStub + fakeReturns := fake.connectedAtReturns + fake.recordInvocation("ConnectedAt", []interface{}{}) + fake.connectedAtMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) ConnectedAtCallCount() int { + fake.connectedAtMutex.RLock() + defer fake.connectedAtMutex.RUnlock() + return len(fake.connectedAtArgsForCall) +} + +func (fake *FakeParticipant) ConnectedAtCalls(stub func() time.Time) { + fake.connectedAtMutex.Lock() + defer fake.connectedAtMutex.Unlock() + fake.ConnectedAtStub = stub +} + +func (fake *FakeParticipant) ConnectedAtReturns(result1 time.Time) { + fake.connectedAtMutex.Lock() + defer fake.connectedAtMutex.Unlock() + fake.ConnectedAtStub = nil + fake.connectedAtReturns = struct { + result1 time.Time + }{result1} +} + +func (fake *FakeParticipant) ConnectedAtReturnsOnCall(i int, result1 time.Time) { + fake.connectedAtMutex.Lock() + defer fake.connectedAtMutex.Unlock() + fake.ConnectedAtStub = nil + if fake.connectedAtReturnsOnCall == nil { + fake.connectedAtReturnsOnCall = make(map[int]struct { + result1 time.Time + }) + } + fake.connectedAtReturnsOnCall[i] = struct { + result1 time.Time + }{result1} +} + +func (fake *FakeParticipant) DebugInfo() map[string]any { + fake.debugInfoMutex.Lock() + ret, specificReturn := fake.debugInfoReturnsOnCall[len(fake.debugInfoArgsForCall)] + fake.debugInfoArgsForCall = append(fake.debugInfoArgsForCall, struct { + }{}) + stub := fake.DebugInfoStub + fakeReturns := fake.debugInfoReturns + fake.recordInvocation("DebugInfo", []interface{}{}) + fake.debugInfoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) DebugInfoCallCount() int { + fake.debugInfoMutex.RLock() + defer fake.debugInfoMutex.RUnlock() + return len(fake.debugInfoArgsForCall) +} + +func (fake *FakeParticipant) DebugInfoCalls(stub func() map[string]any) { + fake.debugInfoMutex.Lock() + defer fake.debugInfoMutex.Unlock() + fake.DebugInfoStub = stub +} + +func (fake *FakeParticipant) DebugInfoReturns(result1 map[string]any) { + fake.debugInfoMutex.Lock() + defer fake.debugInfoMutex.Unlock() + fake.DebugInfoStub = nil + fake.debugInfoReturns = struct { + result1 map[string]any + }{result1} +} + +func (fake *FakeParticipant) DebugInfoReturnsOnCall(i int, result1 map[string]any) { + fake.debugInfoMutex.Lock() + defer fake.debugInfoMutex.Unlock() + fake.DebugInfoStub = nil + if fake.debugInfoReturnsOnCall == nil { + fake.debugInfoReturnsOnCall = make(map[int]struct { + result1 map[string]any + }) + } + fake.debugInfoReturnsOnCall[i] = struct { + result1 map[string]any + }{result1} +} + +func (fake *FakeParticipant) GetAudioLevel() (float64, bool) { + fake.getAudioLevelMutex.Lock() + ret, specificReturn := fake.getAudioLevelReturnsOnCall[len(fake.getAudioLevelArgsForCall)] + fake.getAudioLevelArgsForCall = append(fake.getAudioLevelArgsForCall, struct { + }{}) + stub := fake.GetAudioLevelStub + fakeReturns := fake.getAudioLevelReturns + fake.recordInvocation("GetAudioLevel", []interface{}{}) + fake.getAudioLevelMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeParticipant) GetAudioLevelCallCount() int { + fake.getAudioLevelMutex.RLock() + defer fake.getAudioLevelMutex.RUnlock() + return len(fake.getAudioLevelArgsForCall) +} + +func (fake *FakeParticipant) GetAudioLevelCalls(stub func() (float64, bool)) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = stub +} + +func (fake *FakeParticipant) GetAudioLevelReturns(result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + fake.getAudioLevelReturns = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeParticipant) GetAudioLevelReturnsOnCall(i int, result1 float64, result2 bool) { + fake.getAudioLevelMutex.Lock() + defer fake.getAudioLevelMutex.Unlock() + fake.GetAudioLevelStub = nil + if fake.getAudioLevelReturnsOnCall == nil { + fake.getAudioLevelReturnsOnCall = make(map[int]struct { + result1 float64 + result2 bool + }) + } + fake.getAudioLevelReturnsOnCall[i] = struct { + result1 float64 + result2 bool + }{result1, result2} +} + +func (fake *FakeParticipant) GetLogger() logger.Logger { + fake.getLoggerMutex.Lock() + ret, specificReturn := fake.getLoggerReturnsOnCall[len(fake.getLoggerArgsForCall)] + fake.getLoggerArgsForCall = append(fake.getLoggerArgsForCall, struct { + }{}) + stub := fake.GetLoggerStub + fakeReturns := fake.getLoggerReturns + fake.recordInvocation("GetLogger", []interface{}{}) + fake.getLoggerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetLoggerCallCount() int { + fake.getLoggerMutex.RLock() + defer fake.getLoggerMutex.RUnlock() + return len(fake.getLoggerArgsForCall) +} + +func (fake *FakeParticipant) GetLoggerCalls(stub func() logger.Logger) { + fake.getLoggerMutex.Lock() + defer fake.getLoggerMutex.Unlock() + fake.GetLoggerStub = stub +} + +func (fake *FakeParticipant) GetLoggerReturns(result1 logger.Logger) { + fake.getLoggerMutex.Lock() + defer fake.getLoggerMutex.Unlock() + fake.GetLoggerStub = nil + fake.getLoggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeParticipant) GetLoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.getLoggerMutex.Lock() + defer fake.getLoggerMutex.Unlock() + fake.GetLoggerStub = nil + if fake.getLoggerReturnsOnCall == nil { + fake.getLoggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.getLoggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeParticipant) GetParticipantListener() types.ParticipantListener { + fake.getParticipantListenerMutex.Lock() + ret, specificReturn := fake.getParticipantListenerReturnsOnCall[len(fake.getParticipantListenerArgsForCall)] + fake.getParticipantListenerArgsForCall = append(fake.getParticipantListenerArgsForCall, struct { + }{}) + stub := fake.GetParticipantListenerStub + fakeReturns := fake.getParticipantListenerReturns + fake.recordInvocation("GetParticipantListener", []interface{}{}) + fake.getParticipantListenerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetParticipantListenerCallCount() int { + fake.getParticipantListenerMutex.RLock() + defer fake.getParticipantListenerMutex.RUnlock() + return len(fake.getParticipantListenerArgsForCall) +} + +func (fake *FakeParticipant) GetParticipantListenerCalls(stub func() types.ParticipantListener) { + fake.getParticipantListenerMutex.Lock() + defer fake.getParticipantListenerMutex.Unlock() + fake.GetParticipantListenerStub = stub +} + +func (fake *FakeParticipant) GetParticipantListenerReturns(result1 types.ParticipantListener) { + fake.getParticipantListenerMutex.Lock() + defer fake.getParticipantListenerMutex.Unlock() + fake.GetParticipantListenerStub = nil + fake.getParticipantListenerReturns = struct { + result1 types.ParticipantListener + }{result1} +} + +func (fake *FakeParticipant) GetParticipantListenerReturnsOnCall(i int, result1 types.ParticipantListener) { + fake.getParticipantListenerMutex.Lock() + defer fake.getParticipantListenerMutex.Unlock() + fake.GetParticipantListenerStub = nil + if fake.getParticipantListenerReturnsOnCall == nil { + fake.getParticipantListenerReturnsOnCall = make(map[int]struct { + result1 types.ParticipantListener + }) + } + fake.getParticipantListenerReturnsOnCall[i] = struct { + result1 types.ParticipantListener + }{result1} +} + +func (fake *FakeParticipant) GetPublishedDataTrack(arg1 uint16) types.DataTrack { + fake.getPublishedDataTrackMutex.Lock() + ret, specificReturn := fake.getPublishedDataTrackReturnsOnCall[len(fake.getPublishedDataTrackArgsForCall)] + fake.getPublishedDataTrackArgsForCall = append(fake.getPublishedDataTrackArgsForCall, struct { + arg1 uint16 + }{arg1}) + stub := fake.GetPublishedDataTrackStub + fakeReturns := fake.getPublishedDataTrackReturns + fake.recordInvocation("GetPublishedDataTrack", []interface{}{arg1}) + fake.getPublishedDataTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetPublishedDataTrackCallCount() int { + fake.getPublishedDataTrackMutex.RLock() + defer fake.getPublishedDataTrackMutex.RUnlock() + return len(fake.getPublishedDataTrackArgsForCall) +} + +func (fake *FakeParticipant) GetPublishedDataTrackCalls(stub func(uint16) types.DataTrack) { + fake.getPublishedDataTrackMutex.Lock() + defer fake.getPublishedDataTrackMutex.Unlock() + fake.GetPublishedDataTrackStub = stub +} + +func (fake *FakeParticipant) GetPublishedDataTrackArgsForCall(i int) uint16 { + fake.getPublishedDataTrackMutex.RLock() + defer fake.getPublishedDataTrackMutex.RUnlock() + argsForCall := fake.getPublishedDataTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipant) GetPublishedDataTrackReturns(result1 types.DataTrack) { + fake.getPublishedDataTrackMutex.Lock() + defer fake.getPublishedDataTrackMutex.Unlock() + fake.GetPublishedDataTrackStub = nil + fake.getPublishedDataTrackReturns = struct { + result1 types.DataTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedDataTrackReturnsOnCall(i int, result1 types.DataTrack) { + fake.getPublishedDataTrackMutex.Lock() + defer fake.getPublishedDataTrackMutex.Unlock() + fake.GetPublishedDataTrackStub = nil + if fake.getPublishedDataTrackReturnsOnCall == nil { + fake.getPublishedDataTrackReturnsOnCall = make(map[int]struct { + result1 types.DataTrack + }) + } + fake.getPublishedDataTrackReturnsOnCall[i] = struct { + result1 types.DataTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedDataTracks() []types.DataTrack { + fake.getPublishedDataTracksMutex.Lock() + ret, specificReturn := fake.getPublishedDataTracksReturnsOnCall[len(fake.getPublishedDataTracksArgsForCall)] + fake.getPublishedDataTracksArgsForCall = append(fake.getPublishedDataTracksArgsForCall, struct { + }{}) + stub := fake.GetPublishedDataTracksStub + fakeReturns := fake.getPublishedDataTracksReturns + fake.recordInvocation("GetPublishedDataTracks", []interface{}{}) + fake.getPublishedDataTracksMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetPublishedDataTracksCallCount() int { + fake.getPublishedDataTracksMutex.RLock() + defer fake.getPublishedDataTracksMutex.RUnlock() + return len(fake.getPublishedDataTracksArgsForCall) +} + +func (fake *FakeParticipant) GetPublishedDataTracksCalls(stub func() []types.DataTrack) { + fake.getPublishedDataTracksMutex.Lock() + defer fake.getPublishedDataTracksMutex.Unlock() + fake.GetPublishedDataTracksStub = stub +} + +func (fake *FakeParticipant) GetPublishedDataTracksReturns(result1 []types.DataTrack) { + fake.getPublishedDataTracksMutex.Lock() + defer fake.getPublishedDataTracksMutex.Unlock() + fake.GetPublishedDataTracksStub = nil + fake.getPublishedDataTracksReturns = struct { + result1 []types.DataTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedDataTracksReturnsOnCall(i int, result1 []types.DataTrack) { + fake.getPublishedDataTracksMutex.Lock() + defer fake.getPublishedDataTracksMutex.Unlock() + fake.GetPublishedDataTracksStub = nil + if fake.getPublishedDataTracksReturnsOnCall == nil { + fake.getPublishedDataTracksReturnsOnCall = make(map[int]struct { + result1 []types.DataTrack + }) + } + fake.getPublishedDataTracksReturnsOnCall[i] = struct { + result1 []types.DataTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedTrack(arg1 livekit.TrackID) types.MediaTrack { + fake.getPublishedTrackMutex.Lock() + ret, specificReturn := fake.getPublishedTrackReturnsOnCall[len(fake.getPublishedTrackArgsForCall)] + fake.getPublishedTrackArgsForCall = append(fake.getPublishedTrackArgsForCall, struct { + arg1 livekit.TrackID + }{arg1}) + stub := fake.GetPublishedTrackStub + fakeReturns := fake.getPublishedTrackReturns + fake.recordInvocation("GetPublishedTrack", []interface{}{arg1}) + fake.getPublishedTrackMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetPublishedTrackCallCount() int { + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() + return len(fake.getPublishedTrackArgsForCall) +} + +func (fake *FakeParticipant) GetPublishedTrackCalls(stub func(livekit.TrackID) types.MediaTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = stub +} + +func (fake *FakeParticipant) GetPublishedTrackArgsForCall(i int) livekit.TrackID { + fake.getPublishedTrackMutex.RLock() + defer fake.getPublishedTrackMutex.RUnlock() + argsForCall := fake.getPublishedTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipant) GetPublishedTrackReturns(result1 types.MediaTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = nil + fake.getPublishedTrackReturns = struct { + result1 types.MediaTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedTrackReturnsOnCall(i int, result1 types.MediaTrack) { + fake.getPublishedTrackMutex.Lock() + defer fake.getPublishedTrackMutex.Unlock() + fake.GetPublishedTrackStub = nil + if fake.getPublishedTrackReturnsOnCall == nil { + fake.getPublishedTrackReturnsOnCall = make(map[int]struct { + result1 types.MediaTrack + }) + } + fake.getPublishedTrackReturnsOnCall[i] = struct { + result1 types.MediaTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedTracks() []types.MediaTrack { + fake.getPublishedTracksMutex.Lock() + ret, specificReturn := fake.getPublishedTracksReturnsOnCall[len(fake.getPublishedTracksArgsForCall)] + fake.getPublishedTracksArgsForCall = append(fake.getPublishedTracksArgsForCall, struct { + }{}) + stub := fake.GetPublishedTracksStub + fakeReturns := fake.getPublishedTracksReturns + fake.recordInvocation("GetPublishedTracks", []interface{}{}) + fake.getPublishedTracksMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) GetPublishedTracksCallCount() int { + fake.getPublishedTracksMutex.RLock() + defer fake.getPublishedTracksMutex.RUnlock() + return len(fake.getPublishedTracksArgsForCall) +} + +func (fake *FakeParticipant) GetPublishedTracksCalls(stub func() []types.MediaTrack) { + fake.getPublishedTracksMutex.Lock() + defer fake.getPublishedTracksMutex.Unlock() + fake.GetPublishedTracksStub = stub +} + +func (fake *FakeParticipant) GetPublishedTracksReturns(result1 []types.MediaTrack) { + fake.getPublishedTracksMutex.Lock() + defer fake.getPublishedTracksMutex.Unlock() + fake.GetPublishedTracksStub = nil + fake.getPublishedTracksReturns = struct { + result1 []types.MediaTrack + }{result1} +} + +func (fake *FakeParticipant) GetPublishedTracksReturnsOnCall(i int, result1 []types.MediaTrack) { + fake.getPublishedTracksMutex.Lock() + defer fake.getPublishedTracksMutex.Unlock() + fake.GetPublishedTracksStub = nil + if fake.getPublishedTracksReturnsOnCall == nil { + fake.getPublishedTracksReturnsOnCall = make(map[int]struct { + result1 []types.MediaTrack + }) + } + fake.getPublishedTracksReturnsOnCall[i] = struct { + result1 []types.MediaTrack + }{result1} +} + +func (fake *FakeParticipant) HandleReceivedDataTrackMessage(arg1 []byte, arg2 *datatrack.Packet, arg3 int64) { + var arg1Copy []byte + if arg1 != nil { + arg1Copy = make([]byte, len(arg1)) + copy(arg1Copy, arg1) + } + fake.handleReceivedDataTrackMessageMutex.Lock() + fake.handleReceivedDataTrackMessageArgsForCall = append(fake.handleReceivedDataTrackMessageArgsForCall, struct { + arg1 []byte + arg2 *datatrack.Packet + arg3 int64 + }{arg1Copy, arg2, arg3}) + stub := fake.HandleReceivedDataTrackMessageStub + fake.recordInvocation("HandleReceivedDataTrackMessage", []interface{}{arg1Copy, arg2, arg3}) + fake.handleReceivedDataTrackMessageMutex.Unlock() + if stub != nil { + fake.HandleReceivedDataTrackMessageStub(arg1, arg2, arg3) + } +} + +func (fake *FakeParticipant) HandleReceivedDataTrackMessageCallCount() int { + fake.handleReceivedDataTrackMessageMutex.RLock() + defer fake.handleReceivedDataTrackMessageMutex.RUnlock() + return len(fake.handleReceivedDataTrackMessageArgsForCall) +} + +func (fake *FakeParticipant) HandleReceivedDataTrackMessageCalls(stub func([]byte, *datatrack.Packet, int64)) { + fake.handleReceivedDataTrackMessageMutex.Lock() + defer fake.handleReceivedDataTrackMessageMutex.Unlock() + fake.HandleReceivedDataTrackMessageStub = stub +} + +func (fake *FakeParticipant) HandleReceivedDataTrackMessageArgsForCall(i int) ([]byte, *datatrack.Packet, int64) { + fake.handleReceivedDataTrackMessageMutex.RLock() + defer fake.handleReceivedDataTrackMessageMutex.RUnlock() + argsForCall := fake.handleReceivedDataTrackMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeParticipant) HasPermission(arg1 livekit.TrackID, arg2 livekit.ParticipantIdentity) bool { + fake.hasPermissionMutex.Lock() + ret, specificReturn := fake.hasPermissionReturnsOnCall[len(fake.hasPermissionArgsForCall)] + fake.hasPermissionArgsForCall = append(fake.hasPermissionArgsForCall, struct { + arg1 livekit.TrackID + arg2 livekit.ParticipantIdentity + }{arg1, arg2}) + stub := fake.HasPermissionStub + fakeReturns := fake.hasPermissionReturns + fake.recordInvocation("HasPermission", []interface{}{arg1, arg2}) + fake.hasPermissionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) HasPermissionCallCount() int { + fake.hasPermissionMutex.RLock() + defer fake.hasPermissionMutex.RUnlock() + return len(fake.hasPermissionArgsForCall) +} + +func (fake *FakeParticipant) HasPermissionCalls(stub func(livekit.TrackID, livekit.ParticipantIdentity) bool) { + fake.hasPermissionMutex.Lock() + defer fake.hasPermissionMutex.Unlock() + fake.HasPermissionStub = stub +} + +func (fake *FakeParticipant) HasPermissionArgsForCall(i int) (livekit.TrackID, livekit.ParticipantIdentity) { + fake.hasPermissionMutex.RLock() + defer fake.hasPermissionMutex.RUnlock() + argsForCall := fake.hasPermissionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipant) HasPermissionReturns(result1 bool) { + fake.hasPermissionMutex.Lock() + defer fake.hasPermissionMutex.Unlock() + fake.HasPermissionStub = nil + fake.hasPermissionReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) HasPermissionReturnsOnCall(i int, result1 bool) { + fake.hasPermissionMutex.Lock() + defer fake.hasPermissionMutex.Unlock() + fake.HasPermissionStub = nil + if fake.hasPermissionReturnsOnCall == nil { + fake.hasPermissionReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hasPermissionReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) Hidden() bool { + fake.hiddenMutex.Lock() + ret, specificReturn := fake.hiddenReturnsOnCall[len(fake.hiddenArgsForCall)] + fake.hiddenArgsForCall = append(fake.hiddenArgsForCall, struct { + }{}) + stub := fake.HiddenStub + fakeReturns := fake.hiddenReturns + fake.recordInvocation("Hidden", []interface{}{}) + fake.hiddenMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) HiddenCallCount() int { + fake.hiddenMutex.RLock() + defer fake.hiddenMutex.RUnlock() + return len(fake.hiddenArgsForCall) +} + +func (fake *FakeParticipant) HiddenCalls(stub func() bool) { + fake.hiddenMutex.Lock() + defer fake.hiddenMutex.Unlock() + fake.HiddenStub = stub +} + +func (fake *FakeParticipant) HiddenReturns(result1 bool) { + fake.hiddenMutex.Lock() + defer fake.hiddenMutex.Unlock() + fake.HiddenStub = nil + fake.hiddenReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) HiddenReturnsOnCall(i int, result1 bool) { + fake.hiddenMutex.Lock() + defer fake.hiddenMutex.Unlock() + fake.HiddenStub = nil + if fake.hiddenReturnsOnCall == nil { + fake.hiddenReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.hiddenReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) ID() livekit.ParticipantID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeParticipant) IDCalls(stub func() livekit.ParticipantID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeParticipant) IDReturns(result1 livekit.ParticipantID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeParticipant) IDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeParticipant) Identity() livekit.ParticipantIdentity { + fake.identityMutex.Lock() + ret, specificReturn := fake.identityReturnsOnCall[len(fake.identityArgsForCall)] + fake.identityArgsForCall = append(fake.identityArgsForCall, struct { + }{}) + stub := fake.IdentityStub + fakeReturns := fake.identityReturns + fake.recordInvocation("Identity", []interface{}{}) + fake.identityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IdentityCallCount() int { + fake.identityMutex.RLock() + defer fake.identityMutex.RUnlock() + return len(fake.identityArgsForCall) +} + +func (fake *FakeParticipant) IdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.identityMutex.Lock() + defer fake.identityMutex.Unlock() + fake.IdentityStub = stub +} + +func (fake *FakeParticipant) IdentityReturns(result1 livekit.ParticipantIdentity) { + fake.identityMutex.Lock() + defer fake.identityMutex.Unlock() + fake.IdentityStub = nil + fake.identityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeParticipant) IdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.identityMutex.Lock() + defer fake.identityMutex.Unlock() + fake.IdentityStub = nil + if fake.identityReturnsOnCall == nil { + fake.identityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.identityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeParticipant) IsAgent() bool { + fake.isAgentMutex.Lock() + ret, specificReturn := fake.isAgentReturnsOnCall[len(fake.isAgentArgsForCall)] + fake.isAgentArgsForCall = append(fake.isAgentArgsForCall, struct { + }{}) + stub := fake.IsAgentStub + fakeReturns := fake.isAgentReturns + fake.recordInvocation("IsAgent", []interface{}{}) + fake.isAgentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsAgentCallCount() int { + fake.isAgentMutex.RLock() + defer fake.isAgentMutex.RUnlock() + return len(fake.isAgentArgsForCall) +} + +func (fake *FakeParticipant) IsAgentCalls(stub func() bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = stub +} + +func (fake *FakeParticipant) IsAgentReturns(result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + fake.isAgentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsAgentReturnsOnCall(i int, result1 bool) { + fake.isAgentMutex.Lock() + defer fake.isAgentMutex.Unlock() + fake.IsAgentStub = nil + if fake.isAgentReturnsOnCall == nil { + fake.isAgentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAgentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsClosed() bool { + fake.isClosedMutex.Lock() + ret, specificReturn := fake.isClosedReturnsOnCall[len(fake.isClosedArgsForCall)] + fake.isClosedArgsForCall = append(fake.isClosedArgsForCall, struct { + }{}) + stub := fake.IsClosedStub + fakeReturns := fake.isClosedReturns + fake.recordInvocation("IsClosed", []interface{}{}) + fake.isClosedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsClosedCallCount() int { + fake.isClosedMutex.RLock() + defer fake.isClosedMutex.RUnlock() + return len(fake.isClosedArgsForCall) +} + +func (fake *FakeParticipant) IsClosedCalls(stub func() bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = stub +} + +func (fake *FakeParticipant) IsClosedReturns(result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + fake.isClosedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsClosedReturnsOnCall(i int, result1 bool) { + fake.isClosedMutex.Lock() + defer fake.isClosedMutex.Unlock() + fake.IsClosedStub = nil + if fake.isClosedReturnsOnCall == nil { + fake.isClosedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isClosedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsDependent() bool { + fake.isDependentMutex.Lock() + ret, specificReturn := fake.isDependentReturnsOnCall[len(fake.isDependentArgsForCall)] + fake.isDependentArgsForCall = append(fake.isDependentArgsForCall, struct { + }{}) + stub := fake.IsDependentStub + fakeReturns := fake.isDependentReturns + fake.recordInvocation("IsDependent", []interface{}{}) + fake.isDependentMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsDependentCallCount() int { + fake.isDependentMutex.RLock() + defer fake.isDependentMutex.RUnlock() + return len(fake.isDependentArgsForCall) +} + +func (fake *FakeParticipant) IsDependentCalls(stub func() bool) { + fake.isDependentMutex.Lock() + defer fake.isDependentMutex.Unlock() + fake.IsDependentStub = stub +} + +func (fake *FakeParticipant) IsDependentReturns(result1 bool) { + fake.isDependentMutex.Lock() + defer fake.isDependentMutex.Unlock() + fake.IsDependentStub = nil + fake.isDependentReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsDependentReturnsOnCall(i int, result1 bool) { + fake.isDependentMutex.Lock() + defer fake.isDependentMutex.Unlock() + fake.IsDependentStub = nil + if fake.isDependentReturnsOnCall == nil { + fake.isDependentReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isDependentReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsDisconnected() bool { + fake.isDisconnectedMutex.Lock() + ret, specificReturn := fake.isDisconnectedReturnsOnCall[len(fake.isDisconnectedArgsForCall)] + fake.isDisconnectedArgsForCall = append(fake.isDisconnectedArgsForCall, struct { + }{}) + stub := fake.IsDisconnectedStub + fakeReturns := fake.isDisconnectedReturns + fake.recordInvocation("IsDisconnected", []interface{}{}) + fake.isDisconnectedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsDisconnectedCallCount() int { + fake.isDisconnectedMutex.RLock() + defer fake.isDisconnectedMutex.RUnlock() + return len(fake.isDisconnectedArgsForCall) +} + +func (fake *FakeParticipant) IsDisconnectedCalls(stub func() bool) { + fake.isDisconnectedMutex.Lock() + defer fake.isDisconnectedMutex.Unlock() + fake.IsDisconnectedStub = stub +} + +func (fake *FakeParticipant) IsDisconnectedReturns(result1 bool) { + fake.isDisconnectedMutex.Lock() + defer fake.isDisconnectedMutex.Unlock() + fake.IsDisconnectedStub = nil + fake.isDisconnectedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsDisconnectedReturnsOnCall(i int, result1 bool) { + fake.isDisconnectedMutex.Lock() + defer fake.isDisconnectedMutex.Unlock() + fake.IsDisconnectedStub = nil + if fake.isDisconnectedReturnsOnCall == nil { + fake.isDisconnectedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isDisconnectedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsPublisher() bool { + fake.isPublisherMutex.Lock() + ret, specificReturn := fake.isPublisherReturnsOnCall[len(fake.isPublisherArgsForCall)] + fake.isPublisherArgsForCall = append(fake.isPublisherArgsForCall, struct { + }{}) + stub := fake.IsPublisherStub + fakeReturns := fake.isPublisherReturns + fake.recordInvocation("IsPublisher", []interface{}{}) + fake.isPublisherMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsPublisherCallCount() int { + fake.isPublisherMutex.RLock() + defer fake.isPublisherMutex.RUnlock() + return len(fake.isPublisherArgsForCall) +} + +func (fake *FakeParticipant) IsPublisherCalls(stub func() bool) { + fake.isPublisherMutex.Lock() + defer fake.isPublisherMutex.Unlock() + fake.IsPublisherStub = stub +} + +func (fake *FakeParticipant) IsPublisherReturns(result1 bool) { + fake.isPublisherMutex.Lock() + defer fake.isPublisherMutex.Unlock() + fake.IsPublisherStub = nil + fake.isPublisherReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsPublisherReturnsOnCall(i int, result1 bool) { + fake.isPublisherMutex.Lock() + defer fake.isPublisherMutex.Unlock() + fake.IsPublisherStub = nil + if fake.isPublisherReturnsOnCall == nil { + fake.isPublisherReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isPublisherReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsRecorder() bool { + fake.isRecorderMutex.Lock() + ret, specificReturn := fake.isRecorderReturnsOnCall[len(fake.isRecorderArgsForCall)] + fake.isRecorderArgsForCall = append(fake.isRecorderArgsForCall, struct { + }{}) + stub := fake.IsRecorderStub + fakeReturns := fake.isRecorderReturns + fake.recordInvocation("IsRecorder", []interface{}{}) + fake.isRecorderMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) IsRecorderCallCount() int { + fake.isRecorderMutex.RLock() + defer fake.isRecorderMutex.RUnlock() + return len(fake.isRecorderArgsForCall) +} + +func (fake *FakeParticipant) IsRecorderCalls(stub func() bool) { + fake.isRecorderMutex.Lock() + defer fake.isRecorderMutex.Unlock() + fake.IsRecorderStub = stub +} + +func (fake *FakeParticipant) IsRecorderReturns(result1 bool) { + fake.isRecorderMutex.Lock() + defer fake.isRecorderMutex.Unlock() + fake.IsRecorderStub = nil + fake.isRecorderReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) IsRecorderReturnsOnCall(i int, result1 bool) { + fake.isRecorderMutex.Lock() + defer fake.isRecorderMutex.Unlock() + fake.IsRecorderStub = nil + if fake.isRecorderReturnsOnCall == nil { + fake.isRecorderReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isRecorderReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeParticipant) Kind() livekit.ParticipantInfo_Kind { + fake.kindMutex.Lock() + ret, specificReturn := fake.kindReturnsOnCall[len(fake.kindArgsForCall)] + fake.kindArgsForCall = append(fake.kindArgsForCall, struct { + }{}) + stub := fake.KindStub + fakeReturns := fake.kindReturns + fake.recordInvocation("Kind", []interface{}{}) + fake.kindMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) KindCallCount() int { + fake.kindMutex.RLock() + defer fake.kindMutex.RUnlock() + return len(fake.kindArgsForCall) +} + +func (fake *FakeParticipant) KindCalls(stub func() livekit.ParticipantInfo_Kind) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = stub +} + +func (fake *FakeParticipant) KindReturns(result1 livekit.ParticipantInfo_Kind) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + fake.kindReturns = struct { + result1 livekit.ParticipantInfo_Kind + }{result1} +} + +func (fake *FakeParticipant) KindReturnsOnCall(i int, result1 livekit.ParticipantInfo_Kind) { + fake.kindMutex.Lock() + defer fake.kindMutex.Unlock() + fake.KindStub = nil + if fake.kindReturnsOnCall == nil { + fake.kindReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantInfo_Kind + }) + } + fake.kindReturnsOnCall[i] = struct { + result1 livekit.ParticipantInfo_Kind + }{result1} +} + +func (fake *FakeParticipant) MigrateState() types.MigrateState { + fake.migrateStateMutex.Lock() + ret, specificReturn := fake.migrateStateReturnsOnCall[len(fake.migrateStateArgsForCall)] + fake.migrateStateArgsForCall = append(fake.migrateStateArgsForCall, struct { + }{}) + stub := fake.MigrateStateStub + fakeReturns := fake.migrateStateReturns + fake.recordInvocation("MigrateState", []interface{}{}) + fake.migrateStateMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) MigrateStateCallCount() int { + fake.migrateStateMutex.RLock() + defer fake.migrateStateMutex.RUnlock() + return len(fake.migrateStateArgsForCall) +} + +func (fake *FakeParticipant) MigrateStateCalls(stub func() types.MigrateState) { + fake.migrateStateMutex.Lock() + defer fake.migrateStateMutex.Unlock() + fake.MigrateStateStub = stub +} + +func (fake *FakeParticipant) MigrateStateReturns(result1 types.MigrateState) { + fake.migrateStateMutex.Lock() + defer fake.migrateStateMutex.Unlock() + fake.MigrateStateStub = nil + fake.migrateStateReturns = struct { + result1 types.MigrateState + }{result1} +} + +func (fake *FakeParticipant) MigrateStateReturnsOnCall(i int, result1 types.MigrateState) { + fake.migrateStateMutex.Lock() + defer fake.migrateStateMutex.Unlock() + fake.MigrateStateStub = nil + if fake.migrateStateReturnsOnCall == nil { + fake.migrateStateReturnsOnCall = make(map[int]struct { + result1 types.MigrateState + }) + } + fake.migrateStateReturnsOnCall[i] = struct { + result1 types.MigrateState + }{result1} +} + +func (fake *FakeParticipant) RemovePublishedDataTrack(arg1 types.DataTrack) { + fake.removePublishedDataTrackMutex.Lock() + fake.removePublishedDataTrackArgsForCall = append(fake.removePublishedDataTrackArgsForCall, struct { + arg1 types.DataTrack + }{arg1}) + stub := fake.RemovePublishedDataTrackStub + fake.recordInvocation("RemovePublishedDataTrack", []interface{}{arg1}) + fake.removePublishedDataTrackMutex.Unlock() + if stub != nil { + fake.RemovePublishedDataTrackStub(arg1) + } +} + +func (fake *FakeParticipant) RemovePublishedDataTrackCallCount() int { + fake.removePublishedDataTrackMutex.RLock() + defer fake.removePublishedDataTrackMutex.RUnlock() + return len(fake.removePublishedDataTrackArgsForCall) +} + +func (fake *FakeParticipant) RemovePublishedDataTrackCalls(stub func(types.DataTrack)) { + fake.removePublishedDataTrackMutex.Lock() + defer fake.removePublishedDataTrackMutex.Unlock() + fake.RemovePublishedDataTrackStub = stub +} + +func (fake *FakeParticipant) RemovePublishedDataTrackArgsForCall(i int) types.DataTrack { + fake.removePublishedDataTrackMutex.RLock() + defer fake.removePublishedDataTrackMutex.RUnlock() + argsForCall := fake.removePublishedDataTrackArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipant) RemovePublishedTrack(arg1 types.MediaTrack, arg2 bool) { + fake.removePublishedTrackMutex.Lock() + fake.removePublishedTrackArgsForCall = append(fake.removePublishedTrackArgsForCall, struct { + arg1 types.MediaTrack + arg2 bool + }{arg1, arg2}) + stub := fake.RemovePublishedTrackStub + fake.recordInvocation("RemovePublishedTrack", []interface{}{arg1, arg2}) + fake.removePublishedTrackMutex.Unlock() + if stub != nil { + fake.RemovePublishedTrackStub(arg1, arg2) + } +} + +func (fake *FakeParticipant) RemovePublishedTrackCallCount() int { + fake.removePublishedTrackMutex.RLock() + defer fake.removePublishedTrackMutex.RUnlock() + return len(fake.removePublishedTrackArgsForCall) +} + +func (fake *FakeParticipant) RemovePublishedTrackCalls(stub func(types.MediaTrack, bool)) { + fake.removePublishedTrackMutex.Lock() + defer fake.removePublishedTrackMutex.Unlock() + fake.RemovePublishedTrackStub = stub +} + +func (fake *FakeParticipant) RemovePublishedTrackArgsForCall(i int) (types.MediaTrack, bool) { + fake.removePublishedTrackMutex.RLock() + defer fake.removePublishedTrackMutex.RUnlock() + argsForCall := fake.removePublishedTrackArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipant) State() livekit.ParticipantInfo_State { + fake.stateMutex.Lock() + ret, specificReturn := fake.stateReturnsOnCall[len(fake.stateArgsForCall)] + fake.stateArgsForCall = append(fake.stateArgsForCall, struct { + }{}) + stub := fake.StateStub + fakeReturns := fake.stateReturns + fake.recordInvocation("State", []interface{}{}) + fake.stateMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) StateCallCount() int { + fake.stateMutex.RLock() + defer fake.stateMutex.RUnlock() + return len(fake.stateArgsForCall) +} + +func (fake *FakeParticipant) StateCalls(stub func() livekit.ParticipantInfo_State) { + fake.stateMutex.Lock() + defer fake.stateMutex.Unlock() + fake.StateStub = stub +} + +func (fake *FakeParticipant) StateReturns(result1 livekit.ParticipantInfo_State) { + fake.stateMutex.Lock() + defer fake.stateMutex.Unlock() + fake.StateStub = nil + fake.stateReturns = struct { + result1 livekit.ParticipantInfo_State + }{result1} +} + +func (fake *FakeParticipant) StateReturnsOnCall(i int, result1 livekit.ParticipantInfo_State) { + fake.stateMutex.Lock() + defer fake.stateMutex.Unlock() + fake.StateStub = nil + if fake.stateReturnsOnCall == nil { + fake.stateReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantInfo_State + }) + } + fake.stateReturnsOnCall[i] = struct { + result1 livekit.ParticipantInfo_State + }{result1} +} + +func (fake *FakeParticipant) SubscriptionPermission() (*livekit.SubscriptionPermission, utils.TimedVersion) { + fake.subscriptionPermissionMutex.Lock() + ret, specificReturn := fake.subscriptionPermissionReturnsOnCall[len(fake.subscriptionPermissionArgsForCall)] + fake.subscriptionPermissionArgsForCall = append(fake.subscriptionPermissionArgsForCall, struct { + }{}) + stub := fake.SubscriptionPermissionStub + fakeReturns := fake.subscriptionPermissionReturns + fake.recordInvocation("SubscriptionPermission", []interface{}{}) + fake.subscriptionPermissionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeParticipant) SubscriptionPermissionCallCount() int { + fake.subscriptionPermissionMutex.RLock() + defer fake.subscriptionPermissionMutex.RUnlock() + return len(fake.subscriptionPermissionArgsForCall) +} + +func (fake *FakeParticipant) SubscriptionPermissionCalls(stub func() (*livekit.SubscriptionPermission, utils.TimedVersion)) { + fake.subscriptionPermissionMutex.Lock() + defer fake.subscriptionPermissionMutex.Unlock() + fake.SubscriptionPermissionStub = stub +} + +func (fake *FakeParticipant) SubscriptionPermissionReturns(result1 *livekit.SubscriptionPermission, result2 utils.TimedVersion) { + fake.subscriptionPermissionMutex.Lock() + defer fake.subscriptionPermissionMutex.Unlock() + fake.SubscriptionPermissionStub = nil + fake.subscriptionPermissionReturns = struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeParticipant) SubscriptionPermissionReturnsOnCall(i int, result1 *livekit.SubscriptionPermission, result2 utils.TimedVersion) { + fake.subscriptionPermissionMutex.Lock() + defer fake.subscriptionPermissionMutex.Unlock() + fake.SubscriptionPermissionStub = nil + if fake.subscriptionPermissionReturnsOnCall == nil { + fake.subscriptionPermissionReturnsOnCall = make(map[int]struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + }) + } + fake.subscriptionPermissionReturnsOnCall[i] = struct { + result1 *livekit.SubscriptionPermission + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeParticipant) ToProto() *livekit.ParticipantInfo { + fake.toProtoMutex.Lock() + ret, specificReturn := fake.toProtoReturnsOnCall[len(fake.toProtoArgsForCall)] + fake.toProtoArgsForCall = append(fake.toProtoArgsForCall, struct { + }{}) + stub := fake.ToProtoStub + fakeReturns := fake.toProtoReturns + fake.recordInvocation("ToProto", []interface{}{}) + fake.toProtoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) ToProtoCallCount() int { + fake.toProtoMutex.RLock() + defer fake.toProtoMutex.RUnlock() + return len(fake.toProtoArgsForCall) +} + +func (fake *FakeParticipant) ToProtoCalls(stub func() *livekit.ParticipantInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = stub +} + +func (fake *FakeParticipant) ToProtoReturns(result1 *livekit.ParticipantInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + fake.toProtoReturns = struct { + result1 *livekit.ParticipantInfo + }{result1} +} + +func (fake *FakeParticipant) ToProtoReturnsOnCall(i int, result1 *livekit.ParticipantInfo) { + fake.toProtoMutex.Lock() + defer fake.toProtoMutex.Unlock() + fake.ToProtoStub = nil + if fake.toProtoReturnsOnCall == nil { + fake.toProtoReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + }) + } + fake.toProtoReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + }{result1} +} + +func (fake *FakeParticipant) ToProtoWithVersion() (*livekit.ParticipantInfo, utils.TimedVersion) { + fake.toProtoWithVersionMutex.Lock() + ret, specificReturn := fake.toProtoWithVersionReturnsOnCall[len(fake.toProtoWithVersionArgsForCall)] + fake.toProtoWithVersionArgsForCall = append(fake.toProtoWithVersionArgsForCall, struct { + }{}) + stub := fake.ToProtoWithVersionStub + fakeReturns := fake.toProtoWithVersionReturns + fake.recordInvocation("ToProtoWithVersion", []interface{}{}) + fake.toProtoWithVersionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeParticipant) ToProtoWithVersionCallCount() int { + fake.toProtoWithVersionMutex.RLock() + defer fake.toProtoWithVersionMutex.RUnlock() + return len(fake.toProtoWithVersionArgsForCall) +} + +func (fake *FakeParticipant) ToProtoWithVersionCalls(stub func() (*livekit.ParticipantInfo, utils.TimedVersion)) { + fake.toProtoWithVersionMutex.Lock() + defer fake.toProtoWithVersionMutex.Unlock() + fake.ToProtoWithVersionStub = stub +} + +func (fake *FakeParticipant) ToProtoWithVersionReturns(result1 *livekit.ParticipantInfo, result2 utils.TimedVersion) { + fake.toProtoWithVersionMutex.Lock() + defer fake.toProtoWithVersionMutex.Unlock() + fake.ToProtoWithVersionStub = nil + fake.toProtoWithVersionReturns = struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeParticipant) ToProtoWithVersionReturnsOnCall(i int, result1 *livekit.ParticipantInfo, result2 utils.TimedVersion) { + fake.toProtoWithVersionMutex.Lock() + defer fake.toProtoWithVersionMutex.Unlock() + fake.ToProtoWithVersionStub = nil + if fake.toProtoWithVersionReturnsOnCall == nil { + fake.toProtoWithVersionReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + }) + } + fake.toProtoWithVersionReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + result2 utils.TimedVersion + }{result1, result2} +} + +func (fake *FakeParticipant) UpdateSubscriptionPermission(arg1 *livekit.SubscriptionPermission, arg2 utils.TimedVersion, arg3 func(participantID livekit.ParticipantID) types.LocalParticipant) error { + fake.updateSubscriptionPermissionMutex.Lock() + ret, specificReturn := fake.updateSubscriptionPermissionReturnsOnCall[len(fake.updateSubscriptionPermissionArgsForCall)] + fake.updateSubscriptionPermissionArgsForCall = append(fake.updateSubscriptionPermissionArgsForCall, struct { + arg1 *livekit.SubscriptionPermission + arg2 utils.TimedVersion + arg3 func(participantID livekit.ParticipantID) types.LocalParticipant + }{arg1, arg2, arg3}) + stub := fake.UpdateSubscriptionPermissionStub + fakeReturns := fake.updateSubscriptionPermissionReturns + fake.recordInvocation("UpdateSubscriptionPermission", []interface{}{arg1, arg2, arg3}) + fake.updateSubscriptionPermissionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionCallCount() int { + fake.updateSubscriptionPermissionMutex.RLock() + defer fake.updateSubscriptionPermissionMutex.RUnlock() + return len(fake.updateSubscriptionPermissionArgsForCall) +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionCalls(stub func(*livekit.SubscriptionPermission, utils.TimedVersion, func(participantID livekit.ParticipantID) types.LocalParticipant) error) { + fake.updateSubscriptionPermissionMutex.Lock() + defer fake.updateSubscriptionPermissionMutex.Unlock() + fake.UpdateSubscriptionPermissionStub = stub +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionArgsForCall(i int) (*livekit.SubscriptionPermission, utils.TimedVersion, func(participantID livekit.ParticipantID) types.LocalParticipant) { + fake.updateSubscriptionPermissionMutex.RLock() + defer fake.updateSubscriptionPermissionMutex.RUnlock() + argsForCall := fake.updateSubscriptionPermissionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionReturns(result1 error) { + fake.updateSubscriptionPermissionMutex.Lock() + defer fake.updateSubscriptionPermissionMutex.Unlock() + fake.UpdateSubscriptionPermissionStub = nil + fake.updateSubscriptionPermissionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeParticipant) UpdateSubscriptionPermissionReturnsOnCall(i int, result1 error) { + fake.updateSubscriptionPermissionMutex.Lock() + defer fake.updateSubscriptionPermissionMutex.Unlock() + fake.UpdateSubscriptionPermissionStub = nil + if fake.updateSubscriptionPermissionReturnsOnCall == nil { + fake.updateSubscriptionPermissionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateSubscriptionPermissionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeParticipant) Version() utils.TimedVersion { + fake.versionMutex.Lock() + ret, specificReturn := fake.versionReturnsOnCall[len(fake.versionArgsForCall)] + fake.versionArgsForCall = append(fake.versionArgsForCall, struct { + }{}) + stub := fake.VersionStub + fakeReturns := fake.versionReturns + fake.recordInvocation("Version", []interface{}{}) + fake.versionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeParticipant) VersionCallCount() int { + fake.versionMutex.RLock() + defer fake.versionMutex.RUnlock() + return len(fake.versionArgsForCall) +} + +func (fake *FakeParticipant) VersionCalls(stub func() utils.TimedVersion) { + fake.versionMutex.Lock() + defer fake.versionMutex.Unlock() + fake.VersionStub = stub +} + +func (fake *FakeParticipant) VersionReturns(result1 utils.TimedVersion) { + fake.versionMutex.Lock() + defer fake.versionMutex.Unlock() + fake.VersionStub = nil + fake.versionReturns = struct { + result1 utils.TimedVersion + }{result1} +} + +func (fake *FakeParticipant) VersionReturnsOnCall(i int, result1 utils.TimedVersion) { + fake.versionMutex.Lock() + defer fake.versionMutex.Unlock() + fake.VersionStub = nil + if fake.versionReturnsOnCall == nil { + fake.versionReturnsOnCall = make(map[int]struct { + result1 utils.TimedVersion + }) + } + fake.versionReturnsOnCall[i] = struct { + result1 utils.TimedVersion + }{result1} +} + +func (fake *FakeParticipant) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeParticipant) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.Participant = new(FakeParticipant) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_participant_listener.go b/livekit/pkg/rtc/types/typesfakes/fake_participant_listener.go new file mode 100644 index 0000000..a05a3e0 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_participant_listener.go @@ -0,0 +1,356 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeParticipantListener struct { + OnDataTrackMessageStub func(types.Participant, []byte, *datatrack.Packet) + onDataTrackMessageMutex sync.RWMutex + onDataTrackMessageArgsForCall []struct { + arg1 types.Participant + arg2 []byte + arg3 *datatrack.Packet + } + OnDataTrackPublishedStub func(types.Participant, types.DataTrack) + onDataTrackPublishedMutex sync.RWMutex + onDataTrackPublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.DataTrack + } + OnDataTrackUnpublishedStub func(types.Participant, types.DataTrack) + onDataTrackUnpublishedMutex sync.RWMutex + onDataTrackUnpublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.DataTrack + } + OnMetricsStub func(types.Participant, *livekit.DataPacket) + onMetricsMutex sync.RWMutex + onMetricsArgsForCall []struct { + arg1 types.Participant + arg2 *livekit.DataPacket + } + OnParticipantUpdateStub func(types.Participant) + onParticipantUpdateMutex sync.RWMutex + onParticipantUpdateArgsForCall []struct { + arg1 types.Participant + } + OnTrackPublishedStub func(types.Participant, types.MediaTrack) + onTrackPublishedMutex sync.RWMutex + onTrackPublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.MediaTrack + } + OnTrackUnpublishedStub func(types.Participant, types.MediaTrack) + onTrackUnpublishedMutex sync.RWMutex + onTrackUnpublishedArgsForCall []struct { + arg1 types.Participant + arg2 types.MediaTrack + } + OnTrackUpdatedStub func(types.Participant, types.MediaTrack) + onTrackUpdatedMutex sync.RWMutex + onTrackUpdatedArgsForCall []struct { + arg1 types.Participant + arg2 types.MediaTrack + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeParticipantListener) OnDataTrackMessage(arg1 types.Participant, arg2 []byte, arg3 *datatrack.Packet) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.onDataTrackMessageMutex.Lock() + fake.onDataTrackMessageArgsForCall = append(fake.onDataTrackMessageArgsForCall, struct { + arg1 types.Participant + arg2 []byte + arg3 *datatrack.Packet + }{arg1, arg2Copy, arg3}) + stub := fake.OnDataTrackMessageStub + fake.recordInvocation("OnDataTrackMessage", []interface{}{arg1, arg2Copy, arg3}) + fake.onDataTrackMessageMutex.Unlock() + if stub != nil { + fake.OnDataTrackMessageStub(arg1, arg2, arg3) + } +} + +func (fake *FakeParticipantListener) OnDataTrackMessageCallCount() int { + fake.onDataTrackMessageMutex.RLock() + defer fake.onDataTrackMessageMutex.RUnlock() + return len(fake.onDataTrackMessageArgsForCall) +} + +func (fake *FakeParticipantListener) OnDataTrackMessageCalls(stub func(types.Participant, []byte, *datatrack.Packet)) { + fake.onDataTrackMessageMutex.Lock() + defer fake.onDataTrackMessageMutex.Unlock() + fake.OnDataTrackMessageStub = stub +} + +func (fake *FakeParticipantListener) OnDataTrackMessageArgsForCall(i int) (types.Participant, []byte, *datatrack.Packet) { + fake.onDataTrackMessageMutex.RLock() + defer fake.onDataTrackMessageMutex.RUnlock() + argsForCall := fake.onDataTrackMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeParticipantListener) OnDataTrackPublished(arg1 types.Participant, arg2 types.DataTrack) { + fake.onDataTrackPublishedMutex.Lock() + fake.onDataTrackPublishedArgsForCall = append(fake.onDataTrackPublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.DataTrack + }{arg1, arg2}) + stub := fake.OnDataTrackPublishedStub + fake.recordInvocation("OnDataTrackPublished", []interface{}{arg1, arg2}) + fake.onDataTrackPublishedMutex.Unlock() + if stub != nil { + fake.OnDataTrackPublishedStub(arg1, arg2) + } +} + +func (fake *FakeParticipantListener) OnDataTrackPublishedCallCount() int { + fake.onDataTrackPublishedMutex.RLock() + defer fake.onDataTrackPublishedMutex.RUnlock() + return len(fake.onDataTrackPublishedArgsForCall) +} + +func (fake *FakeParticipantListener) OnDataTrackPublishedCalls(stub func(types.Participant, types.DataTrack)) { + fake.onDataTrackPublishedMutex.Lock() + defer fake.onDataTrackPublishedMutex.Unlock() + fake.OnDataTrackPublishedStub = stub +} + +func (fake *FakeParticipantListener) OnDataTrackPublishedArgsForCall(i int) (types.Participant, types.DataTrack) { + fake.onDataTrackPublishedMutex.RLock() + defer fake.onDataTrackPublishedMutex.RUnlock() + argsForCall := fake.onDataTrackPublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipantListener) OnDataTrackUnpublished(arg1 types.Participant, arg2 types.DataTrack) { + fake.onDataTrackUnpublishedMutex.Lock() + fake.onDataTrackUnpublishedArgsForCall = append(fake.onDataTrackUnpublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.DataTrack + }{arg1, arg2}) + stub := fake.OnDataTrackUnpublishedStub + fake.recordInvocation("OnDataTrackUnpublished", []interface{}{arg1, arg2}) + fake.onDataTrackUnpublishedMutex.Unlock() + if stub != nil { + fake.OnDataTrackUnpublishedStub(arg1, arg2) + } +} + +func (fake *FakeParticipantListener) OnDataTrackUnpublishedCallCount() int { + fake.onDataTrackUnpublishedMutex.RLock() + defer fake.onDataTrackUnpublishedMutex.RUnlock() + return len(fake.onDataTrackUnpublishedArgsForCall) +} + +func (fake *FakeParticipantListener) OnDataTrackUnpublishedCalls(stub func(types.Participant, types.DataTrack)) { + fake.onDataTrackUnpublishedMutex.Lock() + defer fake.onDataTrackUnpublishedMutex.Unlock() + fake.OnDataTrackUnpublishedStub = stub +} + +func (fake *FakeParticipantListener) OnDataTrackUnpublishedArgsForCall(i int) (types.Participant, types.DataTrack) { + fake.onDataTrackUnpublishedMutex.RLock() + defer fake.onDataTrackUnpublishedMutex.RUnlock() + argsForCall := fake.onDataTrackUnpublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipantListener) OnMetrics(arg1 types.Participant, arg2 *livekit.DataPacket) { + fake.onMetricsMutex.Lock() + fake.onMetricsArgsForCall = append(fake.onMetricsArgsForCall, struct { + arg1 types.Participant + arg2 *livekit.DataPacket + }{arg1, arg2}) + stub := fake.OnMetricsStub + fake.recordInvocation("OnMetrics", []interface{}{arg1, arg2}) + fake.onMetricsMutex.Unlock() + if stub != nil { + fake.OnMetricsStub(arg1, arg2) + } +} + +func (fake *FakeParticipantListener) OnMetricsCallCount() int { + fake.onMetricsMutex.RLock() + defer fake.onMetricsMutex.RUnlock() + return len(fake.onMetricsArgsForCall) +} + +func (fake *FakeParticipantListener) OnMetricsCalls(stub func(types.Participant, *livekit.DataPacket)) { + fake.onMetricsMutex.Lock() + defer fake.onMetricsMutex.Unlock() + fake.OnMetricsStub = stub +} + +func (fake *FakeParticipantListener) OnMetricsArgsForCall(i int) (types.Participant, *livekit.DataPacket) { + fake.onMetricsMutex.RLock() + defer fake.onMetricsMutex.RUnlock() + argsForCall := fake.onMetricsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipantListener) OnParticipantUpdate(arg1 types.Participant) { + fake.onParticipantUpdateMutex.Lock() + fake.onParticipantUpdateArgsForCall = append(fake.onParticipantUpdateArgsForCall, struct { + arg1 types.Participant + }{arg1}) + stub := fake.OnParticipantUpdateStub + fake.recordInvocation("OnParticipantUpdate", []interface{}{arg1}) + fake.onParticipantUpdateMutex.Unlock() + if stub != nil { + fake.OnParticipantUpdateStub(arg1) + } +} + +func (fake *FakeParticipantListener) OnParticipantUpdateCallCount() int { + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() + return len(fake.onParticipantUpdateArgsForCall) +} + +func (fake *FakeParticipantListener) OnParticipantUpdateCalls(stub func(types.Participant)) { + fake.onParticipantUpdateMutex.Lock() + defer fake.onParticipantUpdateMutex.Unlock() + fake.OnParticipantUpdateStub = stub +} + +func (fake *FakeParticipantListener) OnParticipantUpdateArgsForCall(i int) types.Participant { + fake.onParticipantUpdateMutex.RLock() + defer fake.onParticipantUpdateMutex.RUnlock() + argsForCall := fake.onParticipantUpdateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeParticipantListener) OnTrackPublished(arg1 types.Participant, arg2 types.MediaTrack) { + fake.onTrackPublishedMutex.Lock() + fake.onTrackPublishedArgsForCall = append(fake.onTrackPublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.MediaTrack + }{arg1, arg2}) + stub := fake.OnTrackPublishedStub + fake.recordInvocation("OnTrackPublished", []interface{}{arg1, arg2}) + fake.onTrackPublishedMutex.Unlock() + if stub != nil { + fake.OnTrackPublishedStub(arg1, arg2) + } +} + +func (fake *FakeParticipantListener) OnTrackPublishedCallCount() int { + fake.onTrackPublishedMutex.RLock() + defer fake.onTrackPublishedMutex.RUnlock() + return len(fake.onTrackPublishedArgsForCall) +} + +func (fake *FakeParticipantListener) OnTrackPublishedCalls(stub func(types.Participant, types.MediaTrack)) { + fake.onTrackPublishedMutex.Lock() + defer fake.onTrackPublishedMutex.Unlock() + fake.OnTrackPublishedStub = stub +} + +func (fake *FakeParticipantListener) OnTrackPublishedArgsForCall(i int) (types.Participant, types.MediaTrack) { + fake.onTrackPublishedMutex.RLock() + defer fake.onTrackPublishedMutex.RUnlock() + argsForCall := fake.onTrackPublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipantListener) OnTrackUnpublished(arg1 types.Participant, arg2 types.MediaTrack) { + fake.onTrackUnpublishedMutex.Lock() + fake.onTrackUnpublishedArgsForCall = append(fake.onTrackUnpublishedArgsForCall, struct { + arg1 types.Participant + arg2 types.MediaTrack + }{arg1, arg2}) + stub := fake.OnTrackUnpublishedStub + fake.recordInvocation("OnTrackUnpublished", []interface{}{arg1, arg2}) + fake.onTrackUnpublishedMutex.Unlock() + if stub != nil { + fake.OnTrackUnpublishedStub(arg1, arg2) + } +} + +func (fake *FakeParticipantListener) OnTrackUnpublishedCallCount() int { + fake.onTrackUnpublishedMutex.RLock() + defer fake.onTrackUnpublishedMutex.RUnlock() + return len(fake.onTrackUnpublishedArgsForCall) +} + +func (fake *FakeParticipantListener) OnTrackUnpublishedCalls(stub func(types.Participant, types.MediaTrack)) { + fake.onTrackUnpublishedMutex.Lock() + defer fake.onTrackUnpublishedMutex.Unlock() + fake.OnTrackUnpublishedStub = stub +} + +func (fake *FakeParticipantListener) OnTrackUnpublishedArgsForCall(i int) (types.Participant, types.MediaTrack) { + fake.onTrackUnpublishedMutex.RLock() + defer fake.onTrackUnpublishedMutex.RUnlock() + argsForCall := fake.onTrackUnpublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipantListener) OnTrackUpdated(arg1 types.Participant, arg2 types.MediaTrack) { + fake.onTrackUpdatedMutex.Lock() + fake.onTrackUpdatedArgsForCall = append(fake.onTrackUpdatedArgsForCall, struct { + arg1 types.Participant + arg2 types.MediaTrack + }{arg1, arg2}) + stub := fake.OnTrackUpdatedStub + fake.recordInvocation("OnTrackUpdated", []interface{}{arg1, arg2}) + fake.onTrackUpdatedMutex.Unlock() + if stub != nil { + fake.OnTrackUpdatedStub(arg1, arg2) + } +} + +func (fake *FakeParticipantListener) OnTrackUpdatedCallCount() int { + fake.onTrackUpdatedMutex.RLock() + defer fake.onTrackUpdatedMutex.RUnlock() + return len(fake.onTrackUpdatedArgsForCall) +} + +func (fake *FakeParticipantListener) OnTrackUpdatedCalls(stub func(types.Participant, types.MediaTrack)) { + fake.onTrackUpdatedMutex.Lock() + defer fake.onTrackUpdatedMutex.Unlock() + fake.OnTrackUpdatedStub = stub +} + +func (fake *FakeParticipantListener) OnTrackUpdatedArgsForCall(i int) (types.Participant, types.MediaTrack) { + fake.onTrackUpdatedMutex.RLock() + defer fake.onTrackUpdatedMutex.RUnlock() + argsForCall := fake.onTrackUpdatedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeParticipantListener) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeParticipantListener) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.ParticipantListener = new(FakeParticipantListener) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_room.go b/livekit/pkg/rtc/types/typesfakes/fake_room.go new file mode 100644 index 0000000..292fad0 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_room.go @@ -0,0 +1,541 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" +) + +type FakeRoom struct { + GetLocalParticipantsStub func() []types.LocalParticipant + getLocalParticipantsMutex sync.RWMutex + getLocalParticipantsArgsForCall []struct { + } + getLocalParticipantsReturns struct { + result1 []types.LocalParticipant + } + getLocalParticipantsReturnsOnCall map[int]struct { + result1 []types.LocalParticipant + } + IDStub func() livekit.RoomID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.RoomID + } + iDReturnsOnCall map[int]struct { + result1 livekit.RoomID + } + IsDataMessageUserPacketDuplicateStub func(*livekit.UserPacket) bool + isDataMessageUserPacketDuplicateMutex sync.RWMutex + isDataMessageUserPacketDuplicateArgsForCall []struct { + arg1 *livekit.UserPacket + } + isDataMessageUserPacketDuplicateReturns struct { + result1 bool + } + isDataMessageUserPacketDuplicateReturnsOnCall map[int]struct { + result1 bool + } + NameStub func() livekit.RoomName + nameMutex sync.RWMutex + nameArgsForCall []struct { + } + nameReturns struct { + result1 livekit.RoomName + } + nameReturnsOnCall map[int]struct { + result1 livekit.RoomName + } + RemoveParticipantStub func(livekit.ParticipantIdentity, livekit.ParticipantID, types.ParticipantCloseReason) + removeParticipantMutex sync.RWMutex + removeParticipantArgsForCall []struct { + arg1 livekit.ParticipantIdentity + arg2 livekit.ParticipantID + arg3 types.ParticipantCloseReason + } + ResolveDataTrackForSubscriberStub func(types.LocalParticipant, livekit.TrackID) types.DataResolverResult + resolveDataTrackForSubscriberMutex sync.RWMutex + resolveDataTrackForSubscriberArgsForCall []struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + } + resolveDataTrackForSubscriberReturns struct { + result1 types.DataResolverResult + } + resolveDataTrackForSubscriberReturnsOnCall map[int]struct { + result1 types.DataResolverResult + } + ResolveMediaTrackForSubscriberStub func(types.LocalParticipant, livekit.TrackID) types.MediaResolverResult + resolveMediaTrackForSubscriberMutex sync.RWMutex + resolveMediaTrackForSubscriberArgsForCall []struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + } + resolveMediaTrackForSubscriberReturns struct { + result1 types.MediaResolverResult + } + resolveMediaTrackForSubscriberReturnsOnCall map[int]struct { + result1 types.MediaResolverResult + } + UpdateSubscriptionsStub func(types.LocalParticipant, []livekit.TrackID, []*livekit.ParticipantTracks, bool) + updateSubscriptionsMutex sync.RWMutex + updateSubscriptionsArgsForCall []struct { + arg1 types.LocalParticipant + arg2 []livekit.TrackID + arg3 []*livekit.ParticipantTracks + arg4 bool + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRoom) GetLocalParticipants() []types.LocalParticipant { + fake.getLocalParticipantsMutex.Lock() + ret, specificReturn := fake.getLocalParticipantsReturnsOnCall[len(fake.getLocalParticipantsArgsForCall)] + fake.getLocalParticipantsArgsForCall = append(fake.getLocalParticipantsArgsForCall, struct { + }{}) + stub := fake.GetLocalParticipantsStub + fakeReturns := fake.getLocalParticipantsReturns + fake.recordInvocation("GetLocalParticipants", []interface{}{}) + fake.getLocalParticipantsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) GetLocalParticipantsCallCount() int { + fake.getLocalParticipantsMutex.RLock() + defer fake.getLocalParticipantsMutex.RUnlock() + return len(fake.getLocalParticipantsArgsForCall) +} + +func (fake *FakeRoom) GetLocalParticipantsCalls(stub func() []types.LocalParticipant) { + fake.getLocalParticipantsMutex.Lock() + defer fake.getLocalParticipantsMutex.Unlock() + fake.GetLocalParticipantsStub = stub +} + +func (fake *FakeRoom) GetLocalParticipantsReturns(result1 []types.LocalParticipant) { + fake.getLocalParticipantsMutex.Lock() + defer fake.getLocalParticipantsMutex.Unlock() + fake.GetLocalParticipantsStub = nil + fake.getLocalParticipantsReturns = struct { + result1 []types.LocalParticipant + }{result1} +} + +func (fake *FakeRoom) GetLocalParticipantsReturnsOnCall(i int, result1 []types.LocalParticipant) { + fake.getLocalParticipantsMutex.Lock() + defer fake.getLocalParticipantsMutex.Unlock() + fake.GetLocalParticipantsStub = nil + if fake.getLocalParticipantsReturnsOnCall == nil { + fake.getLocalParticipantsReturnsOnCall = make(map[int]struct { + result1 []types.LocalParticipant + }) + } + fake.getLocalParticipantsReturnsOnCall[i] = struct { + result1 []types.LocalParticipant + }{result1} +} + +func (fake *FakeRoom) ID() livekit.RoomID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeRoom) IDCalls(stub func() livekit.RoomID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeRoom) IDReturns(result1 livekit.RoomID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.RoomID + }{result1} +} + +func (fake *FakeRoom) IDReturnsOnCall(i int, result1 livekit.RoomID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.RoomID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.RoomID + }{result1} +} + +func (fake *FakeRoom) IsDataMessageUserPacketDuplicate(arg1 *livekit.UserPacket) bool { + fake.isDataMessageUserPacketDuplicateMutex.Lock() + ret, specificReturn := fake.isDataMessageUserPacketDuplicateReturnsOnCall[len(fake.isDataMessageUserPacketDuplicateArgsForCall)] + fake.isDataMessageUserPacketDuplicateArgsForCall = append(fake.isDataMessageUserPacketDuplicateArgsForCall, struct { + arg1 *livekit.UserPacket + }{arg1}) + stub := fake.IsDataMessageUserPacketDuplicateStub + fakeReturns := fake.isDataMessageUserPacketDuplicateReturns + fake.recordInvocation("IsDataMessageUserPacketDuplicate", []interface{}{arg1}) + fake.isDataMessageUserPacketDuplicateMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) IsDataMessageUserPacketDuplicateCallCount() int { + fake.isDataMessageUserPacketDuplicateMutex.RLock() + defer fake.isDataMessageUserPacketDuplicateMutex.RUnlock() + return len(fake.isDataMessageUserPacketDuplicateArgsForCall) +} + +func (fake *FakeRoom) IsDataMessageUserPacketDuplicateCalls(stub func(*livekit.UserPacket) bool) { + fake.isDataMessageUserPacketDuplicateMutex.Lock() + defer fake.isDataMessageUserPacketDuplicateMutex.Unlock() + fake.IsDataMessageUserPacketDuplicateStub = stub +} + +func (fake *FakeRoom) IsDataMessageUserPacketDuplicateArgsForCall(i int) *livekit.UserPacket { + fake.isDataMessageUserPacketDuplicateMutex.RLock() + defer fake.isDataMessageUserPacketDuplicateMutex.RUnlock() + argsForCall := fake.isDataMessageUserPacketDuplicateArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRoom) IsDataMessageUserPacketDuplicateReturns(result1 bool) { + fake.isDataMessageUserPacketDuplicateMutex.Lock() + defer fake.isDataMessageUserPacketDuplicateMutex.Unlock() + fake.IsDataMessageUserPacketDuplicateStub = nil + fake.isDataMessageUserPacketDuplicateReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeRoom) IsDataMessageUserPacketDuplicateReturnsOnCall(i int, result1 bool) { + fake.isDataMessageUserPacketDuplicateMutex.Lock() + defer fake.isDataMessageUserPacketDuplicateMutex.Unlock() + fake.IsDataMessageUserPacketDuplicateStub = nil + if fake.isDataMessageUserPacketDuplicateReturnsOnCall == nil { + fake.isDataMessageUserPacketDuplicateReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isDataMessageUserPacketDuplicateReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeRoom) Name() livekit.RoomName { + fake.nameMutex.Lock() + ret, specificReturn := fake.nameReturnsOnCall[len(fake.nameArgsForCall)] + fake.nameArgsForCall = append(fake.nameArgsForCall, struct { + }{}) + stub := fake.NameStub + fakeReturns := fake.nameReturns + fake.recordInvocation("Name", []interface{}{}) + fake.nameMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) NameCallCount() int { + fake.nameMutex.RLock() + defer fake.nameMutex.RUnlock() + return len(fake.nameArgsForCall) +} + +func (fake *FakeRoom) NameCalls(stub func() livekit.RoomName) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = stub +} + +func (fake *FakeRoom) NameReturns(result1 livekit.RoomName) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + fake.nameReturns = struct { + result1 livekit.RoomName + }{result1} +} + +func (fake *FakeRoom) NameReturnsOnCall(i int, result1 livekit.RoomName) { + fake.nameMutex.Lock() + defer fake.nameMutex.Unlock() + fake.NameStub = nil + if fake.nameReturnsOnCall == nil { + fake.nameReturnsOnCall = make(map[int]struct { + result1 livekit.RoomName + }) + } + fake.nameReturnsOnCall[i] = struct { + result1 livekit.RoomName + }{result1} +} + +func (fake *FakeRoom) RemoveParticipant(arg1 livekit.ParticipantIdentity, arg2 livekit.ParticipantID, arg3 types.ParticipantCloseReason) { + fake.removeParticipantMutex.Lock() + fake.removeParticipantArgsForCall = append(fake.removeParticipantArgsForCall, struct { + arg1 livekit.ParticipantIdentity + arg2 livekit.ParticipantID + arg3 types.ParticipantCloseReason + }{arg1, arg2, arg3}) + stub := fake.RemoveParticipantStub + fake.recordInvocation("RemoveParticipant", []interface{}{arg1, arg2, arg3}) + fake.removeParticipantMutex.Unlock() + if stub != nil { + fake.RemoveParticipantStub(arg1, arg2, arg3) + } +} + +func (fake *FakeRoom) RemoveParticipantCallCount() int { + fake.removeParticipantMutex.RLock() + defer fake.removeParticipantMutex.RUnlock() + return len(fake.removeParticipantArgsForCall) +} + +func (fake *FakeRoom) RemoveParticipantCalls(stub func(livekit.ParticipantIdentity, livekit.ParticipantID, types.ParticipantCloseReason)) { + fake.removeParticipantMutex.Lock() + defer fake.removeParticipantMutex.Unlock() + fake.RemoveParticipantStub = stub +} + +func (fake *FakeRoom) RemoveParticipantArgsForCall(i int) (livekit.ParticipantIdentity, livekit.ParticipantID, types.ParticipantCloseReason) { + fake.removeParticipantMutex.RLock() + defer fake.removeParticipantMutex.RUnlock() + argsForCall := fake.removeParticipantArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRoom) ResolveDataTrackForSubscriber(arg1 types.LocalParticipant, arg2 livekit.TrackID) types.DataResolverResult { + fake.resolveDataTrackForSubscriberMutex.Lock() + ret, specificReturn := fake.resolveDataTrackForSubscriberReturnsOnCall[len(fake.resolveDataTrackForSubscriberArgsForCall)] + fake.resolveDataTrackForSubscriberArgsForCall = append(fake.resolveDataTrackForSubscriberArgsForCall, struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + }{arg1, arg2}) + stub := fake.ResolveDataTrackForSubscriberStub + fakeReturns := fake.resolveDataTrackForSubscriberReturns + fake.recordInvocation("ResolveDataTrackForSubscriber", []interface{}{arg1, arg2}) + fake.resolveDataTrackForSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) ResolveDataTrackForSubscriberCallCount() int { + fake.resolveDataTrackForSubscriberMutex.RLock() + defer fake.resolveDataTrackForSubscriberMutex.RUnlock() + return len(fake.resolveDataTrackForSubscriberArgsForCall) +} + +func (fake *FakeRoom) ResolveDataTrackForSubscriberCalls(stub func(types.LocalParticipant, livekit.TrackID) types.DataResolverResult) { + fake.resolveDataTrackForSubscriberMutex.Lock() + defer fake.resolveDataTrackForSubscriberMutex.Unlock() + fake.ResolveDataTrackForSubscriberStub = stub +} + +func (fake *FakeRoom) ResolveDataTrackForSubscriberArgsForCall(i int) (types.LocalParticipant, livekit.TrackID) { + fake.resolveDataTrackForSubscriberMutex.RLock() + defer fake.resolveDataTrackForSubscriberMutex.RUnlock() + argsForCall := fake.resolveDataTrackForSubscriberArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRoom) ResolveDataTrackForSubscriberReturns(result1 types.DataResolverResult) { + fake.resolveDataTrackForSubscriberMutex.Lock() + defer fake.resolveDataTrackForSubscriberMutex.Unlock() + fake.ResolveDataTrackForSubscriberStub = nil + fake.resolveDataTrackForSubscriberReturns = struct { + result1 types.DataResolverResult + }{result1} +} + +func (fake *FakeRoom) ResolveDataTrackForSubscriberReturnsOnCall(i int, result1 types.DataResolverResult) { + fake.resolveDataTrackForSubscriberMutex.Lock() + defer fake.resolveDataTrackForSubscriberMutex.Unlock() + fake.ResolveDataTrackForSubscriberStub = nil + if fake.resolveDataTrackForSubscriberReturnsOnCall == nil { + fake.resolveDataTrackForSubscriberReturnsOnCall = make(map[int]struct { + result1 types.DataResolverResult + }) + } + fake.resolveDataTrackForSubscriberReturnsOnCall[i] = struct { + result1 types.DataResolverResult + }{result1} +} + +func (fake *FakeRoom) ResolveMediaTrackForSubscriber(arg1 types.LocalParticipant, arg2 livekit.TrackID) types.MediaResolverResult { + fake.resolveMediaTrackForSubscriberMutex.Lock() + ret, specificReturn := fake.resolveMediaTrackForSubscriberReturnsOnCall[len(fake.resolveMediaTrackForSubscriberArgsForCall)] + fake.resolveMediaTrackForSubscriberArgsForCall = append(fake.resolveMediaTrackForSubscriberArgsForCall, struct { + arg1 types.LocalParticipant + arg2 livekit.TrackID + }{arg1, arg2}) + stub := fake.ResolveMediaTrackForSubscriberStub + fakeReturns := fake.resolveMediaTrackForSubscriberReturns + fake.recordInvocation("ResolveMediaTrackForSubscriber", []interface{}{arg1, arg2}) + fake.resolveMediaTrackForSubscriberMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoom) ResolveMediaTrackForSubscriberCallCount() int { + fake.resolveMediaTrackForSubscriberMutex.RLock() + defer fake.resolveMediaTrackForSubscriberMutex.RUnlock() + return len(fake.resolveMediaTrackForSubscriberArgsForCall) +} + +func (fake *FakeRoom) ResolveMediaTrackForSubscriberCalls(stub func(types.LocalParticipant, livekit.TrackID) types.MediaResolverResult) { + fake.resolveMediaTrackForSubscriberMutex.Lock() + defer fake.resolveMediaTrackForSubscriberMutex.Unlock() + fake.ResolveMediaTrackForSubscriberStub = stub +} + +func (fake *FakeRoom) ResolveMediaTrackForSubscriberArgsForCall(i int) (types.LocalParticipant, livekit.TrackID) { + fake.resolveMediaTrackForSubscriberMutex.RLock() + defer fake.resolveMediaTrackForSubscriberMutex.RUnlock() + argsForCall := fake.resolveMediaTrackForSubscriberArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRoom) ResolveMediaTrackForSubscriberReturns(result1 types.MediaResolverResult) { + fake.resolveMediaTrackForSubscriberMutex.Lock() + defer fake.resolveMediaTrackForSubscriberMutex.Unlock() + fake.ResolveMediaTrackForSubscriberStub = nil + fake.resolveMediaTrackForSubscriberReturns = struct { + result1 types.MediaResolverResult + }{result1} +} + +func (fake *FakeRoom) ResolveMediaTrackForSubscriberReturnsOnCall(i int, result1 types.MediaResolverResult) { + fake.resolveMediaTrackForSubscriberMutex.Lock() + defer fake.resolveMediaTrackForSubscriberMutex.Unlock() + fake.ResolveMediaTrackForSubscriberStub = nil + if fake.resolveMediaTrackForSubscriberReturnsOnCall == nil { + fake.resolveMediaTrackForSubscriberReturnsOnCall = make(map[int]struct { + result1 types.MediaResolverResult + }) + } + fake.resolveMediaTrackForSubscriberReturnsOnCall[i] = struct { + result1 types.MediaResolverResult + }{result1} +} + +func (fake *FakeRoom) UpdateSubscriptions(arg1 types.LocalParticipant, arg2 []livekit.TrackID, arg3 []*livekit.ParticipantTracks, arg4 bool) { + var arg2Copy []livekit.TrackID + if arg2 != nil { + arg2Copy = make([]livekit.TrackID, len(arg2)) + copy(arg2Copy, arg2) + } + var arg3Copy []*livekit.ParticipantTracks + if arg3 != nil { + arg3Copy = make([]*livekit.ParticipantTracks, len(arg3)) + copy(arg3Copy, arg3) + } + fake.updateSubscriptionsMutex.Lock() + fake.updateSubscriptionsArgsForCall = append(fake.updateSubscriptionsArgsForCall, struct { + arg1 types.LocalParticipant + arg2 []livekit.TrackID + arg3 []*livekit.ParticipantTracks + arg4 bool + }{arg1, arg2Copy, arg3Copy, arg4}) + stub := fake.UpdateSubscriptionsStub + fake.recordInvocation("UpdateSubscriptions", []interface{}{arg1, arg2Copy, arg3Copy, arg4}) + fake.updateSubscriptionsMutex.Unlock() + if stub != nil { + fake.UpdateSubscriptionsStub(arg1, arg2, arg3, arg4) + } +} + +func (fake *FakeRoom) UpdateSubscriptionsCallCount() int { + fake.updateSubscriptionsMutex.RLock() + defer fake.updateSubscriptionsMutex.RUnlock() + return len(fake.updateSubscriptionsArgsForCall) +} + +func (fake *FakeRoom) UpdateSubscriptionsCalls(stub func(types.LocalParticipant, []livekit.TrackID, []*livekit.ParticipantTracks, bool)) { + fake.updateSubscriptionsMutex.Lock() + defer fake.updateSubscriptionsMutex.Unlock() + fake.UpdateSubscriptionsStub = stub +} + +func (fake *FakeRoom) UpdateSubscriptionsArgsForCall(i int) (types.LocalParticipant, []livekit.TrackID, []*livekit.ParticipantTracks, bool) { + fake.updateSubscriptionsMutex.RLock() + defer fake.updateSubscriptionsMutex.RUnlock() + argsForCall := fake.updateSubscriptionsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeRoom) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRoom) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.Room = new(FakeRoom) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_subscribed_track.go b/livekit/pkg/rtc/types/typesfakes/fake_subscribed_track.go new file mode 100644 index 0000000..2123548 --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_subscribed_track.go @@ -0,0 +1,1074 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/protocol/livekit" + webrtc "github.com/pion/webrtc/v4" +) + +type FakeSubscribedTrack struct { + AddOnBindStub func(func(error)) + addOnBindMutex sync.RWMutex + addOnBindArgsForCall []struct { + arg1 func(error) + } + CloseStub func(bool) + closeMutex sync.RWMutex + closeArgsForCall []struct { + arg1 bool + } + DownTrackStub func() *sfu.DownTrack + downTrackMutex sync.RWMutex + downTrackArgsForCall []struct { + } + downTrackReturns struct { + result1 *sfu.DownTrack + } + downTrackReturnsOnCall map[int]struct { + result1 *sfu.DownTrack + } + IDStub func() livekit.TrackID + iDMutex sync.RWMutex + iDArgsForCall []struct { + } + iDReturns struct { + result1 livekit.TrackID + } + iDReturnsOnCall map[int]struct { + result1 livekit.TrackID + } + IsBoundStub func() bool + isBoundMutex sync.RWMutex + isBoundArgsForCall []struct { + } + isBoundReturns struct { + result1 bool + } + isBoundReturnsOnCall map[int]struct { + result1 bool + } + IsMutedStub func() bool + isMutedMutex sync.RWMutex + isMutedArgsForCall []struct { + } + isMutedReturns struct { + result1 bool + } + isMutedReturnsOnCall map[int]struct { + result1 bool + } + MediaTrackStub func() types.MediaTrack + mediaTrackMutex sync.RWMutex + mediaTrackArgsForCall []struct { + } + mediaTrackReturns struct { + result1 types.MediaTrack + } + mediaTrackReturnsOnCall map[int]struct { + result1 types.MediaTrack + } + NeedsNegotiationStub func() bool + needsNegotiationMutex sync.RWMutex + needsNegotiationArgsForCall []struct { + } + needsNegotiationReturns struct { + result1 bool + } + needsNegotiationReturnsOnCall map[int]struct { + result1 bool + } + OnCloseStub func(func(isExpectedToResume bool)) + onCloseMutex sync.RWMutex + onCloseArgsForCall []struct { + arg1 func(isExpectedToResume bool) + } + PublisherIDStub func() livekit.ParticipantID + publisherIDMutex sync.RWMutex + publisherIDArgsForCall []struct { + } + publisherIDReturns struct { + result1 livekit.ParticipantID + } + publisherIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + PublisherIdentityStub func() livekit.ParticipantIdentity + publisherIdentityMutex sync.RWMutex + publisherIdentityArgsForCall []struct { + } + publisherIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + publisherIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + PublisherVersionStub func() uint32 + publisherVersionMutex sync.RWMutex + publisherVersionArgsForCall []struct { + } + publisherVersionReturns struct { + result1 uint32 + } + publisherVersionReturnsOnCall map[int]struct { + result1 uint32 + } + RTPSenderStub func() *webrtc.RTPSender + rTPSenderMutex sync.RWMutex + rTPSenderArgsForCall []struct { + } + rTPSenderReturns struct { + result1 *webrtc.RTPSender + } + rTPSenderReturnsOnCall map[int]struct { + result1 *webrtc.RTPSender + } + SetPublisherMutedStub func(bool) + setPublisherMutedMutex sync.RWMutex + setPublisherMutedArgsForCall []struct { + arg1 bool + } + SubscriberStub func() types.LocalParticipant + subscriberMutex sync.RWMutex + subscriberArgsForCall []struct { + } + subscriberReturns struct { + result1 types.LocalParticipant + } + subscriberReturnsOnCall map[int]struct { + result1 types.LocalParticipant + } + SubscriberIDStub func() livekit.ParticipantID + subscriberIDMutex sync.RWMutex + subscriberIDArgsForCall []struct { + } + subscriberIDReturns struct { + result1 livekit.ParticipantID + } + subscriberIDReturnsOnCall map[int]struct { + result1 livekit.ParticipantID + } + SubscriberIdentityStub func() livekit.ParticipantIdentity + subscriberIdentityMutex sync.RWMutex + subscriberIdentityArgsForCall []struct { + } + subscriberIdentityReturns struct { + result1 livekit.ParticipantIdentity + } + subscriberIdentityReturnsOnCall map[int]struct { + result1 livekit.ParticipantIdentity + } + UpdateSubscriberSettingsStub func(*livekit.UpdateTrackSettings, bool) + updateSubscriberSettingsMutex sync.RWMutex + updateSubscriberSettingsArgsForCall []struct { + arg1 *livekit.UpdateTrackSettings + arg2 bool + } + UpdateVideoLayerStub func() + updateVideoLayerMutex sync.RWMutex + updateVideoLayerArgsForCall []struct { + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSubscribedTrack) AddOnBind(arg1 func(error)) { + fake.addOnBindMutex.Lock() + fake.addOnBindArgsForCall = append(fake.addOnBindArgsForCall, struct { + arg1 func(error) + }{arg1}) + stub := fake.AddOnBindStub + fake.recordInvocation("AddOnBind", []interface{}{arg1}) + fake.addOnBindMutex.Unlock() + if stub != nil { + fake.AddOnBindStub(arg1) + } +} + +func (fake *FakeSubscribedTrack) AddOnBindCallCount() int { + fake.addOnBindMutex.RLock() + defer fake.addOnBindMutex.RUnlock() + return len(fake.addOnBindArgsForCall) +} + +func (fake *FakeSubscribedTrack) AddOnBindCalls(stub func(func(error))) { + fake.addOnBindMutex.Lock() + defer fake.addOnBindMutex.Unlock() + fake.AddOnBindStub = stub +} + +func (fake *FakeSubscribedTrack) AddOnBindArgsForCall(i int) func(error) { + fake.addOnBindMutex.RLock() + defer fake.addOnBindMutex.RUnlock() + argsForCall := fake.addOnBindArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSubscribedTrack) Close(arg1 bool) { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.CloseStub + fake.recordInvocation("Close", []interface{}{arg1}) + fake.closeMutex.Unlock() + if stub != nil { + fake.CloseStub(arg1) + } +} + +func (fake *FakeSubscribedTrack) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeSubscribedTrack) CloseCalls(stub func(bool)) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeSubscribedTrack) CloseArgsForCall(i int) bool { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + argsForCall := fake.closeArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSubscribedTrack) DownTrack() *sfu.DownTrack { + fake.downTrackMutex.Lock() + ret, specificReturn := fake.downTrackReturnsOnCall[len(fake.downTrackArgsForCall)] + fake.downTrackArgsForCall = append(fake.downTrackArgsForCall, struct { + }{}) + stub := fake.DownTrackStub + fakeReturns := fake.downTrackReturns + fake.recordInvocation("DownTrack", []interface{}{}) + fake.downTrackMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) DownTrackCallCount() int { + fake.downTrackMutex.RLock() + defer fake.downTrackMutex.RUnlock() + return len(fake.downTrackArgsForCall) +} + +func (fake *FakeSubscribedTrack) DownTrackCalls(stub func() *sfu.DownTrack) { + fake.downTrackMutex.Lock() + defer fake.downTrackMutex.Unlock() + fake.DownTrackStub = stub +} + +func (fake *FakeSubscribedTrack) DownTrackReturns(result1 *sfu.DownTrack) { + fake.downTrackMutex.Lock() + defer fake.downTrackMutex.Unlock() + fake.DownTrackStub = nil + fake.downTrackReturns = struct { + result1 *sfu.DownTrack + }{result1} +} + +func (fake *FakeSubscribedTrack) DownTrackReturnsOnCall(i int, result1 *sfu.DownTrack) { + fake.downTrackMutex.Lock() + defer fake.downTrackMutex.Unlock() + fake.DownTrackStub = nil + if fake.downTrackReturnsOnCall == nil { + fake.downTrackReturnsOnCall = make(map[int]struct { + result1 *sfu.DownTrack + }) + } + fake.downTrackReturnsOnCall[i] = struct { + result1 *sfu.DownTrack + }{result1} +} + +func (fake *FakeSubscribedTrack) ID() livekit.TrackID { + fake.iDMutex.Lock() + ret, specificReturn := fake.iDReturnsOnCall[len(fake.iDArgsForCall)] + fake.iDArgsForCall = append(fake.iDArgsForCall, struct { + }{}) + stub := fake.IDStub + fakeReturns := fake.iDReturns + fake.recordInvocation("ID", []interface{}{}) + fake.iDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) IDCallCount() int { + fake.iDMutex.RLock() + defer fake.iDMutex.RUnlock() + return len(fake.iDArgsForCall) +} + +func (fake *FakeSubscribedTrack) IDCalls(stub func() livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = stub +} + +func (fake *FakeSubscribedTrack) IDReturns(result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + fake.iDReturns = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeSubscribedTrack) IDReturnsOnCall(i int, result1 livekit.TrackID) { + fake.iDMutex.Lock() + defer fake.iDMutex.Unlock() + fake.IDStub = nil + if fake.iDReturnsOnCall == nil { + fake.iDReturnsOnCall = make(map[int]struct { + result1 livekit.TrackID + }) + } + fake.iDReturnsOnCall[i] = struct { + result1 livekit.TrackID + }{result1} +} + +func (fake *FakeSubscribedTrack) IsBound() bool { + fake.isBoundMutex.Lock() + ret, specificReturn := fake.isBoundReturnsOnCall[len(fake.isBoundArgsForCall)] + fake.isBoundArgsForCall = append(fake.isBoundArgsForCall, struct { + }{}) + stub := fake.IsBoundStub + fakeReturns := fake.isBoundReturns + fake.recordInvocation("IsBound", []interface{}{}) + fake.isBoundMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) IsBoundCallCount() int { + fake.isBoundMutex.RLock() + defer fake.isBoundMutex.RUnlock() + return len(fake.isBoundArgsForCall) +} + +func (fake *FakeSubscribedTrack) IsBoundCalls(stub func() bool) { + fake.isBoundMutex.Lock() + defer fake.isBoundMutex.Unlock() + fake.IsBoundStub = stub +} + +func (fake *FakeSubscribedTrack) IsBoundReturns(result1 bool) { + fake.isBoundMutex.Lock() + defer fake.isBoundMutex.Unlock() + fake.IsBoundStub = nil + fake.isBoundReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeSubscribedTrack) IsBoundReturnsOnCall(i int, result1 bool) { + fake.isBoundMutex.Lock() + defer fake.isBoundMutex.Unlock() + fake.IsBoundStub = nil + if fake.isBoundReturnsOnCall == nil { + fake.isBoundReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isBoundReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeSubscribedTrack) IsMuted() bool { + fake.isMutedMutex.Lock() + ret, specificReturn := fake.isMutedReturnsOnCall[len(fake.isMutedArgsForCall)] + fake.isMutedArgsForCall = append(fake.isMutedArgsForCall, struct { + }{}) + stub := fake.IsMutedStub + fakeReturns := fake.isMutedReturns + fake.recordInvocation("IsMuted", []interface{}{}) + fake.isMutedMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) IsMutedCallCount() int { + fake.isMutedMutex.RLock() + defer fake.isMutedMutex.RUnlock() + return len(fake.isMutedArgsForCall) +} + +func (fake *FakeSubscribedTrack) IsMutedCalls(stub func() bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = stub +} + +func (fake *FakeSubscribedTrack) IsMutedReturns(result1 bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = nil + fake.isMutedReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeSubscribedTrack) IsMutedReturnsOnCall(i int, result1 bool) { + fake.isMutedMutex.Lock() + defer fake.isMutedMutex.Unlock() + fake.IsMutedStub = nil + if fake.isMutedReturnsOnCall == nil { + fake.isMutedReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isMutedReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeSubscribedTrack) MediaTrack() types.MediaTrack { + fake.mediaTrackMutex.Lock() + ret, specificReturn := fake.mediaTrackReturnsOnCall[len(fake.mediaTrackArgsForCall)] + fake.mediaTrackArgsForCall = append(fake.mediaTrackArgsForCall, struct { + }{}) + stub := fake.MediaTrackStub + fakeReturns := fake.mediaTrackReturns + fake.recordInvocation("MediaTrack", []interface{}{}) + fake.mediaTrackMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) MediaTrackCallCount() int { + fake.mediaTrackMutex.RLock() + defer fake.mediaTrackMutex.RUnlock() + return len(fake.mediaTrackArgsForCall) +} + +func (fake *FakeSubscribedTrack) MediaTrackCalls(stub func() types.MediaTrack) { + fake.mediaTrackMutex.Lock() + defer fake.mediaTrackMutex.Unlock() + fake.MediaTrackStub = stub +} + +func (fake *FakeSubscribedTrack) MediaTrackReturns(result1 types.MediaTrack) { + fake.mediaTrackMutex.Lock() + defer fake.mediaTrackMutex.Unlock() + fake.MediaTrackStub = nil + fake.mediaTrackReturns = struct { + result1 types.MediaTrack + }{result1} +} + +func (fake *FakeSubscribedTrack) MediaTrackReturnsOnCall(i int, result1 types.MediaTrack) { + fake.mediaTrackMutex.Lock() + defer fake.mediaTrackMutex.Unlock() + fake.MediaTrackStub = nil + if fake.mediaTrackReturnsOnCall == nil { + fake.mediaTrackReturnsOnCall = make(map[int]struct { + result1 types.MediaTrack + }) + } + fake.mediaTrackReturnsOnCall[i] = struct { + result1 types.MediaTrack + }{result1} +} + +func (fake *FakeSubscribedTrack) NeedsNegotiation() bool { + fake.needsNegotiationMutex.Lock() + ret, specificReturn := fake.needsNegotiationReturnsOnCall[len(fake.needsNegotiationArgsForCall)] + fake.needsNegotiationArgsForCall = append(fake.needsNegotiationArgsForCall, struct { + }{}) + stub := fake.NeedsNegotiationStub + fakeReturns := fake.needsNegotiationReturns + fake.recordInvocation("NeedsNegotiation", []interface{}{}) + fake.needsNegotiationMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) NeedsNegotiationCallCount() int { + fake.needsNegotiationMutex.RLock() + defer fake.needsNegotiationMutex.RUnlock() + return len(fake.needsNegotiationArgsForCall) +} + +func (fake *FakeSubscribedTrack) NeedsNegotiationCalls(stub func() bool) { + fake.needsNegotiationMutex.Lock() + defer fake.needsNegotiationMutex.Unlock() + fake.NeedsNegotiationStub = stub +} + +func (fake *FakeSubscribedTrack) NeedsNegotiationReturns(result1 bool) { + fake.needsNegotiationMutex.Lock() + defer fake.needsNegotiationMutex.Unlock() + fake.NeedsNegotiationStub = nil + fake.needsNegotiationReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeSubscribedTrack) NeedsNegotiationReturnsOnCall(i int, result1 bool) { + fake.needsNegotiationMutex.Lock() + defer fake.needsNegotiationMutex.Unlock() + fake.NeedsNegotiationStub = nil + if fake.needsNegotiationReturnsOnCall == nil { + fake.needsNegotiationReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.needsNegotiationReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeSubscribedTrack) OnClose(arg1 func(isExpectedToResume bool)) { + fake.onCloseMutex.Lock() + fake.onCloseArgsForCall = append(fake.onCloseArgsForCall, struct { + arg1 func(isExpectedToResume bool) + }{arg1}) + stub := fake.OnCloseStub + fake.recordInvocation("OnClose", []interface{}{arg1}) + fake.onCloseMutex.Unlock() + if stub != nil { + fake.OnCloseStub(arg1) + } +} + +func (fake *FakeSubscribedTrack) OnCloseCallCount() int { + fake.onCloseMutex.RLock() + defer fake.onCloseMutex.RUnlock() + return len(fake.onCloseArgsForCall) +} + +func (fake *FakeSubscribedTrack) OnCloseCalls(stub func(func(isExpectedToResume bool))) { + fake.onCloseMutex.Lock() + defer fake.onCloseMutex.Unlock() + fake.OnCloseStub = stub +} + +func (fake *FakeSubscribedTrack) OnCloseArgsForCall(i int) func(isExpectedToResume bool) { + fake.onCloseMutex.RLock() + defer fake.onCloseMutex.RUnlock() + argsForCall := fake.onCloseArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSubscribedTrack) PublisherID() livekit.ParticipantID { + fake.publisherIDMutex.Lock() + ret, specificReturn := fake.publisherIDReturnsOnCall[len(fake.publisherIDArgsForCall)] + fake.publisherIDArgsForCall = append(fake.publisherIDArgsForCall, struct { + }{}) + stub := fake.PublisherIDStub + fakeReturns := fake.publisherIDReturns + fake.recordInvocation("PublisherID", []interface{}{}) + fake.publisherIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) PublisherIDCallCount() int { + fake.publisherIDMutex.RLock() + defer fake.publisherIDMutex.RUnlock() + return len(fake.publisherIDArgsForCall) +} + +func (fake *FakeSubscribedTrack) PublisherIDCalls(stub func() livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = stub +} + +func (fake *FakeSubscribedTrack) PublisherIDReturns(result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + fake.publisherIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeSubscribedTrack) PublisherIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.publisherIDMutex.Lock() + defer fake.publisherIDMutex.Unlock() + fake.PublisherIDStub = nil + if fake.publisherIDReturnsOnCall == nil { + fake.publisherIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.publisherIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeSubscribedTrack) PublisherIdentity() livekit.ParticipantIdentity { + fake.publisherIdentityMutex.Lock() + ret, specificReturn := fake.publisherIdentityReturnsOnCall[len(fake.publisherIdentityArgsForCall)] + fake.publisherIdentityArgsForCall = append(fake.publisherIdentityArgsForCall, struct { + }{}) + stub := fake.PublisherIdentityStub + fakeReturns := fake.publisherIdentityReturns + fake.recordInvocation("PublisherIdentity", []interface{}{}) + fake.publisherIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) PublisherIdentityCallCount() int { + fake.publisherIdentityMutex.RLock() + defer fake.publisherIdentityMutex.RUnlock() + return len(fake.publisherIdentityArgsForCall) +} + +func (fake *FakeSubscribedTrack) PublisherIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = stub +} + +func (fake *FakeSubscribedTrack) PublisherIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + fake.publisherIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeSubscribedTrack) PublisherIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.publisherIdentityMutex.Lock() + defer fake.publisherIdentityMutex.Unlock() + fake.PublisherIdentityStub = nil + if fake.publisherIdentityReturnsOnCall == nil { + fake.publisherIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.publisherIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeSubscribedTrack) PublisherVersion() uint32 { + fake.publisherVersionMutex.Lock() + ret, specificReturn := fake.publisherVersionReturnsOnCall[len(fake.publisherVersionArgsForCall)] + fake.publisherVersionArgsForCall = append(fake.publisherVersionArgsForCall, struct { + }{}) + stub := fake.PublisherVersionStub + fakeReturns := fake.publisherVersionReturns + fake.recordInvocation("PublisherVersion", []interface{}{}) + fake.publisherVersionMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) PublisherVersionCallCount() int { + fake.publisherVersionMutex.RLock() + defer fake.publisherVersionMutex.RUnlock() + return len(fake.publisherVersionArgsForCall) +} + +func (fake *FakeSubscribedTrack) PublisherVersionCalls(stub func() uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = stub +} + +func (fake *FakeSubscribedTrack) PublisherVersionReturns(result1 uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = nil + fake.publisherVersionReturns = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeSubscribedTrack) PublisherVersionReturnsOnCall(i int, result1 uint32) { + fake.publisherVersionMutex.Lock() + defer fake.publisherVersionMutex.Unlock() + fake.PublisherVersionStub = nil + if fake.publisherVersionReturnsOnCall == nil { + fake.publisherVersionReturnsOnCall = make(map[int]struct { + result1 uint32 + }) + } + fake.publisherVersionReturnsOnCall[i] = struct { + result1 uint32 + }{result1} +} + +func (fake *FakeSubscribedTrack) RTPSender() *webrtc.RTPSender { + fake.rTPSenderMutex.Lock() + ret, specificReturn := fake.rTPSenderReturnsOnCall[len(fake.rTPSenderArgsForCall)] + fake.rTPSenderArgsForCall = append(fake.rTPSenderArgsForCall, struct { + }{}) + stub := fake.RTPSenderStub + fakeReturns := fake.rTPSenderReturns + fake.recordInvocation("RTPSender", []interface{}{}) + fake.rTPSenderMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) RTPSenderCallCount() int { + fake.rTPSenderMutex.RLock() + defer fake.rTPSenderMutex.RUnlock() + return len(fake.rTPSenderArgsForCall) +} + +func (fake *FakeSubscribedTrack) RTPSenderCalls(stub func() *webrtc.RTPSender) { + fake.rTPSenderMutex.Lock() + defer fake.rTPSenderMutex.Unlock() + fake.RTPSenderStub = stub +} + +func (fake *FakeSubscribedTrack) RTPSenderReturns(result1 *webrtc.RTPSender) { + fake.rTPSenderMutex.Lock() + defer fake.rTPSenderMutex.Unlock() + fake.RTPSenderStub = nil + fake.rTPSenderReturns = struct { + result1 *webrtc.RTPSender + }{result1} +} + +func (fake *FakeSubscribedTrack) RTPSenderReturnsOnCall(i int, result1 *webrtc.RTPSender) { + fake.rTPSenderMutex.Lock() + defer fake.rTPSenderMutex.Unlock() + fake.RTPSenderStub = nil + if fake.rTPSenderReturnsOnCall == nil { + fake.rTPSenderReturnsOnCall = make(map[int]struct { + result1 *webrtc.RTPSender + }) + } + fake.rTPSenderReturnsOnCall[i] = struct { + result1 *webrtc.RTPSender + }{result1} +} + +func (fake *FakeSubscribedTrack) SetPublisherMuted(arg1 bool) { + fake.setPublisherMutedMutex.Lock() + fake.setPublisherMutedArgsForCall = append(fake.setPublisherMutedArgsForCall, struct { + arg1 bool + }{arg1}) + stub := fake.SetPublisherMutedStub + fake.recordInvocation("SetPublisherMuted", []interface{}{arg1}) + fake.setPublisherMutedMutex.Unlock() + if stub != nil { + fake.SetPublisherMutedStub(arg1) + } +} + +func (fake *FakeSubscribedTrack) SetPublisherMutedCallCount() int { + fake.setPublisherMutedMutex.RLock() + defer fake.setPublisherMutedMutex.RUnlock() + return len(fake.setPublisherMutedArgsForCall) +} + +func (fake *FakeSubscribedTrack) SetPublisherMutedCalls(stub func(bool)) { + fake.setPublisherMutedMutex.Lock() + defer fake.setPublisherMutedMutex.Unlock() + fake.SetPublisherMutedStub = stub +} + +func (fake *FakeSubscribedTrack) SetPublisherMutedArgsForCall(i int) bool { + fake.setPublisherMutedMutex.RLock() + defer fake.setPublisherMutedMutex.RUnlock() + argsForCall := fake.setPublisherMutedArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSubscribedTrack) Subscriber() types.LocalParticipant { + fake.subscriberMutex.Lock() + ret, specificReturn := fake.subscriberReturnsOnCall[len(fake.subscriberArgsForCall)] + fake.subscriberArgsForCall = append(fake.subscriberArgsForCall, struct { + }{}) + stub := fake.SubscriberStub + fakeReturns := fake.subscriberReturns + fake.recordInvocation("Subscriber", []interface{}{}) + fake.subscriberMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) SubscriberCallCount() int { + fake.subscriberMutex.RLock() + defer fake.subscriberMutex.RUnlock() + return len(fake.subscriberArgsForCall) +} + +func (fake *FakeSubscribedTrack) SubscriberCalls(stub func() types.LocalParticipant) { + fake.subscriberMutex.Lock() + defer fake.subscriberMutex.Unlock() + fake.SubscriberStub = stub +} + +func (fake *FakeSubscribedTrack) SubscriberReturns(result1 types.LocalParticipant) { + fake.subscriberMutex.Lock() + defer fake.subscriberMutex.Unlock() + fake.SubscriberStub = nil + fake.subscriberReturns = struct { + result1 types.LocalParticipant + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberReturnsOnCall(i int, result1 types.LocalParticipant) { + fake.subscriberMutex.Lock() + defer fake.subscriberMutex.Unlock() + fake.SubscriberStub = nil + if fake.subscriberReturnsOnCall == nil { + fake.subscriberReturnsOnCall = make(map[int]struct { + result1 types.LocalParticipant + }) + } + fake.subscriberReturnsOnCall[i] = struct { + result1 types.LocalParticipant + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberID() livekit.ParticipantID { + fake.subscriberIDMutex.Lock() + ret, specificReturn := fake.subscriberIDReturnsOnCall[len(fake.subscriberIDArgsForCall)] + fake.subscriberIDArgsForCall = append(fake.subscriberIDArgsForCall, struct { + }{}) + stub := fake.SubscriberIDStub + fakeReturns := fake.subscriberIDReturns + fake.recordInvocation("SubscriberID", []interface{}{}) + fake.subscriberIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) SubscriberIDCallCount() int { + fake.subscriberIDMutex.RLock() + defer fake.subscriberIDMutex.RUnlock() + return len(fake.subscriberIDArgsForCall) +} + +func (fake *FakeSubscribedTrack) SubscriberIDCalls(stub func() livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = stub +} + +func (fake *FakeSubscribedTrack) SubscriberIDReturns(result1 livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = nil + fake.subscriberIDReturns = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberIDReturnsOnCall(i int, result1 livekit.ParticipantID) { + fake.subscriberIDMutex.Lock() + defer fake.subscriberIDMutex.Unlock() + fake.SubscriberIDStub = nil + if fake.subscriberIDReturnsOnCall == nil { + fake.subscriberIDReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantID + }) + } + fake.subscriberIDReturnsOnCall[i] = struct { + result1 livekit.ParticipantID + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberIdentity() livekit.ParticipantIdentity { + fake.subscriberIdentityMutex.Lock() + ret, specificReturn := fake.subscriberIdentityReturnsOnCall[len(fake.subscriberIdentityArgsForCall)] + fake.subscriberIdentityArgsForCall = append(fake.subscriberIdentityArgsForCall, struct { + }{}) + stub := fake.SubscriberIdentityStub + fakeReturns := fake.subscriberIdentityReturns + fake.recordInvocation("SubscriberIdentity", []interface{}{}) + fake.subscriberIdentityMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityCallCount() int { + fake.subscriberIdentityMutex.RLock() + defer fake.subscriberIdentityMutex.RUnlock() + return len(fake.subscriberIdentityArgsForCall) +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityCalls(stub func() livekit.ParticipantIdentity) { + fake.subscriberIdentityMutex.Lock() + defer fake.subscriberIdentityMutex.Unlock() + fake.SubscriberIdentityStub = stub +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityReturns(result1 livekit.ParticipantIdentity) { + fake.subscriberIdentityMutex.Lock() + defer fake.subscriberIdentityMutex.Unlock() + fake.SubscriberIdentityStub = nil + fake.subscriberIdentityReturns = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeSubscribedTrack) SubscriberIdentityReturnsOnCall(i int, result1 livekit.ParticipantIdentity) { + fake.subscriberIdentityMutex.Lock() + defer fake.subscriberIdentityMutex.Unlock() + fake.SubscriberIdentityStub = nil + if fake.subscriberIdentityReturnsOnCall == nil { + fake.subscriberIdentityReturnsOnCall = make(map[int]struct { + result1 livekit.ParticipantIdentity + }) + } + fake.subscriberIdentityReturnsOnCall[i] = struct { + result1 livekit.ParticipantIdentity + }{result1} +} + +func (fake *FakeSubscribedTrack) UpdateSubscriberSettings(arg1 *livekit.UpdateTrackSettings, arg2 bool) { + fake.updateSubscriberSettingsMutex.Lock() + fake.updateSubscriberSettingsArgsForCall = append(fake.updateSubscriberSettingsArgsForCall, struct { + arg1 *livekit.UpdateTrackSettings + arg2 bool + }{arg1, arg2}) + stub := fake.UpdateSubscriberSettingsStub + fake.recordInvocation("UpdateSubscriberSettings", []interface{}{arg1, arg2}) + fake.updateSubscriberSettingsMutex.Unlock() + if stub != nil { + fake.UpdateSubscriberSettingsStub(arg1, arg2) + } +} + +func (fake *FakeSubscribedTrack) UpdateSubscriberSettingsCallCount() int { + fake.updateSubscriberSettingsMutex.RLock() + defer fake.updateSubscriberSettingsMutex.RUnlock() + return len(fake.updateSubscriberSettingsArgsForCall) +} + +func (fake *FakeSubscribedTrack) UpdateSubscriberSettingsCalls(stub func(*livekit.UpdateTrackSettings, bool)) { + fake.updateSubscriberSettingsMutex.Lock() + defer fake.updateSubscriberSettingsMutex.Unlock() + fake.UpdateSubscriberSettingsStub = stub +} + +func (fake *FakeSubscribedTrack) UpdateSubscriberSettingsArgsForCall(i int) (*livekit.UpdateTrackSettings, bool) { + fake.updateSubscriberSettingsMutex.RLock() + defer fake.updateSubscriberSettingsMutex.RUnlock() + argsForCall := fake.updateSubscriberSettingsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSubscribedTrack) UpdateVideoLayer() { + fake.updateVideoLayerMutex.Lock() + fake.updateVideoLayerArgsForCall = append(fake.updateVideoLayerArgsForCall, struct { + }{}) + stub := fake.UpdateVideoLayerStub + fake.recordInvocation("UpdateVideoLayer", []interface{}{}) + fake.updateVideoLayerMutex.Unlock() + if stub != nil { + fake.UpdateVideoLayerStub() + } +} + +func (fake *FakeSubscribedTrack) UpdateVideoLayerCallCount() int { + fake.updateVideoLayerMutex.RLock() + defer fake.updateVideoLayerMutex.RUnlock() + return len(fake.updateVideoLayerArgsForCall) +} + +func (fake *FakeSubscribedTrack) UpdateVideoLayerCalls(stub func()) { + fake.updateVideoLayerMutex.Lock() + defer fake.updateVideoLayerMutex.Unlock() + fake.UpdateVideoLayerStub = stub +} + +func (fake *FakeSubscribedTrack) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSubscribedTrack) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.SubscribedTrack = new(FakeSubscribedTrack) diff --git a/livekit/pkg/rtc/types/typesfakes/fake_websocket_client.go b/livekit/pkg/rtc/types/typesfakes/fake_websocket_client.go new file mode 100644 index 0000000..6d5b44c --- /dev/null +++ b/livekit/pkg/rtc/types/typesfakes/fake_websocket_client.go @@ -0,0 +1,406 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package typesfakes + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +type FakeWebsocketClient struct { + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + ReadMessageStub func() (int, []byte, error) + readMessageMutex sync.RWMutex + readMessageArgsForCall []struct { + } + readMessageReturns struct { + result1 int + result2 []byte + result3 error + } + readMessageReturnsOnCall map[int]struct { + result1 int + result2 []byte + result3 error + } + SetReadDeadlineStub func(time.Time) error + setReadDeadlineMutex sync.RWMutex + setReadDeadlineArgsForCall []struct { + arg1 time.Time + } + setReadDeadlineReturns struct { + result1 error + } + setReadDeadlineReturnsOnCall map[int]struct { + result1 error + } + WriteControlStub func(int, []byte, time.Time) error + writeControlMutex sync.RWMutex + writeControlArgsForCall []struct { + arg1 int + arg2 []byte + arg3 time.Time + } + writeControlReturns struct { + result1 error + } + writeControlReturnsOnCall map[int]struct { + result1 error + } + WriteMessageStub func(int, []byte) error + writeMessageMutex sync.RWMutex + writeMessageArgsForCall []struct { + arg1 int + arg2 []byte + } + writeMessageReturns struct { + result1 error + } + writeMessageReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeWebsocketClient) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebsocketClient) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeWebsocketClient) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeWebsocketClient) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) ReadMessage() (int, []byte, error) { + fake.readMessageMutex.Lock() + ret, specificReturn := fake.readMessageReturnsOnCall[len(fake.readMessageArgsForCall)] + fake.readMessageArgsForCall = append(fake.readMessageArgsForCall, struct { + }{}) + stub := fake.ReadMessageStub + fakeReturns := fake.readMessageReturns + fake.recordInvocation("ReadMessage", []interface{}{}) + fake.readMessageMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeWebsocketClient) ReadMessageCallCount() int { + fake.readMessageMutex.RLock() + defer fake.readMessageMutex.RUnlock() + return len(fake.readMessageArgsForCall) +} + +func (fake *FakeWebsocketClient) ReadMessageCalls(stub func() (int, []byte, error)) { + fake.readMessageMutex.Lock() + defer fake.readMessageMutex.Unlock() + fake.ReadMessageStub = stub +} + +func (fake *FakeWebsocketClient) ReadMessageReturns(result1 int, result2 []byte, result3 error) { + fake.readMessageMutex.Lock() + defer fake.readMessageMutex.Unlock() + fake.ReadMessageStub = nil + fake.readMessageReturns = struct { + result1 int + result2 []byte + result3 error + }{result1, result2, result3} +} + +func (fake *FakeWebsocketClient) ReadMessageReturnsOnCall(i int, result1 int, result2 []byte, result3 error) { + fake.readMessageMutex.Lock() + defer fake.readMessageMutex.Unlock() + fake.ReadMessageStub = nil + if fake.readMessageReturnsOnCall == nil { + fake.readMessageReturnsOnCall = make(map[int]struct { + result1 int + result2 []byte + result3 error + }) + } + fake.readMessageReturnsOnCall[i] = struct { + result1 int + result2 []byte + result3 error + }{result1, result2, result3} +} + +func (fake *FakeWebsocketClient) SetReadDeadline(arg1 time.Time) error { + fake.setReadDeadlineMutex.Lock() + ret, specificReturn := fake.setReadDeadlineReturnsOnCall[len(fake.setReadDeadlineArgsForCall)] + fake.setReadDeadlineArgsForCall = append(fake.setReadDeadlineArgsForCall, struct { + arg1 time.Time + }{arg1}) + stub := fake.SetReadDeadlineStub + fakeReturns := fake.setReadDeadlineReturns + fake.recordInvocation("SetReadDeadline", []interface{}{arg1}) + fake.setReadDeadlineMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebsocketClient) SetReadDeadlineCallCount() int { + fake.setReadDeadlineMutex.RLock() + defer fake.setReadDeadlineMutex.RUnlock() + return len(fake.setReadDeadlineArgsForCall) +} + +func (fake *FakeWebsocketClient) SetReadDeadlineCalls(stub func(time.Time) error) { + fake.setReadDeadlineMutex.Lock() + defer fake.setReadDeadlineMutex.Unlock() + fake.SetReadDeadlineStub = stub +} + +func (fake *FakeWebsocketClient) SetReadDeadlineArgsForCall(i int) time.Time { + fake.setReadDeadlineMutex.RLock() + defer fake.setReadDeadlineMutex.RUnlock() + argsForCall := fake.setReadDeadlineArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeWebsocketClient) SetReadDeadlineReturns(result1 error) { + fake.setReadDeadlineMutex.Lock() + defer fake.setReadDeadlineMutex.Unlock() + fake.SetReadDeadlineStub = nil + fake.setReadDeadlineReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) SetReadDeadlineReturnsOnCall(i int, result1 error) { + fake.setReadDeadlineMutex.Lock() + defer fake.setReadDeadlineMutex.Unlock() + fake.SetReadDeadlineStub = nil + if fake.setReadDeadlineReturnsOnCall == nil { + fake.setReadDeadlineReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setReadDeadlineReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) WriteControl(arg1 int, arg2 []byte, arg3 time.Time) error { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.writeControlMutex.Lock() + ret, specificReturn := fake.writeControlReturnsOnCall[len(fake.writeControlArgsForCall)] + fake.writeControlArgsForCall = append(fake.writeControlArgsForCall, struct { + arg1 int + arg2 []byte + arg3 time.Time + }{arg1, arg2Copy, arg3}) + stub := fake.WriteControlStub + fakeReturns := fake.writeControlReturns + fake.recordInvocation("WriteControl", []interface{}{arg1, arg2Copy, arg3}) + fake.writeControlMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebsocketClient) WriteControlCallCount() int { + fake.writeControlMutex.RLock() + defer fake.writeControlMutex.RUnlock() + return len(fake.writeControlArgsForCall) +} + +func (fake *FakeWebsocketClient) WriteControlCalls(stub func(int, []byte, time.Time) error) { + fake.writeControlMutex.Lock() + defer fake.writeControlMutex.Unlock() + fake.WriteControlStub = stub +} + +func (fake *FakeWebsocketClient) WriteControlArgsForCall(i int) (int, []byte, time.Time) { + fake.writeControlMutex.RLock() + defer fake.writeControlMutex.RUnlock() + argsForCall := fake.writeControlArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeWebsocketClient) WriteControlReturns(result1 error) { + fake.writeControlMutex.Lock() + defer fake.writeControlMutex.Unlock() + fake.WriteControlStub = nil + fake.writeControlReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) WriteControlReturnsOnCall(i int, result1 error) { + fake.writeControlMutex.Lock() + defer fake.writeControlMutex.Unlock() + fake.WriteControlStub = nil + if fake.writeControlReturnsOnCall == nil { + fake.writeControlReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeControlReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) WriteMessage(arg1 int, arg2 []byte) error { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.writeMessageMutex.Lock() + ret, specificReturn := fake.writeMessageReturnsOnCall[len(fake.writeMessageArgsForCall)] + fake.writeMessageArgsForCall = append(fake.writeMessageArgsForCall, struct { + arg1 int + arg2 []byte + }{arg1, arg2Copy}) + stub := fake.WriteMessageStub + fakeReturns := fake.writeMessageReturns + fake.recordInvocation("WriteMessage", []interface{}{arg1, arg2Copy}) + fake.writeMessageMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeWebsocketClient) WriteMessageCallCount() int { + fake.writeMessageMutex.RLock() + defer fake.writeMessageMutex.RUnlock() + return len(fake.writeMessageArgsForCall) +} + +func (fake *FakeWebsocketClient) WriteMessageCalls(stub func(int, []byte) error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = stub +} + +func (fake *FakeWebsocketClient) WriteMessageArgsForCall(i int) (int, []byte) { + fake.writeMessageMutex.RLock() + defer fake.writeMessageMutex.RUnlock() + argsForCall := fake.writeMessageArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeWebsocketClient) WriteMessageReturns(result1 error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = nil + fake.writeMessageReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) WriteMessageReturnsOnCall(i int, result1 error) { + fake.writeMessageMutex.Lock() + defer fake.writeMessageMutex.Unlock() + fake.WriteMessageStub = nil + if fake.writeMessageReturnsOnCall == nil { + fake.writeMessageReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.writeMessageReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeWebsocketClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeWebsocketClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ types.WebsocketClient = new(FakeWebsocketClient) diff --git a/livekit/pkg/rtc/updatatrackmanager.go b/livekit/pkg/rtc/updatatrackmanager.go new file mode 100644 index 0000000..c23c8d2 --- /dev/null +++ b/livekit/pkg/rtc/updatatrackmanager.go @@ -0,0 +1,109 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "sync" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "golang.org/x/exp/maps" +) + +type UpDataTrackManagerParams struct { + Logger logger.Logger + Participant types.Participant +} + +type UpDataTrackManager struct { + params UpDataTrackManagerParams + + lock sync.RWMutex + dataTracks map[uint16]types.DataTrack + + onDataTrackPublished func(types.Participant, types.DataTrack) + onDataTrackUnpublished func(types.Participant, types.DataTrack) +} + +func NewUpDataTrackManager(params UpDataTrackManagerParams) *UpDataTrackManager { + return &UpDataTrackManager{ + params: params, + dataTracks: make(map[uint16]types.DataTrack), + } +} + +func (u *UpDataTrackManager) AddPublishedDataTrack(dt types.DataTrack) { + u.lock.Lock() + u.dataTracks[dt.PubHandle()] = dt + u.lock.Unlock() + + u.params.Participant.GetParticipantListener().OnDataTrackPublished(u.params.Participant, dt) +} + +func (u *UpDataTrackManager) RemovePublishedDataTrack(dt types.DataTrack) { + var found bool + pubHandle := dt.PubHandle() + u.lock.Lock() + if u.dataTracks[pubHandle] == dt { + delete(u.dataTracks, pubHandle) + found = true + } + u.lock.Unlock() + + if found { + dt.Close() + + u.params.Participant.GetParticipantListener().OnDataTrackUnpublished(u.params.Participant, dt) + } +} + +func (u *UpDataTrackManager) GetPublishedDataTracks() []types.DataTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + return maps.Values(u.dataTracks) +} + +func (u *UpDataTrackManager) GetPublishedDataTrack(handle uint16) types.DataTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.dataTracks[handle] +} + +func (u *UpDataTrackManager) HandleReceivedDataTrackMessage(data []byte, packet *datatrack.Packet, arrivalTime int64) { + u.lock.RLock() + dt := u.dataTracks[packet.Handle] + u.lock.RUnlock() + if dt == nil { + return + } + + dt.HandlePacket(data, packet, arrivalTime) +} + +func (u *UpDataTrackManager) ToProto() []*livekit.DataTrackInfo { + u.lock.RLock() + defer u.lock.RUnlock() + + var dataTrackInfos []*livekit.DataTrackInfo + for _, dt := range u.dataTracks { + dataTrackInfos = append(dataTrackInfos, dt.ToProto()) + } + + return dataTrackInfos +} diff --git a/livekit/pkg/rtc/uptrackmanager.go b/livekit/pkg/rtc/uptrackmanager.go new file mode 100644 index 0000000..da7c961 --- /dev/null +++ b/livekit/pkg/rtc/uptrackmanager.go @@ -0,0 +1,438 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "errors" + "sync" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "golang.org/x/exp/maps" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/utils" +) + +var ( + ErrSubscriptionPermissionNeedsId = errors.New("either participant identity or SID needed") +) + +type UpTrackManagerParams struct { + Logger logger.Logger + VersionGenerator utils.TimedVersionGenerator +} + +// UpTrackManager manages all uptracks from a participant +type UpTrackManager struct { + // utils.TimedVersion is a atomic. To be correctly aligned also on 32bit archs + // 64it atomics need to be at the front of a struct + subscriptionPermissionVersion utils.TimedVersion + + params UpTrackManagerParams + + closed bool + + // publishedTracks that participant is publishing + publishedTracks map[livekit.TrackID]types.MediaTrack + subscriptionPermission *livekit.SubscriptionPermission + // subscriber permission for published tracks + subscriberPermissions map[livekit.ParticipantIdentity]*livekit.TrackPermission // subscriberIdentity => *livekit.TrackPermission + + lock sync.RWMutex + + // callbacks & handlers + onClose func() + onTrackUpdated func(track types.MediaTrack) +} + +func NewUpTrackManager(params UpTrackManagerParams) *UpTrackManager { + return &UpTrackManager{ + params: params, + publishedTracks: make(map[livekit.TrackID]types.MediaTrack), + } +} + +func (u *UpTrackManager) Close(isExpectedToResume bool) { + u.lock.Lock() + if u.closed { + u.lock.Unlock() + return + } + + u.closed = true + + publishedTracks := u.publishedTracks + u.publishedTracks = make(map[livekit.TrackID]types.MediaTrack) + u.lock.Unlock() + + for _, t := range publishedTracks { + t.Close(isExpectedToResume) + } + + if onClose := u.getOnUpTrackManagerClose(); onClose != nil { + onClose() + } +} + +func (u *UpTrackManager) OnUpTrackManagerClose(f func()) { + u.lock.Lock() + u.onClose = f + u.lock.Unlock() +} + +func (u *UpTrackManager) getOnUpTrackManagerClose() func() { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.onClose +} + +func (u *UpTrackManager) ToProto() []*livekit.TrackInfo { + u.lock.RLock() + defer u.lock.RUnlock() + + var trackInfos []*livekit.TrackInfo + for _, t := range u.publishedTracks { + trackInfos = append(trackInfos, t.ToProto()) + } + + return trackInfos +} + +func (u *UpTrackManager) OnPublishedTrackUpdated(f func(track types.MediaTrack)) { + u.onTrackUpdated = f +} + +func (u *UpTrackManager) SetPublishedTrackMuted(trackID livekit.TrackID, muted bool) (types.MediaTrack, bool) { + changed := false + track := u.GetPublishedTrack(trackID) + if track != nil { + currentMuted := track.IsMuted() + track.SetMuted(muted) + + if currentMuted != track.IsMuted() { + changed = true + u.params.Logger.Debugw("publisher mute status changed", "trackID", trackID, "muted", track.IsMuted()) + if u.onTrackUpdated != nil { + u.onTrackUpdated(track) + } + } + } + + return track, changed +} + +func (u *UpTrackManager) GetPublishedTrack(trackID livekit.TrackID) types.MediaTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.getPublishedTrackLocked(trackID) +} + +func (u *UpTrackManager) GetPublishedTracks() []types.MediaTrack { + u.lock.RLock() + defer u.lock.RUnlock() + + return maps.Values(u.publishedTracks) +} + +func (u *UpTrackManager) UpdateSubscriptionPermission( + subscriptionPermission *livekit.SubscriptionPermission, + timedVersion utils.TimedVersion, + resolverBySid func(participantID livekit.ParticipantID) types.LocalParticipant, +) error { + u.lock.Lock() + if !timedVersion.IsZero() { + // it's possible for permission updates to come from another node. In that case + // they would be the authority for this participant's permissions + // we do not want to initialize subscriptionPermissionVersion too early since if another machine is the + // owner for the data, we'd prefer to use their TimedVersion + // ignore older version + if !timedVersion.After(u.subscriptionPermissionVersion) { + u.params.Logger.Debugw( + "skipping older subscription permission version", + "existingValue", logger.Proto(u.subscriptionPermission), + "existingVersion", &u.subscriptionPermissionVersion, + "requestingValue", logger.Proto(subscriptionPermission), + "requestingVersion", &timedVersion, + ) + u.lock.Unlock() + return nil + } + u.subscriptionPermissionVersion.Update(timedVersion) + } else { + // for requests coming from the current node, use local versions + u.subscriptionPermissionVersion.Update(u.params.VersionGenerator.Next()) + } + + // store as is for use when migrating + u.subscriptionPermission = subscriptionPermission + if subscriptionPermission == nil { + u.params.Logger.Debugw( + "updating subscription permission, setting to nil", + "version", u.subscriptionPermissionVersion, + ) + // possible to get a nil when migrating + u.lock.Unlock() + return nil + } + + u.params.Logger.Debugw( + "updating subscription permission", + "permissions", logger.Proto(u.subscriptionPermission), + "version", u.subscriptionPermissionVersion, + ) + if err := u.parseSubscriptionPermissionsLocked(subscriptionPermission, func(pID livekit.ParticipantID) types.LocalParticipant { + u.lock.Unlock() + var p types.LocalParticipant + if resolverBySid != nil { + p = resolverBySid(pID) + } + u.lock.Lock() + return p + }); err != nil { + // when failed, do not override previous permissions + u.params.Logger.Errorw("failed updating subscription permission", err) + u.lock.Unlock() + return err + } + u.lock.Unlock() + + u.maybeRevokeSubscriptions() + + return nil +} + +func (u *UpTrackManager) SubscriptionPermission() (*livekit.SubscriptionPermission, utils.TimedVersion) { + u.lock.RLock() + defer u.lock.RUnlock() + + if u.subscriptionPermissionVersion.IsZero() { + return nil, u.subscriptionPermissionVersion.Load() + } + + return u.subscriptionPermission, u.subscriptionPermissionVersion.Load() +} + +func (u *UpTrackManager) HasPermission(trackID livekit.TrackID, subIdentity livekit.ParticipantIdentity) bool { + u.lock.RLock() + defer u.lock.RUnlock() + + return u.hasPermissionLocked(trackID, subIdentity) +} + +func (u *UpTrackManager) UpdatePublishedAudioTrack(update *livekit.UpdateLocalAudioTrack) types.MediaTrack { + track := u.GetPublishedTrack(livekit.TrackID(update.TrackSid)) + if track != nil { + track.UpdateAudioTrack(update) + if u.onTrackUpdated != nil { + u.onTrackUpdated(track) + } + } + + return track +} + +func (u *UpTrackManager) UpdatePublishedVideoTrack(update *livekit.UpdateLocalVideoTrack) types.MediaTrack { + track := u.GetPublishedTrack(livekit.TrackID(update.TrackSid)) + if track != nil { + track.UpdateVideoTrack(update) + if u.onTrackUpdated != nil { + u.onTrackUpdated(track) + } + } + + return track +} + +func (u *UpTrackManager) AddPublishedTrack(track types.MediaTrack) { + u.lock.Lock() + if _, ok := u.publishedTracks[track.ID()]; !ok { + u.publishedTracks[track.ID()] = track + } + u.lock.Unlock() + u.params.Logger.Debugw("added published track", "trackID", track.ID(), "trackInfo", logger.Proto(track.ToProto())) + + track.AddOnClose(func(_isExpectedToResume bool) { + u.lock.Lock() + delete(u.publishedTracks, track.ID()) + // not modifying subscription permissions, will get reset on next update from participant + u.lock.Unlock() + }) +} + +func (u *UpTrackManager) RemovePublishedTrack(track types.MediaTrack, isExpectedToResume bool) { + track.Close(isExpectedToResume) + + u.lock.Lock() + delete(u.publishedTracks, track.ID()) + u.lock.Unlock() +} + +func (u *UpTrackManager) getPublishedTrackLocked(trackID livekit.TrackID) types.MediaTrack { + return u.publishedTracks[trackID] +} + +func (u *UpTrackManager) parseSubscriptionPermissionsLocked( + subscriptionPermission *livekit.SubscriptionPermission, + resolver func(participantID livekit.ParticipantID) types.LocalParticipant, +) error { + // every update overrides the existing + + // all_participants takes precedence + if subscriptionPermission.AllParticipants { + // everything is allowed, nothing else to do + u.subscriberPermissions = nil + return nil + } + + // per participant permissions + subscriberPermissions := make(map[livekit.ParticipantIdentity]*livekit.TrackPermission) + for _, trackPerms := range subscriptionPermission.TrackPermissions { + subscriberIdentity := livekit.ParticipantIdentity(trackPerms.ParticipantIdentity) + if subscriberIdentity == "" { + if trackPerms.ParticipantSid == "" { + return ErrSubscriptionPermissionNeedsId + } + + sub := resolver(livekit.ParticipantID(trackPerms.ParticipantSid)) + if sub == nil { + u.params.Logger.Warnw("could not find subscriber for permissions update", nil, "subscriberID", trackPerms.ParticipantSid) + continue + } + + subscriberIdentity = sub.Identity() + } else { + if trackPerms.ParticipantSid != "" { + sub := resolver(livekit.ParticipantID(trackPerms.ParticipantSid)) + if sub != nil && sub.Identity() != subscriberIdentity { + u.params.Logger.Errorw("participant identity mismatch", nil, "expected", subscriberIdentity, "got", sub.Identity()) + } + if sub == nil { + u.params.Logger.Warnw("could not find subscriber for permissions update", nil, "subscriberID", trackPerms.ParticipantSid) + } + } + } + + subscriberPermissions[subscriberIdentity] = trackPerms + } + + u.subscriberPermissions = subscriberPermissions + + return nil +} + +func (u *UpTrackManager) hasPermissionLocked(trackID livekit.TrackID, subscriberIdentity livekit.ParticipantIdentity) bool { + if u.subscriberPermissions == nil { + return true + } + + perms, ok := u.subscriberPermissions[subscriberIdentity] + if !ok { + return false + } + + if perms.AllTracks { + return true + } + + for _, sid := range perms.TrackSids { + if livekit.TrackID(sid) == trackID { + return true + } + } + + return false +} + +// returns a list of participants that are allowed to subscribe to the track. if nil is returned, it means everyone is +// allowed to subscribe to this track +func (u *UpTrackManager) getAllowedSubscribersLocked(trackID livekit.TrackID) []livekit.ParticipantIdentity { + if u.subscriberPermissions == nil { + return nil + } + + allowed := make([]livekit.ParticipantIdentity, 0) + for subscriberIdentity, perms := range u.subscriberPermissions { + if perms.AllTracks { + allowed = append(allowed, subscriberIdentity) + continue + } + + for _, sid := range perms.TrackSids { + if livekit.TrackID(sid) == trackID { + allowed = append(allowed, subscriberIdentity) + break + } + } + } + + return allowed +} + +func (u *UpTrackManager) maybeRevokeSubscriptions() { + u.lock.Lock() + defer u.lock.Unlock() + + for trackID, track := range u.publishedTracks { + allowed := u.getAllowedSubscribersLocked(trackID) + if allowed == nil { + // no restrictions + continue + } + + track.RevokeDisallowedSubscribers(allowed) + } +} + +func (u *UpTrackManager) DebugInfo() map[string]any { + info := map[string]any{} + publishedTrackInfo := make(map[livekit.TrackID]any) + + u.lock.RLock() + for trackID, track := range u.publishedTracks { + if mt, ok := track.(*MediaTrack); ok { + publishedTrackInfo[trackID] = mt.DebugInfo() + } else { + publishedTrackInfo[trackID] = map[string]any{ + "ID": track.ID(), + "Kind": track.Kind().String(), + "PubMuted": track.IsMuted(), + } + } + } + u.lock.RUnlock() + + info["PublishedTracks"] = publishedTrackInfo + + return info +} + +func (u *UpTrackManager) GetAudioLevel() (level float64, active bool) { + level = 0 + for _, pt := range u.GetPublishedTracks() { + if pt.Source() == livekit.TrackSource_MICROPHONE { + tl, ta := pt.GetAudioLevel() + if ta { + active = true + if tl > level { + level = tl + } + } + } + } + return +} diff --git a/livekit/pkg/rtc/uptrackmanager_test.go b/livekit/pkg/rtc/uptrackmanager_test.go new file mode 100644 index 0000000..2bee32e --- /dev/null +++ b/livekit/pkg/rtc/uptrackmanager_test.go @@ -0,0 +1,329 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/rtc/types/typesfakes" +) + +var defaultUptrackManagerParams = UpTrackManagerParams{ + Logger: logger.GetLogger(), + VersionGenerator: utils.NewDefaultTimedVersionGenerator(), +} + +func TestUpdateSubscriptionPermission(t *testing.T) { + t.Run("updates subscription permission", func(t *testing.T) { + um := NewUpTrackManager(defaultUptrackManagerParams) + vg := utils.NewDefaultTimedVersionGenerator() + + tra := &typesfakes.FakeMediaTrack{} + tra.IDReturns("audio") + um.publishedTracks["audio"] = tra + + trv := &typesfakes.FakeMediaTrack{} + trv.IDReturns("video") + um.publishedTracks["video"] = trv + + // no restrictive subscription permission + subscriptionPermission := &livekit.SubscriptionPermission{ + AllParticipants: true, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.Nil(t, um.subscriberPermissions) + + // nobody is allowed to subscribe + subscriptionPermission = &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{}, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.NotNil(t, um.subscriberPermissions) + require.Equal(t, 0, len(um.subscriberPermissions)) + + lp1 := &typesfakes.FakeLocalParticipant{} + lp1.IdentityReturns("p1") + lp2 := &typesfakes.FakeLocalParticipant{} + lp2.IdentityReturns("p2") + + sidResolver := func(sid livekit.ParticipantID) types.LocalParticipant { + if sid == "p1" { + return lp1 + } + + if sid == "p2" { + return lp2 + } + + return nil + } + + // allow all tracks for participants + perms1 := &livekit.TrackPermission{ + ParticipantSid: "p1", + AllTracks: true, + } + perms2 := &livekit.TrackPermission{ + ParticipantSid: "p2", + AllTracks: true, + } + subscriptionPermission = &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{ + perms1, + perms2, + }, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), sidResolver) + require.Equal(t, 2, len(um.subscriberPermissions)) + require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) + + // allow all tracks for some and restrictive for others + perms1 = &livekit.TrackPermission{ + ParticipantIdentity: "p1", + AllTracks: true, + } + perms2 = &livekit.TrackPermission{ + ParticipantIdentity: "p2", + TrackSids: []string{"audio"}, + } + perms3 := &livekit.TrackPermission{ + ParticipantIdentity: "p3", + TrackSids: []string{"video"}, + } + subscriptionPermission = &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{ + perms1, + perms2, + perms3, + }, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.Equal(t, 3, len(um.subscriberPermissions)) + require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) + require.EqualValues(t, perms3, um.subscriberPermissions["p3"]) + }) + + t.Run("updates subscription permission using both", func(t *testing.T) { + um := NewUpTrackManager(defaultUptrackManagerParams) + vg := utils.NewDefaultTimedVersionGenerator() + + tra := &typesfakes.FakeMediaTrack{} + tra.IDReturns("audio") + um.publishedTracks["audio"] = tra + + trv := &typesfakes.FakeMediaTrack{} + trv.IDReturns("video") + um.publishedTracks["video"] = trv + + lp1 := &typesfakes.FakeLocalParticipant{} + lp1.IdentityReturns("p1") + lp2 := &typesfakes.FakeLocalParticipant{} + lp2.IdentityReturns("p2") + + sidResolver := func(sid livekit.ParticipantID) types.LocalParticipant { + if sid == "p1" { + return lp1 + } + + if sid == "p2" { + return lp2 + } + + return nil + } + + // allow all tracks for participants + perms1 := &livekit.TrackPermission{ + ParticipantSid: "p1", + ParticipantIdentity: "p1", + AllTracks: true, + } + perms2 := &livekit.TrackPermission{ + ParticipantSid: "p2", + ParticipantIdentity: "p2", + AllTracks: true, + } + subscriptionPermission := &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{ + perms1, + perms2, + }, + } + err := um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), sidResolver) + require.NoError(t, err) + require.Equal(t, 2, len(um.subscriberPermissions)) + require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) + + // mismatched identities should fail a permission update + badSidResolver := func(sid livekit.ParticipantID) types.LocalParticipant { + if sid == "p1" { + return lp2 + } + + if sid == "p2" { + return lp1 + } + + return nil + } + + err = um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), badSidResolver) + require.NoError(t, err) + require.Equal(t, 2, len(um.subscriberPermissions)) + require.EqualValues(t, perms1, um.subscriberPermissions["p1"]) + require.EqualValues(t, perms2, um.subscriberPermissions["p2"]) + }) + + t.Run("update versions", func(t *testing.T) { + um := NewUpTrackManager(defaultUptrackManagerParams) + vg := utils.NewDefaultTimedVersionGenerator() + + v0, v1, v2 := vg.Next(), vg.Next(), vg.Next() + + um.UpdateSubscriptionPermission(&livekit.SubscriptionPermission{}, v1, nil) + require.Equal(t, v1.Load(), um.subscriptionPermissionVersion.Load(), "first update should be applied") + + um.UpdateSubscriptionPermission(&livekit.SubscriptionPermission{}, v2, nil) + require.Equal(t, v2.Load(), um.subscriptionPermissionVersion.Load(), "ordered updates should be applied") + + um.UpdateSubscriptionPermission(&livekit.SubscriptionPermission{}, v0, nil) + require.Equal(t, v2.Load(), um.subscriptionPermissionVersion.Load(), "out of order updates should be ignored") + + um.UpdateSubscriptionPermission(&livekit.SubscriptionPermission{}, utils.TimedVersion(0), nil) + require.True(t, um.subscriptionPermissionVersion.After(v2), "zero version in updates should use next local version") + }) +} + +func TestSubscriptionPermission(t *testing.T) { + t.Run("checks subscription permission", func(t *testing.T) { + um := NewUpTrackManager(defaultUptrackManagerParams) + vg := utils.NewDefaultTimedVersionGenerator() + + tra := &typesfakes.FakeMediaTrack{} + tra.IDReturns("audio") + um.publishedTracks["audio"] = tra + + trv := &typesfakes.FakeMediaTrack{} + trv.IDReturns("video") + um.publishedTracks["video"] = trv + + // no restrictive permission + subscriptionPermission := &livekit.SubscriptionPermission{ + AllParticipants: true, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.True(t, um.hasPermissionLocked("audio", "p1")) + require.True(t, um.hasPermissionLocked("audio", "p2")) + + // nobody is allowed to subscribe + subscriptionPermission = &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{}, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.False(t, um.hasPermissionLocked("audio", "p1")) + require.False(t, um.hasPermissionLocked("audio", "p2")) + + // allow all tracks for participants + subscriptionPermission = &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{ + { + ParticipantIdentity: "p1", + AllTracks: true, + }, + { + ParticipantIdentity: "p2", + AllTracks: true, + }, + }, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.True(t, um.hasPermissionLocked("audio", "p1")) + require.True(t, um.hasPermissionLocked("video", "p1")) + require.True(t, um.hasPermissionLocked("audio", "p2")) + require.True(t, um.hasPermissionLocked("video", "p2")) + + // add a new track after permissions are set + trs := &typesfakes.FakeMediaTrack{} + trs.IDReturns("screen") + um.publishedTracks["screen"] = trs + + require.True(t, um.hasPermissionLocked("audio", "p1")) + require.True(t, um.hasPermissionLocked("video", "p1")) + require.True(t, um.hasPermissionLocked("screen", "p1")) + require.True(t, um.hasPermissionLocked("audio", "p2")) + require.True(t, um.hasPermissionLocked("video", "p2")) + require.True(t, um.hasPermissionLocked("screen", "p2")) + + // allow all tracks for some and restrictive for others + subscriptionPermission = &livekit.SubscriptionPermission{ + TrackPermissions: []*livekit.TrackPermission{ + { + ParticipantIdentity: "p1", + AllTracks: true, + }, + { + ParticipantIdentity: "p2", + TrackSids: []string{"audio"}, + }, + { + ParticipantIdentity: "p3", + TrackSids: []string{"video"}, + }, + }, + } + um.UpdateSubscriptionPermission(subscriptionPermission, vg.Next(), nil) + require.True(t, um.hasPermissionLocked("audio", "p1")) + require.True(t, um.hasPermissionLocked("video", "p1")) + require.True(t, um.hasPermissionLocked("screen", "p1")) + + require.True(t, um.hasPermissionLocked("audio", "p2")) + require.False(t, um.hasPermissionLocked("video", "p2")) + require.False(t, um.hasPermissionLocked("screen", "p2")) + + require.False(t, um.hasPermissionLocked("audio", "p3")) + require.True(t, um.hasPermissionLocked("video", "p3")) + require.False(t, um.hasPermissionLocked("screen", "p3")) + + // add a new track after restrictive permissions are set + trw := &typesfakes.FakeMediaTrack{} + trw.IDReturns("watch") + um.publishedTracks["watch"] = trw + + require.True(t, um.hasPermissionLocked("audio", "p1")) + require.True(t, um.hasPermissionLocked("video", "p1")) + require.True(t, um.hasPermissionLocked("screen", "p1")) + require.True(t, um.hasPermissionLocked("watch", "p1")) + + require.True(t, um.hasPermissionLocked("audio", "p2")) + require.False(t, um.hasPermissionLocked("video", "p2")) + require.False(t, um.hasPermissionLocked("screen", "p2")) + require.False(t, um.hasPermissionLocked("watch", "p2")) + + require.False(t, um.hasPermissionLocked("audio", "p3")) + require.True(t, um.hasPermissionLocked("video", "p3")) + require.False(t, um.hasPermissionLocked("screen", "p3")) + require.False(t, um.hasPermissionLocked("watch", "p3")) + }) +} diff --git a/livekit/pkg/rtc/user_packet_deduper.go b/livekit/pkg/rtc/user_packet_deduper.go new file mode 100644 index 0000000..cd2ed40 --- /dev/null +++ b/livekit/pkg/rtc/user_packet_deduper.go @@ -0,0 +1,66 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "sync" + + "github.com/google/uuid" + "github.com/livekit/protocol/livekit" +) + +const ( + maxSize = 100 +) + +type UserPacketDeduper struct { + lock sync.Mutex + seen map[uuid.UUID]uuid.UUID + head uuid.UUID + tail uuid.UUID +} + +func NewUserPacketDeduper() *UserPacketDeduper { + return &UserPacketDeduper{ + seen: make(map[uuid.UUID]uuid.UUID), + } +} + +func (u *UserPacketDeduper) IsDuplicate(up *livekit.UserPacket) bool { + id, err := uuid.FromBytes(up.Nonce) + if err != nil { + return false + } + + u.lock.Lock() + defer u.lock.Unlock() + + if u.head == id { + return true + } + if _, ok := u.seen[id]; ok { + return true + } + + u.seen[u.head] = id + u.head = id + + if len(u.seen) == maxSize { + tail := u.tail + u.tail = u.seen[tail] + delete(u.seen, tail) + } + return false +} diff --git a/livekit/pkg/rtc/utils.go b/livekit/pkg/rtc/utils.go new file mode 100644 index 0000000..1a584b3 --- /dev/null +++ b/livekit/pkg/rtc/utils.go @@ -0,0 +1,202 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "errors" + "io" + "net" + "strings" + + "github.com/pion/webrtc/v4" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +const ( + trackIdSeparator = "|" + + cMinIPTruncateLen = 8 +) + +func UnpackStreamID(packed string) (participantID livekit.ParticipantID, trackID livekit.TrackID) { + parts := strings.Split(packed, trackIdSeparator) + if len(parts) > 1 { + return livekit.ParticipantID(parts[0]), livekit.TrackID(packed[len(parts[0])+1:]) + } + return livekit.ParticipantID(packed), "" +} + +func PackStreamID(participantID livekit.ParticipantID, trackID livekit.TrackID) string { + return string(participantID) + trackIdSeparator + string(trackID) +} + +func PackSyncStreamID(participantID livekit.ParticipantID, stream string) string { + return string(participantID) + trackIdSeparator + stream +} + +func StreamFromTrackSource(source livekit.TrackSource) string { + // group camera/mic, screenshare/audio together + switch source { + case livekit.TrackSource_SCREEN_SHARE: + return "screen" + case livekit.TrackSource_SCREEN_SHARE_AUDIO: + return "screen" + case livekit.TrackSource_CAMERA: + return "camera" + case livekit.TrackSource_MICROPHONE: + return "camera" + } + return "unknown" +} + +func PackDataTrackLabel(participantID livekit.ParticipantID, trackID livekit.TrackID, label string) string { + return string(participantID) + trackIdSeparator + string(trackID) + trackIdSeparator + label +} + +func UnpackDataTrackLabel(packed string) (participantID livekit.ParticipantID, trackID livekit.TrackID, label string) { + parts := strings.Split(packed, trackIdSeparator) + if len(parts) != 3 { + return "", livekit.TrackID(packed), "" + } + participantID = livekit.ParticipantID(parts[0]) + trackID = livekit.TrackID(parts[1]) + label = parts[2] + return +} + +func ToProtoTrackKind(kind webrtc.RTPCodecType) livekit.TrackType { + switch kind { + case webrtc.RTPCodecTypeVideo: + return livekit.TrackType_VIDEO + case webrtc.RTPCodecTypeAudio: + return livekit.TrackType_AUDIO + } + panic("unsupported track direction") +} + +func IsEOF(err error) bool { + return err == io.ErrClosedPipe || err == io.EOF +} + +func Recover(l logger.Logger) any { + if l == nil { + l = logger.GetLogger() + } + r := recover() + if r != nil { + var err error + switch e := r.(type) { + case string: + err = errors.New(e) + case error: + err = e + default: + err = errors.New("unknown panic") + } + l.Errorw("recovered panic", err, "panic", r) + } + + return r +} + +// logger helpers +func LoggerWithParticipant(l logger.Logger, identity livekit.ParticipantIdentity, sid livekit.ParticipantID, isRemote bool) logger.Logger { + values := make([]any, 0, 4) + if identity != "" { + values = append(values, "participant", identity) + } + if sid != "" { + values = append(values, "pID", sid) + } + values = append(values, "remote", isRemote) + // enable sampling per participant + return l.WithValues(values...) +} + +func LoggerWithRoom(l logger.Logger, name livekit.RoomName, roomID livekit.RoomID) logger.Logger { + values := make([]any, 0, 2) + if name != "" { + values = append(values, "room", name) + } + if roomID != "" { + values = append(values, "roomID", roomID) + } + // also sample for the room + return l.WithItemSampler().WithValues(values...) +} + +func LoggerWithTrack(l logger.Logger, trackID livekit.TrackID, isRelayed bool) logger.Logger { + // sampling not required because caller already passing in participant's logger + if trackID != "" { + return l.WithValues("trackID", trackID, "relayed", isRelayed) + } + return l +} + +func LoggerWithPCTarget(l logger.Logger, target livekit.SignalTarget) logger.Logger { + return l.WithValues("transport", target) +} + +func LoggerWithCodecMime(l logger.Logger, mimeType mime.MimeType) logger.Logger { + if mimeType != mime.MimeTypeUnknown { + return l.WithValues("mime", mimeType.String()) + } + return l +} + +func MaybeTruncateIP(addr string) string { + ipAddr := net.ParseIP(addr) + if ipAddr == nil { + return "" + } + + if ipAddr.IsPrivate() || len(addr) <= cMinIPTruncateLen { + return addr + } + + return addr[:len(addr)-3] + "..." +} + +func ChunkProtoBatch[T proto.Message](batch []T, target int) [][]T { + var chunks [][]T + var start, size int + for i, m := range batch { + s := proto.Size(m) + if size+s > target { + if start < i { + chunks = append(chunks, batch[start:i]) + } + start = i + size = 0 + } + size += s + } + if start < len(batch) { + chunks = append(chunks, batch[start:]) + } + return chunks +} + +func IsRedEnabled(ti *livekit.TrackInfo) bool { + if len(ti.Codecs) != 0 && ti.Codecs[0].MimeType != "" { + return mime.IsMimeTypeStringRED(ti.Codecs[0].MimeType) + } + + return !ti.GetDisableRed() +} diff --git a/livekit/pkg/rtc/utils_test.go b/livekit/pkg/rtc/utils_test.go new file mode 100644 index 0000000..37a6f01 --- /dev/null +++ b/livekit/pkg/rtc/utils_test.go @@ -0,0 +1,75 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "math/rand/v2" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/guid" +) + +func TestPackStreamId(t *testing.T) { + packed := "PA_123abc|uuid-id" + pID, trackID := UnpackStreamID(packed) + require.Equal(t, livekit.ParticipantID("PA_123abc"), pID) + require.Equal(t, livekit.TrackID("uuid-id"), trackID) + + require.Equal(t, packed, PackStreamID(pID, trackID)) +} + +func TestPackDataTrackLabel(t *testing.T) { + pID := livekit.ParticipantID("PA_123abc") + trackID := livekit.TrackID("TR_b3da25") + label := "trackLabel" + packed := "PA_123abc|TR_b3da25|trackLabel" + require.Equal(t, packed, PackDataTrackLabel(pID, trackID, label)) + + p, tr, l := UnpackDataTrackLabel(packed) + require.Equal(t, pID, p) + require.Equal(t, trackID, tr) + require.Equal(t, label, l) +} + +func TestChunkProtoBatch(t *testing.T) { + rng := rand.New(rand.NewPCG(1, 2)) + var updates []*livekit.ParticipantInfo + for range 32 { + updates = append(updates, &livekit.ParticipantInfo{ + Sid: guid.New(guid.ParticipantPrefix), + Identity: uuid.NewString(), + Metadata: strings.Repeat("x", rng.IntN(128*1024)), + }) + } + + target := 64 * 1024 + batches := ChunkProtoBatch(updates, target) + var count int + for _, b := range batches { + var sum int + for _, m := range b { + sum += proto.Size(m) + count++ + } + require.True(t, sum < target || len(b) == 1, "batch size exceeds target") + } + require.Equal(t, len(updates), count) +} diff --git a/livekit/pkg/rtc/wrappedreceiver.go b/livekit/pkg/rtc/wrappedreceiver.go new file mode 100644 index 0000000..ad3a3b9 --- /dev/null +++ b/livekit/pkg/rtc/wrappedreceiver.go @@ -0,0 +1,594 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtc + +import ( + "errors" + "sync" + + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" +) + +// wrapper around WebRTC receiver, overriding its ID + +type WrappedReceiverParams struct { + Receivers []*simulcastReceiver + TrackID livekit.TrackID + StreamId string + UpstreamCodecs []webrtc.RTPCodecParameters + Logger logger.Logger + DisableRed bool + IsEncrypted bool +} + +type WrappedReceiver struct { + lock sync.Mutex + + sfu.TrackReceiver + params WrappedReceiverParams + receivers []sfu.TrackReceiver + codecs []webrtc.RTPCodecParameters + onReadyCallbacks []func() +} + +func NewWrappedReceiver(params WrappedReceiverParams) *WrappedReceiver { + sfuReceivers := make([]sfu.TrackReceiver, 0, len(params.Receivers)) + for _, r := range params.Receivers { + sfuReceivers = append(sfuReceivers, r) + } + + codecs := params.UpstreamCodecs + if len(codecs) == 1 { + normalizedMimeType := mime.NormalizeMimeType(codecs[0].MimeType) + if normalizedMimeType == mime.MimeTypeRED { + // if upstream is opus/red, then add opus to match clients that don't support red + codecs = append(codecs, OpusCodecParameters) + } else if !params.DisableRed && normalizedMimeType == mime.MimeTypeOpus { + // if upstream is opus only and red enabled, add red to match clients that support red + codecs = append(codecs, RedCodecParameters) + // prefer red codec + codecs[0], codecs[1] = codecs[1], codecs[0] + } + } + + return &WrappedReceiver{ + params: params, + receivers: sfuReceivers, + codecs: codecs, + } +} + +func (r *WrappedReceiver) TrackID() livekit.TrackID { + return r.params.TrackID +} + +func (r *WrappedReceiver) StreamID() string { + return r.params.StreamId +} + +// DetermineReceiver determines the receiver of negotiated codec and returns +// +// isAvailable: returns true if given codec is a potential codec from publisher or if an existing published codec can be translated +// needsPublish: indicates if the codec is needed from publisher, some combinations can be achieved via codec translation internally, +// +// example: unencrypted opus -> RED translation and vice-versa can be done without the need for publisher to send the other codec. +func (r *WrappedReceiver) DetermineReceiver(codec webrtc.RTPCodecCapability) (isAvailable bool, needsPublish bool) { + r.lock.Lock() + + reason := "no matching receiver" + codecMimeType := mime.NormalizeMimeType(codec.MimeType) + var trackReceiver sfu.TrackReceiver + for _, receiver := range r.receivers { + receiverMimeType := receiver.Mime() + if receiverMimeType == codecMimeType { + trackReceiver = receiver + isAvailable = true + needsPublish = true + break + } + + if receiverMimeType == mime.MimeTypeRED && codecMimeType == mime.MimeTypeOpus { + // audio opus/red can match opus only + if !r.params.IsEncrypted { // cannot match encrypted source + trackReceiver = receiver.GetPrimaryReceiverForRed() + isAvailable = true + break + } else { + reason = "encrypted source" + } + } else if receiverMimeType == mime.MimeTypeOpus && codecMimeType == mime.MimeTypeRED { + if !r.params.IsEncrypted { // cannot match encrypted source + trackReceiver = receiver.GetRedReceiver() + isAvailable = true + break + } else { + reason = "encrypted source" + } + } + + } + if trackReceiver == nil { + r.lock.Unlock() + r.params.Logger.Warnw( + "can't determine receiver for codec", nil, + "codec", codec.MimeType, + "reason", reason, + ) + return + } + r.TrackReceiver = trackReceiver + + onReadyCallbacks := r.onReadyCallbacks + r.onReadyCallbacks = nil + r.lock.Unlock() + + for _, f := range onReadyCallbacks { + trackReceiver.AddOnReady(f) + } + + return +} + +func (r *WrappedReceiver) Codecs() []webrtc.RTPCodecParameters { + return slices.Clone(r.codecs) +} + +func (r *WrappedReceiver) DeleteDownTrack(participantID livekit.ParticipantID) { + r.lock.Lock() + trackReceiver := r.TrackReceiver + r.lock.Unlock() + + if trackReceiver != nil { + trackReceiver.DeleteDownTrack(participantID) + } +} + +func (r *WrappedReceiver) AddOnReady(f func()) { + r.lock.Lock() + trackReceiver := r.TrackReceiver + if trackReceiver == nil { + r.onReadyCallbacks = append(r.onReadyCallbacks, f) + r.lock.Unlock() + return + } + r.lock.Unlock() + + trackReceiver.AddOnReady(f) +} + +// -------------------------------------------- + +type DummyReceiver struct { + receiver atomic.Value + trackID livekit.TrackID + streamId string + codec webrtc.RTPCodecParameters + headerExtensions []webrtc.RTPHeaderExtensionParameter + + downTrackLock sync.Mutex + downTracks map[livekit.ParticipantID]sfu.TrackSender + onReadyCallbacks []func() + onCodecStateChange []func(webrtc.RTPCodecParameters, sfu.ReceiverCodecState) + + settingsLock sync.Mutex + maxExpectedLayerValid bool + maxExpectedLayer int32 + pausedValid bool + paused bool + + redReceiver, primaryReceiver *DummyRedReceiver +} + +func NewDummyReceiver(trackID livekit.TrackID, streamId string, codec webrtc.RTPCodecParameters, headerExtensions []webrtc.RTPHeaderExtensionParameter) *DummyReceiver { + return &DummyReceiver{ + trackID: trackID, + streamId: streamId, + codec: codec, + headerExtensions: headerExtensions, + downTracks: make(map[livekit.ParticipantID]sfu.TrackSender), + } +} + +func (d *DummyReceiver) Receiver() sfu.TrackReceiver { + r, _ := d.receiver.Load().(sfu.TrackReceiver) + return r +} + +func (d *DummyReceiver) Upgrade(receiver sfu.TrackReceiver) { + if !d.receiver.CompareAndSwap(nil, receiver) { + return + } + + d.downTrackLock.Lock() + for _, t := range d.downTracks { + receiver.AddDownTrack(t) + } + d.downTracks = make(map[livekit.ParticipantID]sfu.TrackSender) + + onReadyCallbacks := d.onReadyCallbacks + d.onReadyCallbacks = nil + + codecChange := d.onCodecStateChange + d.onCodecStateChange = nil + d.downTrackLock.Unlock() + + for _, f := range onReadyCallbacks { + receiver.AddOnReady(f) + } + + for _, f := range codecChange { + receiver.AddOnCodecStateChange(f) + } + + d.settingsLock.Lock() + maxExpectedLayerValid := d.maxExpectedLayerValid + d.maxExpectedLayerValid = false + + pausedValid := d.pausedValid + d.pausedValid = false + d.settingsLock.Unlock() + + if maxExpectedLayerValid { + receiver.SetMaxExpectedSpatialLayer(d.maxExpectedLayer) + } + + if pausedValid { + receiver.SetUpTrackPaused(d.paused) + } + + d.settingsLock.Lock() + if d.primaryReceiver != nil { + d.primaryReceiver.upgrade(receiver) + } + if d.redReceiver != nil { + d.redReceiver.upgrade(receiver) + } + d.settingsLock.Unlock() +} + +func (d *DummyReceiver) TrackID() livekit.TrackID { + return d.trackID +} + +func (d *DummyReceiver) StreamID() string { + return d.streamId +} + +func (d *DummyReceiver) Codec() webrtc.RTPCodecParameters { + if receiver := d.getReceiver(); receiver != nil { + return receiver.Codec() + } + return d.codec +} + +func (d *DummyReceiver) Mime() mime.MimeType { + if receiver := d.getReceiver(); receiver != nil { + return receiver.Mime() + } + return mime.NormalizeMimeType(d.codec.MimeType) +} + +func (d *DummyReceiver) VideoLayerMode() livekit.VideoLayer_Mode { + if receiver := d.getReceiver(); receiver != nil { + return receiver.VideoLayerMode() + } + return buffer.GetVideoLayerModeForMimeType(d.Mime(), d.TrackInfo()) +} + +func (d *DummyReceiver) HeaderExtensions() []webrtc.RTPHeaderExtensionParameter { + if receiver := d.getReceiver(); receiver != nil { + return receiver.HeaderExtensions() + } + return d.headerExtensions +} + +func (d *DummyReceiver) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { + if receiver := d.getReceiver(); receiver != nil { + return receiver.ReadRTP(buf, layer, esn) + } + return 0, errors.New("no receiver") +} + +func (d *DummyReceiver) GetLayeredBitrate() ([]int32, sfu.Bitrates) { + if receiver := d.getReceiver(); receiver != nil { + return receiver.GetLayeredBitrate() + } + return nil, sfu.Bitrates{} +} + +func (d *DummyReceiver) GetAudioLevel() (float64, bool) { + if receiver := d.getReceiver(); receiver != nil { + return receiver.GetAudioLevel() + } + return 0, false +} + +func (d *DummyReceiver) SendPLI(layer int32, force bool) { + if receiver := d.getReceiver(); receiver != nil { + receiver.SendPLI(layer, force) + } +} + +func (d *DummyReceiver) SetUpTrackPaused(paused bool) { + d.settingsLock.Lock() + receiver := d.getReceiver() + if receiver != nil { + d.pausedValid = false + } else { + d.pausedValid = true + d.paused = paused + } + d.settingsLock.Unlock() + + if receiver != nil { + receiver.SetUpTrackPaused(paused) + } +} + +func (d *DummyReceiver) SetMaxExpectedSpatialLayer(layer int32) { + d.settingsLock.Lock() + receiver := d.getReceiver() + if receiver != nil { + d.maxExpectedLayerValid = false + } else { + d.maxExpectedLayerValid = true + d.maxExpectedLayer = layer + } + d.settingsLock.Unlock() + + if receiver != nil { + receiver.SetMaxExpectedSpatialLayer(layer) + } +} + +func (d *DummyReceiver) AddDownTrack(track sfu.TrackSender) error { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if receiver := d.getReceiver(); receiver != nil { + receiver.AddDownTrack(track) + } else { + d.downTracks[track.SubscriberID()] = track + } + return nil +} + +func (d *DummyReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if receiver := d.getReceiver(); receiver != nil { + receiver.DeleteDownTrack(subscriberID) + } else { + delete(d.downTracks, subscriberID) + } +} + +func (d *DummyReceiver) GetDownTracks() []sfu.TrackSender { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if receiver := d.getReceiver(); receiver != nil { + return receiver.GetDownTracks() + } + return maps.Values(d.downTracks) +} + +func (d *DummyReceiver) DebugInfo() map[string]any { + if receiver := d.getReceiver(); receiver != nil { + return receiver.DebugInfo() + } + return nil +} + +func (d *DummyReceiver) GetTemporalLayerFpsForSpatial(spatial int32) []float32 { + if receiver := d.getReceiver(); receiver != nil { + return receiver.GetTemporalLayerFpsForSpatial(spatial) + } + return nil +} + +func (d *DummyReceiver) TrackInfo() *livekit.TrackInfo { + if receiver := d.getReceiver(); receiver != nil { + return receiver.TrackInfo() + } + return nil +} + +func (d *DummyReceiver) UpdateTrackInfo(ti *livekit.TrackInfo) { + if receiver := d.getReceiver(); receiver != nil { + receiver.UpdateTrackInfo(ti) + } +} + +func (d *DummyReceiver) IsClosed() bool { + if receiver := d.getReceiver(); receiver != nil { + return receiver.IsClosed() + } + return false +} + +func (d *DummyReceiver) GetPrimaryReceiverForRed() sfu.TrackReceiver { + d.settingsLock.Lock() + defer d.settingsLock.Unlock() + + if d.primaryReceiver == nil { + d.primaryReceiver = NewDummyRedReceiver(d, false) + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + d.primaryReceiver.upgrade(r) + } + } + return d.primaryReceiver +} + +func (d *DummyReceiver) GetRedReceiver() sfu.TrackReceiver { + d.settingsLock.Lock() + defer d.settingsLock.Unlock() + if d.redReceiver == nil { + d.redReceiver = NewDummyRedReceiver(d, true) + if r, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + d.redReceiver.upgrade(r) + } + } + return d.redReceiver +} + +func (d *DummyReceiver) GetTrackStats() *livekit.RTPStats { + if receiver := d.getReceiver(); receiver != nil { + return receiver.GetTrackStats() + } + return nil +} + +func (d *DummyReceiver) AddOnReady(f func()) { + d.downTrackLock.Lock() + receiver := d.getReceiver() + if receiver == nil { + d.onReadyCallbacks = append(d.onReadyCallbacks, f) + } + d.downTrackLock.Unlock() + if receiver != nil { + receiver.AddOnReady(f) + } +} + +func (d *DummyReceiver) AddOnCodecStateChange(f func(codec webrtc.RTPCodecParameters, state sfu.ReceiverCodecState)) { + d.downTrackLock.Lock() + receiver := d.getReceiver() + if receiver == nil { + d.onCodecStateChange = append(d.onCodecStateChange, f) + } + d.downTrackLock.Unlock() + if receiver != nil { + receiver.AddOnCodecStateChange(f) + } +} + +func (d *DummyReceiver) CodecState() sfu.ReceiverCodecState { + if receiver := d.getReceiver(); receiver != nil { + return receiver.CodecState() + } + return sfu.ReceiverCodecStateNormal +} + +func (d *DummyReceiver) VideoSizes() []buffer.VideoSize { + if receiver := d.getReceiver(); receiver != nil { + return receiver.VideoSizes() + } + + return nil +} + +func (d *DummyReceiver) Restart(reason string) { + if receiver := d.getReceiver(); receiver != nil { + receiver.Restart(reason) + } +} + +func (d *DummyReceiver) getReceiver() sfu.TrackReceiver { + if receiver, ok := d.receiver.Load().(sfu.TrackReceiver); ok { + return receiver + } + + return nil +} + +// -------------------------------------------- + +type DummyRedReceiver struct { + *DummyReceiver + redReceiver atomic.Value // sfu.TrackReceiver + // indicates this receiver is for RED encoding receiver of primary codec OR + // primary decoding receiver of RED codec + isRedEncoding bool + + downTrackLock sync.Mutex + downTracks map[livekit.ParticipantID]sfu.TrackSender +} + +func NewDummyRedReceiver(d *DummyReceiver, isRedEncoding bool) *DummyRedReceiver { + return &DummyRedReceiver{ + DummyReceiver: d, + isRedEncoding: isRedEncoding, + downTracks: make(map[livekit.ParticipantID]sfu.TrackSender), + } +} + +func (d *DummyRedReceiver) AddDownTrack(track sfu.TrackSender) error { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + r.AddDownTrack(track) + } else { + d.downTracks[track.SubscriberID()] = track + } + return nil +} + +func (d *DummyRedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + r.DeleteDownTrack(subscriberID) + } else { + delete(d.downTracks, subscriberID) + } +} + +func (d *DummyRedReceiver) GetDownTracks() []sfu.TrackSender { + d.downTrackLock.Lock() + defer d.downTrackLock.Unlock() + + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + return r.GetDownTracks() + } + return maps.Values(d.downTracks) +} + +func (d *DummyRedReceiver) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { + if r, ok := d.redReceiver.Load().(sfu.TrackReceiver); ok { + return r.ReadRTP(buf, layer, esn) + } + return 0, errors.New("no receiver") +} + +func (d *DummyRedReceiver) upgrade(receiver sfu.TrackReceiver) { + var redReceiver sfu.TrackReceiver + if d.isRedEncoding { + redReceiver = receiver.GetRedReceiver() + } else { + redReceiver = receiver.GetPrimaryReceiverForRed() + } + d.redReceiver.Store(redReceiver) + + d.downTrackLock.Lock() + for _, t := range d.downTracks { + redReceiver.AddDownTrack(t) + } + d.downTracks = make(map[livekit.ParticipantID]sfu.TrackSender) + d.downTrackLock.Unlock() +} diff --git a/livekit/pkg/service/agent_dispatch_service.go b/livekit/pkg/service/agent_dispatch_service.go new file mode 100644 index 0000000..2cf918c --- /dev/null +++ b/livekit/pkg/service/agent_dispatch_service.go @@ -0,0 +1,111 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "fmt" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" +) + +type AgentDispatchService struct { + agentDispatchClient rpc.TypedAgentDispatchInternalClient + topicFormatter rpc.TopicFormatter + roomAllocator RoomAllocator + router routing.MessageRouter +} + +func NewAgentDispatchService( + agentDispatchClient rpc.TypedAgentDispatchInternalClient, + topicFormatter rpc.TopicFormatter, + roomAllocator RoomAllocator, + router routing.MessageRouter, +) *AgentDispatchService { + return &AgentDispatchService{ + agentDispatchClient: agentDispatchClient, + topicFormatter: topicFormatter, + roomAllocator: roomAllocator, + router: router, + } +} + +func (ag *AgentDispatchService) CreateDispatch(ctx context.Context, req *livekit.CreateAgentDispatchRequest) (*livekit.AgentDispatch, error) { + AppendLogFields(ctx, "room", req.Room, "request", logger.Proto(redactCreateAgentDispatchRequest(req))) + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + if ag.roomAllocator.AutoCreateEnabled(ctx) { + err := ag.roomAllocator.SelectRoomNode(ctx, livekit.RoomName(req.Room), "") + if err != nil { + return nil, err + } + + _, err = ag.router.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: req.Room}) + if err != nil { + return nil, err + } + } + + dispatch := &livekit.AgentDispatch{ + Id: guid.New(guid.AgentDispatchPrefix), + AgentName: req.AgentName, + Room: req.Room, + Metadata: req.Metadata, + } + return ag.agentDispatchClient.CreateDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), dispatch) +} + +func (ag *AgentDispatchService) DeleteDispatch(ctx context.Context, req *livekit.DeleteAgentDispatchRequest) (*livekit.AgentDispatch, error) { + AppendLogFields(ctx, "room", req.Room, "request", logger.Proto(req)) + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.DeleteDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} + +func (ag *AgentDispatchService) ListDispatch(ctx context.Context, req *livekit.ListAgentDispatchRequest) (*livekit.ListAgentDispatchResponse, error) { + AppendLogFields(ctx, "room", req.Room, "request", logger.Proto(req)) + err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, twirpAuthError(err) + } + + return ag.agentDispatchClient.ListDispatch(ctx, ag.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) +} + +func redactCreateAgentDispatchRequest(req *livekit.CreateAgentDispatchRequest) *livekit.CreateAgentDispatchRequest { + if req.Metadata == "" { + return req + } + + clone := utils.CloneProto(req) + + // replace with size of metadata to provide visibility on request size + if clone.Metadata != "" { + clone.Metadata = fmt.Sprintf("__size: %d", len(clone.Metadata)) + } + + return clone +} diff --git a/livekit/pkg/service/agentservice.go b/livekit/pkg/service/agentservice.go new file mode 100644 index 0000000..b4fc76d --- /dev/null +++ b/livekit/pkg/service/agentservice.go @@ -0,0 +1,548 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "math/rand" + "net/http" + "slices" + "sort" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/version" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/psrpc" +) + +type AgentSocketUpgrader struct { + websocket.Upgrader +} + +func (u AgentSocketUpgrader) Upgrade( + w http.ResponseWriter, + r *http.Request, + responseHeader http.Header, +) ( + conn *websocket.Conn, + registration agent.WorkerRegistration, + ok bool, +) { + if u.CheckOrigin == nil { + // allow connections from any origin, since script may be hosted anywhere + // security is enforced by access tokens + u.CheckOrigin = func(r *http.Request) bool { + return true + } + } + + // reject non websocket requests + if !websocket.IsWebSocketUpgrade(r) { + w.WriteHeader(404) + return + } + + // require a claim + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil || !claims.Video.Agent { + HandleError(w, r, http.StatusUnauthorized, rtc.ErrPermissionDenied) + return + } + + registration = agent.MakeWorkerRegistration() + registration.ClientIP = GetClientIP(r) + + // upgrade + conn, err := u.Upgrader.Upgrade(w, r, responseHeader) + if err != nil { + HandleError(w, r, http.StatusInternalServerError, err) + return + } + + if pv, err := strconv.Atoi(r.FormValue("protocol")); err == nil { + registration.Protocol = agent.WorkerProtocolVersion(pv) + } + + return conn, registration, true +} + +func DispatchAgentWorkerSignal(c agent.SignalConn, h agent.WorkerSignalHandler, l logger.Logger) bool { + req, _, err := c.ReadWorkerMessage() + if err != nil { + if IsWebSocketCloseError(err) { + l.Debugw("worker closed WS connection", "wsError", err) + } else { + l.Errorw("error reading from websocket", err) + } + return false + } + + if err := agent.DispatchWorkerSignal(req, h); err != nil { + l.Warnw("unable to handle worker signal", err, "req", logger.Proto(req)) + return false + } + + return true +} + +func HandshakeAgentWorker(c agent.SignalConn, serverInfo *livekit.ServerInfo, registration agent.WorkerRegistration, l logger.Logger) (r agent.WorkerRegistration, ok bool) { + wr := agent.NewWorkerRegisterer(c, serverInfo, registration) + if err := c.SetReadDeadline(wr.Deadline()); err != nil { + return + } + for !wr.Registered() { + if ok = DispatchAgentWorkerSignal(c, wr, l); !ok { + return + } + } + if err := c.SetReadDeadline(time.Time{}); err != nil { + return + } + return wr.Registration(), true +} + +type AgentService struct { + upgrader AgentSocketUpgrader + + *AgentHandler +} + +type AgentHandler struct { + agentServer rpc.AgentInternalServer + mu sync.Mutex + logger logger.Logger + + serverInfo *livekit.ServerInfo + workers map[string]*agent.Worker + jobToWorker map[livekit.JobID]*agent.Worker + keyProvider auth.KeyProvider + + namespaceWorkers map[workerKey][]*agent.Worker + roomKeyCount int + publisherKeyCount int + participantKeyCount int + namespaces []string // namespaces deprecated + agentNames []string + + roomTopic string + publisherTopic string + participantTopic string +} + +type workerKey struct { + agentName string + namespace string + jobType livekit.JobType +} + +func NewAgentService( + conf *config.Config, + currentNode routing.LocalNode, + bus psrpc.MessageBus, + keyProvider auth.KeyProvider, +) (*AgentService, error) { + s := &AgentService{} + + serverInfo := &livekit.ServerInfo{ + Edition: livekit.ServerInfo_Standard, + Version: version.Version, + Protocol: types.CurrentProtocol, + AgentProtocol: agent.CurrentProtocol, + Region: conf.Region, + NodeId: string(currentNode.NodeID()), + } + + agentServer, err := rpc.NewAgentInternalServer(s, bus) + if err != nil { + return nil, err + } + s.AgentHandler = NewAgentHandler( + agentServer, + keyProvider, + logger.GetLogger(), + serverInfo, + agent.RoomAgentTopic, + agent.PublisherAgentTopic, + agent.ParticipantAgentTopic, + ) + return s, nil +} + +func (s *AgentService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if conn, registration, ok := s.upgrader.Upgrade(w, r, nil); ok { + s.HandleConnection(r.Context(), NewWSSignalConnection(conn), registration) + conn.Close() + } +} + +func NewAgentHandler( + agentServer rpc.AgentInternalServer, + keyProvider auth.KeyProvider, + logger logger.Logger, + serverInfo *livekit.ServerInfo, + roomTopic string, + publisherTopic string, + participantTopic string, +) *AgentHandler { + return &AgentHandler{ + agentServer: agentServer, + logger: logger.WithComponent("agents"), + workers: make(map[string]*agent.Worker), + jobToWorker: make(map[livekit.JobID]*agent.Worker), + namespaceWorkers: make(map[workerKey][]*agent.Worker), + serverInfo: serverInfo, + keyProvider: keyProvider, + roomTopic: roomTopic, + publisherTopic: publisherTopic, + participantTopic: participantTopic, + } +} + +func (h *AgentHandler) HandleConnection(ctx context.Context, conn agent.SignalConn, registration agent.WorkerRegistration) { + registration, ok := HandshakeAgentWorker(conn, h.serverInfo, registration, h.logger) + if !ok { + return + } + + apiKey := GetAPIKey(ctx) + apiSecret := h.keyProvider.GetSecret(apiKey) + + worker := agent.NewWorker(registration, apiKey, apiSecret, conn, h.logger) + h.registerWorker(worker) + + handlerWorker := &agentHandlerWorker{h, worker} + for ok := true; ok; { + ok = DispatchAgentWorkerSignal(conn, handlerWorker, worker.Logger()) + } + + h.deregisterWorker(worker) + worker.Close() +} + +func (h *AgentHandler) registerWorker(w *agent.Worker) { + h.mu.Lock() + + h.workers[w.ID] = w + + key := workerKey{w.AgentName, w.Namespace, w.JobType} + + workers := h.namespaceWorkers[key] + created := len(workers) == 0 + + if created { + nameTopic := agent.GetAgentTopic(w.AgentName, w.Namespace) + var typeTopic string + switch w.JobType { + case livekit.JobType_JT_ROOM: + typeTopic = h.roomTopic + case livekit.JobType_JT_PUBLISHER: + typeTopic = h.publisherTopic + case livekit.JobType_JT_PARTICIPANT: + typeTopic = h.participantTopic + } + + err := h.agentServer.RegisterJobRequestTopic(nameTopic, typeTopic) + if err != nil { + h.mu.Unlock() + + w.Logger().Errorw("failed to register job request topic", err) + w.Close() + return + } + + switch w.JobType { + case livekit.JobType_JT_ROOM: + h.roomKeyCount++ + case livekit.JobType_JT_PUBLISHER: + h.publisherKeyCount++ + case livekit.JobType_JT_PARTICIPANT: + h.participantKeyCount++ + } + + h.namespaces = append(h.namespaces, w.Namespace) + sort.Strings(h.namespaces) + h.agentNames = append(h.agentNames, w.AgentName) + sort.Strings(h.agentNames) + } + + h.namespaceWorkers[key] = append(workers, w) + h.mu.Unlock() + + h.logger.Infow("worker registered", + "namespace", w.Namespace, + "jobType", w.JobType, + "agentName", w.AgentName, + "workerID", w.ID, + ) + if created { + err := h.agentServer.PublishWorkerRegistered(context.Background(), agent.DefaultHandlerNamespace, &emptypb.Empty{}) + // TODO: when this happens, should we disconnect the worker so it'll retry? + if err != nil { + w.Logger().Errorw("failed to publish worker registered", err, "namespace", w.Namespace, "jobType", w.JobType, "agentName", w.AgentName) + } + } +} + +func (h *AgentHandler) deregisterWorker(w *agent.Worker) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.workers, w.ID) + + key := workerKey{w.AgentName, w.Namespace, w.JobType} + + workers, ok := h.namespaceWorkers[key] + if !ok { + return + } + index := slices.Index(workers, w) + if index == -1 { + return + } + + if len(workers) > 1 { + h.namespaceWorkers[key] = slices.Delete(workers, index, index+1) + } else { + h.logger.Infow("last worker deregistered", + "namespace", w.Namespace, + "jobType", w.JobType, + "agentName", w.AgentName, + "workerID", w.ID, + ) + delete(h.namespaceWorkers, key) + + topic := agent.GetAgentTopic(w.AgentName, w.Namespace) + + switch w.JobType { + case livekit.JobType_JT_ROOM: + h.roomKeyCount-- + h.agentServer.DeregisterJobRequestTopic(topic, h.roomTopic) + case livekit.JobType_JT_PUBLISHER: + h.publisherKeyCount-- + h.agentServer.DeregisterJobRequestTopic(topic, h.publisherTopic) + case livekit.JobType_JT_PARTICIPANT: + h.participantKeyCount-- + h.agentServer.DeregisterJobRequestTopic(topic, h.participantTopic) + } + + // agentNames and namespaces contains repeated entries for each agentNames/namespaces combinations + if i := slices.Index(h.namespaces, w.Namespace); i != -1 { + h.namespaces = slices.Delete(h.namespaces, i, i+1) + } + if i := slices.Index(h.agentNames, w.AgentName); i != -1 { + h.agentNames = slices.Delete(h.agentNames, i, i+1) + } + } + + jobs := w.RunningJobs() + for jobID := range jobs { + h.deregisterJob(jobID) + } +} + +func (h *AgentHandler) deregisterJob(jobID livekit.JobID) { + h.agentServer.DeregisterJobTerminateTopic(string(jobID)) + + delete(h.jobToWorker, jobID) + + // TODO update dispatch state +} + +func (h *AgentHandler) JobRequest(ctx context.Context, job *livekit.Job) (*rpc.JobRequestResponse, error) { + logger := h.logger.WithUnlikelyValues( + "jobID", job.Id, + "namespace", job.Namespace, + "agentName", job.AgentName, + "jobType", job.Type.String(), + ) + if job.Room != nil { + logger = logger.WithValues("room", job.Room.Name, "roomID", job.Room.Sid) + } + if job.Participant != nil { + logger = logger.WithValues("participant", job.Participant.Identity) + } + + key := workerKey{job.AgentName, job.Namespace, job.Type} + attempted := make(map[*agent.Worker]struct{}) + for { + selected, err := h.selectWorkerWeightedByLoad(key, attempted) + if err != nil { + logger.Warnw("no worker available to handle job", err) + return nil, psrpc.NewError(psrpc.ResourceExhausted, err) + } + + logger := logger.WithValues("workerID", selected.ID) + attempted[selected] = struct{}{} + + state, err := selected.AssignJob(ctx, job) + switch state.GetStatus() { + case livekit.JobStatus_JS_RUNNING: + logger.Infow("assigned job to worker") + h.mu.Lock() + h.jobToWorker[livekit.JobID(job.Id)] = selected + h.mu.Unlock() + + err = h.agentServer.RegisterJobTerminateTopic(job.Id) + if err != nil { + logger.Errorw("failed to register JobTerminate handler", err) + } + fallthrough + case livekit.JobStatus_JS_SUCCESS: + return &rpc.JobRequestResponse{ + State: state, + }, nil + default: + retry := utils.ErrorIsOneOf(err, agent.ErrWorkerNotAvailable, agent.ErrWorkerClosed) + logger.Warnw("failed to assign job to worker", err, "retry", retry) + if !retry { + return nil, err + } + } + } +} + +func (h *AgentHandler) JobRequestAffinity(ctx context.Context, job *livekit.Job) float32 { + h.mu.Lock() + defer h.mu.Unlock() + + var affinity float32 + for _, w := range h.workers { + if w.AgentName != job.AgentName || w.Namespace != job.Namespace || w.JobType != job.Type { + continue + } + + if w.Status() == livekit.WorkerStatus_WS_AVAILABLE { + affinity += max(0, 1-w.Load()) + } + } + + return affinity +} + +func (h *AgentHandler) JobTerminate(ctx context.Context, req *rpc.JobTerminateRequest) (*rpc.JobTerminateResponse, error) { + h.mu.Lock() + w := h.jobToWorker[livekit.JobID(req.JobId)] + h.mu.Unlock() + + if w == nil { + return nil, psrpc.NewErrorf(psrpc.NotFound, "no worker for jobID") + } + + state, err := w.TerminateJob(livekit.JobID(req.JobId), req.Reason) + if err != nil { + return nil, err + } + + return &rpc.JobTerminateResponse{ + State: state, + }, nil +} + +func (h *AgentHandler) CheckEnabled(ctx context.Context, req *rpc.CheckEnabledRequest) (*rpc.CheckEnabledResponse, error) { + h.mu.Lock() + defer h.mu.Unlock() + + // This doesn't return the full agentName -> namespace mapping, which can cause some unnecessary RPC. + // namespaces are however deprecated. + return &rpc.CheckEnabledResponse{ + Namespaces: slices.Compact(slices.Clone(h.namespaces)), + AgentNames: slices.Compact(slices.Clone(h.agentNames)), + RoomEnabled: h.roomKeyCount != 0, + PublisherEnabled: h.publisherKeyCount != 0, + ParticipantEnabled: h.participantKeyCount != 0, + }, nil +} + +func (h *AgentHandler) DrainConnections(interval time.Duration) { + // jitter drain start + time.Sleep(time.Duration(rand.Int63n(int64(interval)))) + + t := time.NewTicker(interval) + defer t.Stop() + + h.mu.Lock() + defer h.mu.Unlock() + + for _, w := range h.workers { + w.Close() + <-t.C + } +} + +func (h *AgentHandler) selectWorkerWeightedByLoad(key workerKey, ignore map[*agent.Worker]struct{}) (*agent.Worker, error) { + h.mu.Lock() + defer h.mu.Unlock() + + workers, ok := h.namespaceWorkers[key] + if !ok { + return nil, errors.New("no workers available") + } + + normalizedLoads := make(map[*agent.Worker]float32) + var availableSum float32 + for _, w := range workers { + if _, ok := ignore[w]; !ok && w.Status() == livekit.WorkerStatus_WS_AVAILABLE { + normalizedLoads[w] = max(0, 1-w.Load()) + availableSum += normalizedLoads[w] + } + } + + if availableSum == 0 { + return nil, errors.New("no workers with sufficient capacity") + } + + currentSum := rand.Float32() * availableSum + for w, load := range normalizedLoads { + if currentSum -= load; currentSum <= 0 { + return w, nil + } + } + return workers[0], nil +} + +var _ agent.WorkerSignalHandler = (*agentHandlerWorker)(nil) + +type agentHandlerWorker struct { + h *AgentHandler + *agent.Worker +} + +func (w *agentHandlerWorker) HandleUpdateJob(update *livekit.UpdateJobStatus) error { + if err := w.Worker.HandleUpdateJob(update); err != nil { + return err + } + + if agent.JobStatusIsEnded(update.Status) { + w.h.mu.Lock() + w.h.deregisterJob(livekit.JobID(update.JobId)) + w.h.mu.Unlock() + } + return nil +} diff --git a/livekit/pkg/service/auth.go b/livekit/pkg/service/auth.go new file mode 100644 index 0000000..f8e6442 --- /dev/null +++ b/livekit/pkg/service/auth.go @@ -0,0 +1,238 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/twitchtv/twirp" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" +) + +const ( + authorizationHeader = "Authorization" + bearerPrefix = "Bearer " + accessTokenParam = "access_token" +) + +type grantsKey struct{} + +type grantsValue struct { + claims *auth.ClaimGrants + apiKey string +} + +var ( + ErrPermissionDenied = errors.New("permissions denied") + ErrMissingAuthorization = errors.New("invalid authorization header. Must start with " + bearerPrefix) + ErrInvalidAuthorizationToken = errors.New("invalid authorization token") + ErrInvalidAPIKey = errors.New("invalid API key") +) + +// authentication middleware +type APIKeyAuthMiddleware struct { + provider auth.KeyProvider +} + +func NewAPIKeyAuthMiddleware(provider auth.KeyProvider) *APIKeyAuthMiddleware { + return &APIKeyAuthMiddleware{ + provider: provider, + } +} + +func (m *APIKeyAuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if r.URL != nil && (r.URL.Path == "/rtc/validate" || r.URL.Path == "/rtc/v1/validate") { + w.Header().Set("Access-Control-Allow-Origin", "*") + } + + authHeader := r.Header.Get(authorizationHeader) + var authToken string + + if authHeader != "" { + if !strings.HasPrefix(authHeader, bearerPrefix) { + HandleError(w, r, http.StatusUnauthorized, ErrMissingAuthorization) + return + } + + authToken = authHeader[len(bearerPrefix):] + } else { + // attempt to find from request header + authToken = r.FormValue(accessTokenParam) + } + + if authToken != "" { + v, err := auth.ParseAPIToken(authToken) + if err != nil { + HandleError(w, r, http.StatusUnauthorized, ErrInvalidAuthorizationToken) + return + } + + secret := m.provider.GetSecret(v.APIKey()) + if secret == "" { + HandleError(w, r, http.StatusUnauthorized, errors.New("invalid API key: "+v.APIKey())) + return + } + + _, grants, err := v.Verify(secret) + if err != nil { + HandleError(w, r, http.StatusUnauthorized, errors.New("invalid token: "+authToken+", error: "+err.Error())) + return + } + + // set grants in context + ctx := r.Context() + r = r.WithContext(context.WithValue(ctx, grantsKey{}, &grantsValue{ + claims: grants, + apiKey: v.APIKey(), + })) + } + + next.ServeHTTP(w, r) +} + +func WithAPIKey(ctx context.Context, grants *auth.ClaimGrants, apiKey string) context.Context { + return context.WithValue(ctx, grantsKey{}, &grantsValue{ + claims: grants, + apiKey: apiKey, + }) +} + +func GetGrants(ctx context.Context) *auth.ClaimGrants { + val := ctx.Value(grantsKey{}) + v, ok := val.(*grantsValue) + if !ok { + return nil + } + return v.claims +} + +func GetAPIKey(ctx context.Context) string { + val := ctx.Value(grantsKey{}) + v, ok := val.(*grantsValue) + if !ok { + return "" + } + return v.apiKey +} + +func WithGrants(ctx context.Context, grants *auth.ClaimGrants, apiKey string) context.Context { + return context.WithValue(ctx, grantsKey{}, &grantsValue{ + claims: grants, + apiKey: apiKey, + }) +} + +func SetAuthorizationToken(r *http.Request, token string) { + r.Header.Set(authorizationHeader, bearerPrefix+token) +} + +func EnsureJoinPermission(ctx context.Context) (name livekit.RoomName, err error) { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil { + err = ErrPermissionDenied + return + } + + if claims.Video.RoomJoin { + name = livekit.RoomName(claims.Video.Room) + } else { + err = ErrPermissionDenied + } + return +} + +func EnsureAdminPermission(ctx context.Context, room livekit.RoomName) error { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil { + return ErrPermissionDenied + } + + if !claims.Video.RoomAdmin || room != livekit.RoomName(claims.Video.Room) { + return ErrPermissionDenied + } + + return nil +} + +func EnsureCreatePermission(ctx context.Context) error { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil || !claims.Video.RoomCreate { + return ErrPermissionDenied + } + return nil +} + +func EnsureListPermission(ctx context.Context) error { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil || !claims.Video.RoomList { + return ErrPermissionDenied + } + return nil +} + +func EnsureRecordPermission(ctx context.Context) error { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil || !claims.Video.RoomRecord { + return ErrPermissionDenied + } + return nil +} + +func EnsureIngressAdminPermission(ctx context.Context) error { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil || !claims.Video.IngressAdmin { + return ErrPermissionDenied + } + return nil +} + +func EnsureSIPAdminPermission(ctx context.Context) error { + claims := GetGrants(ctx) + if claims == nil || claims.SIP == nil || !claims.SIP.Admin { + return ErrPermissionDenied + } + return nil +} + +func EnsureSIPCallPermission(ctx context.Context) error { + claims := GetGrants(ctx) + if claims == nil || claims.SIP == nil || !claims.SIP.Call { + return ErrPermissionDenied + } + return nil +} + +func EnsureDestRoomPermission(ctx context.Context, source livekit.RoomName, destination livekit.RoomName) error { + claims := GetGrants(ctx) + if claims == nil || claims.Video == nil { + return ErrPermissionDenied + } + + if !claims.Video.RoomAdmin || source != livekit.RoomName(claims.Video.Room) || destination != livekit.RoomName(claims.Video.DestinationRoom) { + return ErrPermissionDenied + } + + return nil +} + +// wraps authentication errors around Twirp +func twirpAuthError(err error) error { + return twirp.NewError(twirp.Unauthenticated, err.Error()) +} diff --git a/livekit/pkg/service/auth_test.go b/livekit/pkg/service/auth_test.go new file mode 100644 index 0000000..00b89bc --- /dev/null +++ b/livekit/pkg/service/auth_test.go @@ -0,0 +1,74 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/auth/authfakes" + + "github.com/livekit/livekit-server/pkg/service" +) + +func TestAuthMiddleware(t *testing.T) { + api := "APIabcdefg" + secret := "somesecretencodedinbase62extendto32bytes" + provider := &authfakes.FakeKeyProvider{} + provider.GetSecretReturns(secret) + + m := service.NewAPIKeyAuthMiddleware(provider) + var grants *auth.ClaimGrants + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + grants = service.GetGrants(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + orig := &auth.VideoGrant{Room: "abcdefg", RoomJoin: true} + // ensure that the original claim could be retrieved + at := auth.NewAccessToken(api, secret). + AddGrant(orig) + token, err := at.ToJWT() + require.NoError(t, err) + + r := &http.Request{Header: http.Header{}} + w := httptest.NewRecorder() + service.SetAuthorizationToken(r, token) + m.ServeHTTP(w, r, handler) + + require.NotNil(t, grants) + require.EqualValues(t, orig, grants.Video) + + // no authorization == no claims + grants = nil + w = httptest.NewRecorder() + r = &http.Request{Header: http.Header{}} + m.ServeHTTP(w, r, handler) + require.Nil(t, grants) + require.Equal(t, http.StatusOK, w.Code) + + // incorrect authorization: error + grants = nil + w = httptest.NewRecorder() + r = &http.Request{Header: http.Header{}} + service.SetAuthorizationToken(r, "invalid token") + m.ServeHTTP(w, r, handler) + require.Nil(t, grants) + require.Equal(t, http.StatusUnauthorized, w.Code) +} diff --git a/livekit/pkg/service/basic_auth.go b/livekit/pkg/service/basic_auth.go new file mode 100644 index 0000000..40173c6 --- /dev/null +++ b/livekit/pkg/service/basic_auth.go @@ -0,0 +1,31 @@ +package service + +import ( + "net/http" +) + +func GenBasicAuthMiddleware(username string, password string) (func(http.ResponseWriter, *http.Request, http.HandlerFunc) ) { + return func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + given_username, given_password, ok := r.BasicAuth() + unauthorized := func() { + rw.Header().Set("WWW-Authenticate", "Basic realm=\"Protected Area\"") + rw.WriteHeader(http.StatusUnauthorized) + } + if !ok { + unauthorized() + return + } + + if given_username != username { + unauthorized() + return + } + + if given_password != password { + unauthorized() + return + } + + next(rw, r) + } +} \ No newline at end of file diff --git a/livekit/pkg/service/clients.go b/livekit/pkg/service/clients.go new file mode 100644 index 0000000..d8fb337 --- /dev/null +++ b/livekit/pkg/service/clients.go @@ -0,0 +1,33 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" +) + +//counterfeiter:generate . IOClient +type IOClient interface { + CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) + GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) + ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) + CreateIngress(ctx context.Context, req *livekit.IngressInfo) (*emptypb.Empty, error) + UpdateIngressState(ctx context.Context, req *rpc.UpdateIngressStateRequest) (*emptypb.Empty, error) +} diff --git a/livekit/pkg/service/docker_test.go b/livekit/pkg/service/docker_test.go new file mode 100644 index 0000000..a3da604 --- /dev/null +++ b/livekit/pkg/service/docker_test.go @@ -0,0 +1,80 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "fmt" + "log" + "net" + "os" + "testing" + + "go.uber.org/atomic" + + "github.com/ory/dockertest/v3" +) + +var Docker *dockertest.Pool + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not construct pool: %s", err) + } + + // uses pool to try to connect to Docker + err = pool.Client.Ping() + if err != nil { + log.Fatalf("Could not connect to Docker: %s", err) + } + Docker = pool + + code := m.Run() + os.Exit(code) +} + +func waitTCPPort(t testing.TB, addr string) { + if err := Docker.Retry(func() error { + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Log(err) + return err + } + _ = conn.Close() + return nil + }); err != nil { + t.Fatal(err) + } +} + +var redisLast atomic.Uint32 + +func runRedis(t testing.TB) string { + c, err := Docker.RunWithOptions(&dockertest.RunOptions{ + Name: fmt.Sprintf("lktest-redis-%d", redisLast.Inc()), + Repository: "redis", Tag: "latest", + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = Docker.Purge(c) + }) + addr := c.GetHostPort("6379/tcp") + waitTCPPort(t, addr) + + t.Log("Redis running on", addr) + return addr +} diff --git a/livekit/pkg/service/egress.go b/livekit/pkg/service/egress.go new file mode 100644 index 0000000..2718fb2 --- /dev/null +++ b/livekit/pkg/service/egress.go @@ -0,0 +1,349 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + + "github.com/twitchtv/twirp" + + "github.com/livekit/protocol/egress" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/rtc" +) + +type EgressService struct { + launcher rtc.EgressLauncher + client rpc.EgressClient + io IOClient + roomService livekit.RoomService +} + +type egressLauncher struct { + client rpc.EgressClient + io IOClient + store ServiceStore +} + +func NewEgressService( + client rpc.EgressClient, + launcher rtc.EgressLauncher, + io IOClient, + rs livekit.RoomService, +) *EgressService { + return &EgressService{ + client: client, + io: io, + roomService: rs, + launcher: launcher, + } +} + +func NewEgressLauncher(client rpc.EgressClient, io IOClient, store ServiceStore) rtc.EgressLauncher { + if client == nil { + return nil + } + return &egressLauncher{ + client: client, + io: io, + store: store, + } +} + +func (s *EgressService) StartRoomCompositeEgress(ctx context.Context, req *livekit.RoomCompositeEgressRequest) (*livekit.EgressInfo, error) { + fields := []any{ + "room", req.RoomName, + "baseUrl", req.CustomBaseUrl, + "outputType", egress.GetOutputType(req), + } + defer func() { + AppendLogFields(ctx, fields...) + }() + ei, err := s.startEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_RoomComposite{ + RoomComposite: req, + }, + }) + if err != nil { + return nil, err + } + fields = append(fields, "egressID", ei.EgressId) + return ei, err +} + +func (s *EgressService) StartWebEgress(ctx context.Context, req *livekit.WebEgressRequest) (*livekit.EgressInfo, error) { + fields := []any{ + "url", req.Url, + "outputType", egress.GetOutputType(req), + } + defer func() { + AppendLogFields(ctx, fields...) + }() + ei, err := s.startEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_Web{ + Web: req, + }, + }) + if err != nil { + return nil, err + } + fields = append(fields, "egressID", ei.EgressId) + return ei, err +} + +func (s *EgressService) StartParticipantEgress(ctx context.Context, req *livekit.ParticipantEgressRequest) (*livekit.EgressInfo, error) { + fields := []any{ + "room", req.RoomName, + "identity", req.Identity, + "outputType", egress.GetOutputType(req), + } + defer func() { + AppendLogFields(ctx, fields...) + }() + ei, err := s.startEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_Participant{ + Participant: req, + }, + }) + if err != nil { + return nil, err + } + fields = append(fields, "egressID", ei.EgressId) + return ei, err +} + +func (s *EgressService) StartTrackCompositeEgress(ctx context.Context, req *livekit.TrackCompositeEgressRequest) (*livekit.EgressInfo, error) { + fields := []any{ + "room", req.RoomName, + "audioTrackID", req.AudioTrackId, + "videoTrackID", req.VideoTrackId, + "outputType", egress.GetOutputType(req), + } + defer func() { + AppendLogFields(ctx, fields...) + }() + ei, err := s.startEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_TrackComposite{ + TrackComposite: req, + }, + }) + if err != nil { + return nil, err + } + fields = append(fields, "egressID", ei.EgressId) + return ei, err +} + +func (s *EgressService) StartTrackEgress(ctx context.Context, req *livekit.TrackEgressRequest) (*livekit.EgressInfo, error) { + fields := []any{"room", req.RoomName, "trackID", req.TrackId} + if t := reflect.TypeOf(req.Output); t != nil { + fields = append(fields, "outputType", t.String()) + } + defer func() { + AppendLogFields(ctx, fields...) + }() + ei, err := s.startEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_Track{ + Track: req, + }, + }) + if err != nil { + return nil, err + } + fields = append(fields, "egressID", ei.EgressId) + return ei, err +} + +func (s *EgressService) startEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { + if err := EnsureRecordPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } else if s.launcher == nil { + return nil, ErrEgressNotConnected + } + + return s.launcher.StartEgress(ctx, req) +} + +func (s *egressLauncher) StartEgress(ctx context.Context, req *rpc.StartEgressRequest) (*livekit.EgressInfo, error) { + if s.client == nil { + return nil, ErrEgressNotConnected + } + + // Ensure we have an Egress ID + if req.EgressId == "" { + req.EgressId = guid.New(utils.EgressPrefix) + } + + if req.RoomId == "" { + var roomName string + switch v := req.Request.(type) { + case *rpc.StartEgressRequest_RoomComposite: + roomName = v.RoomComposite.RoomName + case *rpc.StartEgressRequest_Web: + // no room name + case *rpc.StartEgressRequest_Participant: + roomName = v.Participant.RoomName + case *rpc.StartEgressRequest_TrackComposite: + roomName = v.TrackComposite.RoomName + case *rpc.StartEgressRequest_Track: + roomName = v.Track.RoomName + } + + if roomName != "" { + room, _, err := s.store.LoadRoom(ctx, livekit.RoomName(roomName), false) + if err != nil { + return nil, err + } + req.RoomId = room.Sid + } + } + + info, err := s.client.StartEgress(ctx, "", req) + if err != nil { + return nil, err + } + + _, err = s.io.CreateEgress(ctx, info) + if err != nil { + logger.Errorw("failed to create egress", err) + } + + return info, nil +} + +type LayoutMetadata struct { + Layout string `json:"layout"` +} + +func (s *EgressService) UpdateLayout(ctx context.Context, req *livekit.UpdateLayoutRequest) (*livekit.EgressInfo, error) { + AppendLogFields(ctx, "egressID", req.EgressId, "layout", req.Layout) + if err := EnsureRecordPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + + info, err := s.io.GetEgress(ctx, &rpc.GetEgressRequest{EgressId: req.EgressId}) + if err != nil { + return nil, err + } + + metadata, err := json.Marshal(&LayoutMetadata{Layout: req.Layout}) + if err != nil { + return nil, err + } + + grants := GetGrants(ctx) + grants.Video.Room = info.RoomName + grants.Video.RoomAdmin = true + + _, err = s.roomService.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: info.RoomName, + Identity: info.EgressId, + Metadata: string(metadata), + }) + if err != nil { + return nil, err + } + + return info, nil +} + +func (s *EgressService) UpdateStream(ctx context.Context, req *livekit.UpdateStreamRequest) (*livekit.EgressInfo, error) { + AppendLogFields(ctx, "egressID", req.EgressId, "addUrls", req.AddOutputUrls, "removeUrls", req.RemoveOutputUrls) + if err := EnsureRecordPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + + if s.client == nil { + return nil, ErrEgressNotConnected + } + + info, err := s.client.UpdateStream(ctx, req.EgressId, req) + if err != nil { + var loadErr error + info, loadErr = s.io.GetEgress(ctx, &rpc.GetEgressRequest{EgressId: req.EgressId}) + if loadErr != nil { + return nil, loadErr + } + + switch info.Status { + case livekit.EgressStatus_EGRESS_STARTING, + livekit.EgressStatus_EGRESS_ACTIVE: + return nil, err + default: + return nil, twirp.NewError(twirp.FailedPrecondition, + fmt.Sprintf("egress with status %s cannot be updated", info.Status.String())) + } + } + + return info, nil +} + +func (s *EgressService) ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) { + if req.RoomName != "" { + AppendLogFields(ctx, "room", req.RoomName) + } + if err := EnsureRecordPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + return s.io.ListEgress(ctx, req) +} + +func (s *EgressService) StopEgress(ctx context.Context, req *livekit.StopEgressRequest) (info *livekit.EgressInfo, err error) { + defer func() { + if errors.Is(err, psrpc.ErrNoResponse) { + // Do not map cases where the context times out to 503 + err = psrpc.ErrRequestTimedOut + } + }() + + AppendLogFields(ctx, "egressID", req.EgressId) + if err := EnsureRecordPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + + if s.client == nil { + return nil, ErrEgressNotConnected + } + + info, err = s.client.StopEgress(ctx, req.EgressId, req) + if err != nil { + var loadErr error + info, loadErr = s.io.GetEgress(ctx, &rpc.GetEgressRequest{EgressId: req.EgressId}) + if loadErr != nil { + return nil, loadErr + } + + switch info.Status { + case livekit.EgressStatus_EGRESS_STARTING, + livekit.EgressStatus_EGRESS_ACTIVE: + return nil, err + default: + return nil, twirp.NewError(twirp.FailedPrecondition, + fmt.Sprintf("egress with status %s cannot be stopped", info.Status.String())) + } + } + + return info, nil +} diff --git a/livekit/pkg/service/errors.go b/livekit/pkg/service/errors.go new file mode 100644 index 0000000..7d07390 --- /dev/null +++ b/livekit/pkg/service/errors.go @@ -0,0 +1,52 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "github.com/livekit/psrpc" +) + +var ( + ErrEgressNotFound = psrpc.NewErrorf(psrpc.NotFound, "egress does not exist") + ErrEgressNotConnected = psrpc.NewErrorf(psrpc.Internal, "egress not connected (redis required)") + ErrIdentityEmpty = psrpc.NewErrorf(psrpc.InvalidArgument, "identity cannot be empty") + ErrParticipantSidEmpty = psrpc.NewErrorf(psrpc.InvalidArgument, "participant sid cannot be empty") + ErrIngressNotConnected = psrpc.NewErrorf(psrpc.Internal, "ingress not connected (redis required)") + ErrIngressNotFound = psrpc.NewErrorf(psrpc.NotFound, "ingress does not exist") + ErrIngressNonReusable = psrpc.NewErrorf(psrpc.InvalidArgument, "ingress is not reusable and cannot be modified") + ErrNameExceedsLimits = psrpc.NewErrorf(psrpc.InvalidArgument, "name length exceeds limits") + ErrMetadataExceedsLimits = psrpc.NewErrorf(psrpc.InvalidArgument, "metadata size exceeds limits") + ErrAttributeExceedsLimits = psrpc.NewErrorf(psrpc.InvalidArgument, "attribute size exceeds limits") + ErrNoRoomName = psrpc.NewErrorf(psrpc.InvalidArgument, "no room name") + ErrRoomNameExceedsLimits = psrpc.NewErrorf(psrpc.InvalidArgument, "room name length exceeds limits") + ErrParticipantIdentityExceedsLimits = psrpc.NewErrorf(psrpc.InvalidArgument, "participant identity length exceeds limits") + ErrDestinationSameAsSourceRoom = psrpc.NewErrorf(psrpc.InvalidArgument, "destination room cannot be the same as source room") + ErrOperationFailed = psrpc.NewErrorf(psrpc.Internal, "operation cannot be completed") + ErrParticipantNotFound = psrpc.NewErrorf(psrpc.NotFound, "participant does not exist") + ErrRoomNotFound = psrpc.NewErrorf(psrpc.NotFound, "requested room does not exist") + ErrRoomLockFailed = psrpc.NewErrorf(psrpc.Internal, "could not lock room") + ErrRoomUnlockFailed = psrpc.NewErrorf(psrpc.Internal, "could not unlock room, lock token does not match") + ErrRemoteUnmuteNoteEnabled = psrpc.NewErrorf(psrpc.FailedPrecondition, "remote unmute not enabled") + ErrTrackNotFound = psrpc.NewErrorf(psrpc.NotFound, "track is not found") + ErrWebHookMissingAPIKey = psrpc.NewErrorf(psrpc.InvalidArgument, "api_key is required to use webhooks") + ErrSIPNotConnected = psrpc.NewErrorf(psrpc.Internal, "sip not connected (redis required)") + ErrSIPTrunkNotFound = psrpc.NewErrorf(psrpc.NotFound, "requested sip trunk does not exist") + ErrSIPDispatchRuleNotFound = psrpc.NewErrorf(psrpc.NotFound, "requested sip dispatch rule does not exist") + ErrSIPParticipantNotFound = psrpc.NewErrorf(psrpc.NotFound, "requested sip participant does not exist") + ErrInvalidMessageType = psrpc.NewErrorf(psrpc.Internal, "invalid message type") + ErrNoConnectRequest = psrpc.NewErrorf(psrpc.InvalidArgument, "no connect request") + ErrNoConnectResponse = psrpc.NewErrorf(psrpc.InvalidArgument, "no connect response") + ErrDestinationIdentityRequired = psrpc.NewErrorf(psrpc.InvalidArgument, "destination identity is required") +) diff --git a/livekit/pkg/service/ingress.go b/livekit/pkg/service/ingress.go new file mode 100644 index 0000000..2b82453 --- /dev/null +++ b/livekit/pkg/service/ingress.go @@ -0,0 +1,408 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "fmt" + "net/url" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/ingress" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" +) + +type IngressLauncher interface { + LaunchPullIngress(ctx context.Context, info *livekit.IngressInfo) (*livekit.IngressInfo, error) +} + +type IngressService struct { + conf *config.IngressConfig + nodeID livekit.NodeID + bus psrpc.MessageBus + psrpcClient rpc.IngressClient + store IngressStore + io IOClient + telemetry telemetry.TelemetryService + launcher IngressLauncher +} + +func NewIngressServiceWithIngressLauncher( + conf *config.IngressConfig, + nodeID livekit.NodeID, + bus psrpc.MessageBus, + psrpcClient rpc.IngressClient, + store IngressStore, + io IOClient, + ts telemetry.TelemetryService, + launcher IngressLauncher, +) *IngressService { + + return &IngressService{ + conf: conf, + nodeID: nodeID, + bus: bus, + psrpcClient: psrpcClient, + store: store, + io: io, + telemetry: ts, + launcher: launcher, + } +} + +func NewIngressService( + conf *config.IngressConfig, + nodeID livekit.NodeID, + bus psrpc.MessageBus, + psrpcClient rpc.IngressClient, + store IngressStore, + io IOClient, + ts telemetry.TelemetryService, +) *IngressService { + s := NewIngressServiceWithIngressLauncher(conf, nodeID, bus, psrpcClient, store, io, ts, nil) + + s.launcher = s + + return s +} + +func (s *IngressService) CreateIngress(ctx context.Context, req *livekit.CreateIngressRequest) (*livekit.IngressInfo, error) { + fields := []any{ + "inputType", req.InputType, + "name", req.Name, + } + if req.RoomName != "" { + fields = append(fields, "room", req.RoomName, "identity", req.ParticipantIdentity) + } + defer func() { + AppendLogFields(ctx, fields...) + }() + + var url string + switch req.InputType { + case livekit.IngressInput_RTMP_INPUT: + url = s.conf.RTMPBaseURL + case livekit.IngressInput_WHIP_INPUT: + url = s.conf.WHIPBaseURL + case livekit.IngressInput_URL_INPUT: + default: + return nil, ingress.ErrInvalidIngressType + } + + ig, err := s.CreateIngressWithUrl(ctx, url, req) + if err != nil { + return nil, err + } + fields = append(fields, "ingressID", ig.IngressId) + + return ig, nil +} + +func (s *IngressService) CreateIngressWithUrl(ctx context.Context, urlStr string, req *livekit.CreateIngressRequest) (*livekit.IngressInfo, error) { + err := EnsureIngressAdminPermission(ctx) + if err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrIngressNotConnected + } + + if req.InputType == livekit.IngressInput_URL_INPUT { + if req.Url == "" { + return nil, ingress.ErrInvalidIngress("missing URL parameter") + } + urlObj, err := url.Parse(req.Url) + if err != nil { + return nil, psrpc.NewError(psrpc.InvalidArgument, err) + } + if urlObj.Scheme != "http" && urlObj.Scheme != "https" && urlObj.Scheme != "srt" { + return nil, ingress.ErrInvalidIngress(fmt.Sprintf("invalid url scheme %s", urlObj.Scheme)) + } + // Marshall the URL again for sanitization + urlStr = urlObj.String() + } + + var sk string + if req.InputType != livekit.IngressInput_URL_INPUT { + sk = guid.New("") + } + + info := &livekit.IngressInfo{ + IngressId: guid.New(utils.IngressPrefix), + Name: req.Name, + StreamKey: sk, + Url: urlStr, + InputType: req.InputType, + Audio: req.Audio, + Video: req.Video, + EnableTranscoding: req.EnableTranscoding, + RoomName: req.RoomName, + ParticipantIdentity: req.ParticipantIdentity, + ParticipantName: req.ParticipantName, + ParticipantMetadata: req.ParticipantMetadata, + State: &livekit.IngressState{}, + Enabled: req.Enabled, + } + + switch req.InputType { + case livekit.IngressInput_RTMP_INPUT, + livekit.IngressInput_WHIP_INPUT: + info.Reusable = true + if err := ingress.ValidateForSerialization(info); err != nil { + return nil, err + } + case livekit.IngressInput_URL_INPUT: + if err := ingress.Validate(info); err != nil { + return nil, err + } + default: + return nil, ingress.ErrInvalidIngressType + } + + updateEnableTranscoding(info) + + if req.InputType == livekit.IngressInput_URL_INPUT { + retInfo, err := s.launcher.LaunchPullIngress(ctx, info) + if retInfo != nil { + info = retInfo + } else { + info.State.Status = livekit.IngressState_ENDPOINT_ERROR + info.State.Error = err.Error() + } + if err != nil { + return info, err + } + // The Ingress instance will create the ingress object when handling the URL pull ingress + } else { + _, err = s.io.CreateIngress(ctx, info) + switch err { + case nil: + break + case ingress.ErrIngressOutOfDate: + // Error returned if the ingress was already created by the ingress service + err = nil + default: + logger.Errorw("could not create ingress object", err) + return nil, err + } + } + + return info, nil +} + +func (s *IngressService) LaunchPullIngress(ctx context.Context, info *livekit.IngressInfo) (*livekit.IngressInfo, error) { + req := &rpc.StartIngressRequest{ + Info: info, + } + + return s.psrpcClient.StartIngress(ctx, req) +} + +func updateEnableTranscoding(info *livekit.IngressInfo) { + // Set BypassTranscoding as well for backward compatibility + if info.EnableTranscoding != nil { + info.BypassTranscoding = !*info.EnableTranscoding + return + } + + switch info.InputType { + case livekit.IngressInput_WHIP_INPUT: + f := false + info.EnableTranscoding = &f + info.BypassTranscoding = true + default: + t := true + info.EnableTranscoding = &t + } +} + +func updateInfoUsingRequest(req *livekit.UpdateIngressRequest, info *livekit.IngressInfo) error { + if req.Name != "" { + info.Name = req.Name + } + if req.RoomName != "" { + info.RoomName = req.RoomName + } + if req.ParticipantIdentity != "" { + info.ParticipantIdentity = req.ParticipantIdentity + } + if req.ParticipantName != "" { + info.ParticipantName = req.ParticipantName + } + if req.EnableTranscoding != nil { + info.EnableTranscoding = req.EnableTranscoding + } + + if req.ParticipantMetadata != "" { + info.ParticipantMetadata = req.ParticipantMetadata + } + if req.Audio != nil { + info.Audio = req.Audio + } + if req.Video != nil { + info.Video = req.Video + } + + if req.Enabled != nil { + info.Enabled = req.Enabled + } + + if err := ingress.ValidateForSerialization(info); err != nil { + return err + } + + updateEnableTranscoding(info) + + return nil +} + +func (s *IngressService) UpdateIngress(ctx context.Context, req *livekit.UpdateIngressRequest) (*livekit.IngressInfo, error) { + fields := []any{ + "ingress", req.IngressId, + "name", req.Name, + } + if req.RoomName != "" { + fields = append(fields, "room", req.RoomName, "identity", req.ParticipantIdentity) + } + AppendLogFields(ctx, fields...) + err := EnsureIngressAdminPermission(ctx) + if err != nil { + return nil, twirpAuthError(err) + } + + if s.psrpcClient == nil { + return nil, ErrIngressNotConnected + } + + info, err := s.store.LoadIngress(ctx, req.IngressId) + if err != nil { + logger.Errorw("could not load ingress info", err) + return nil, err + } + + if !info.Reusable { + logger.Infow("ingress update attempted on non reusable ingress", "ingressID", info.IngressId) + return info, ErrIngressNonReusable + } + + switch info.State.Status { + case livekit.IngressState_ENDPOINT_ERROR: + info.State.Status = livekit.IngressState_ENDPOINT_INACTIVE + _, err = s.io.UpdateIngressState(ctx, &rpc.UpdateIngressStateRequest{ + IngressId: req.IngressId, + State: info.State, + }) + if err != nil { + logger.Warnw("could not store ingress state", err) + } + fallthrough + + case livekit.IngressState_ENDPOINT_INACTIVE: + err = updateInfoUsingRequest(req, info) + if err != nil { + return nil, err + } + + case livekit.IngressState_ENDPOINT_BUFFERING, + livekit.IngressState_ENDPOINT_PUBLISHING: + err := updateInfoUsingRequest(req, info) + if err != nil { + return nil, err + } + + // Do not store the returned state as the ingress service will do it + if _, err = s.psrpcClient.UpdateIngress(ctx, req.IngressId, req); err != nil { + logger.Warnw("could not update active ingress", err) + } + } + + err = s.store.UpdateIngress(ctx, info) + if err != nil { + logger.Errorw("could not update ingress info", err) + return nil, err + } + + return info, nil +} + +func (s *IngressService) ListIngress(ctx context.Context, req *livekit.ListIngressRequest) (*livekit.ListIngressResponse, error) { + AppendLogFields(ctx, "room", req.RoomName) + err := EnsureIngressAdminPermission(ctx) + if err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrIngressNotConnected + } + + var infos []*livekit.IngressInfo + if req.IngressId != "" { + info, err := s.store.LoadIngress(ctx, req.IngressId) + if err != nil { + return nil, err + } + infos = []*livekit.IngressInfo{info} + } else { + infos, err = s.store.ListIngress(ctx, livekit.RoomName(req.RoomName)) + if err != nil { + logger.Errorw("could not list ingress info", err) + return nil, err + } + } + + return &livekit.ListIngressResponse{Items: infos}, nil +} + +func (s *IngressService) DeleteIngress(ctx context.Context, req *livekit.DeleteIngressRequest) (*livekit.IngressInfo, error) { + AppendLogFields(ctx, "ingressID", req.IngressId) + if err := EnsureIngressAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + + if s.psrpcClient == nil { + return nil, ErrIngressNotConnected + } + + info, err := s.store.LoadIngress(ctx, req.IngressId) + if err != nil { + return nil, err + } + + switch info.State.Status { + case livekit.IngressState_ENDPOINT_BUFFERING, + livekit.IngressState_ENDPOINT_PUBLISHING: + if _, err = s.psrpcClient.DeleteIngress(ctx, req.IngressId, req); err != nil { + logger.Warnw("could not stop active ingress", err) + } + } + + err = s.store.DeleteIngress(ctx, info) + if err != nil { + logger.Errorw("could not delete ingress info", err) + return nil, err + } + + info.State.Status = livekit.IngressState_ENDPOINT_INACTIVE + + s.telemetry.IngressDeleted(ctx, info) + + return info, nil +} diff --git a/livekit/pkg/service/interfaces.go b/livekit/pkg/service/interfaces.go new file mode 100644 index 0000000..0dc80ca --- /dev/null +++ b/livekit/pkg/service/interfaces.go @@ -0,0 +1,115 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "time" + + "github.com/livekit/protocol/livekit" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +// encapsulates CRUD operations for room settings +// +//counterfeiter:generate . ObjectStore +type ObjectStore interface { + ServiceStore + OSSServiceStore + + // enable locking on a specific room to prevent race + // returns a (lock uuid, error) + LockRoom(ctx context.Context, roomName livekit.RoomName, duration time.Duration) (string, error) + UnlockRoom(ctx context.Context, roomName livekit.RoomName, uid string) error + + StoreRoom(ctx context.Context, room *livekit.Room, internal *livekit.RoomInternal) error + + StoreParticipant(ctx context.Context, roomName livekit.RoomName, participant *livekit.ParticipantInfo) error + DeleteParticipant(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) error +} + +//counterfeiter:generate . ServiceStore +type ServiceStore interface { + LoadRoom(ctx context.Context, roomName livekit.RoomName, includeInternal bool) (*livekit.Room, *livekit.RoomInternal, error) + RoomExists(ctx context.Context, roomName livekit.RoomName) (bool, error) + + // ListRooms returns currently active rooms. if names is not nil, it'll filter and return + // only rooms that match + ListRooms(ctx context.Context, roomNames []livekit.RoomName) ([]*livekit.Room, error) + LoadParticipant(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) + ListParticipants(ctx context.Context, roomName livekit.RoomName) ([]*livekit.ParticipantInfo, error) +} + +type OSSServiceStore interface { + DeleteRoom(ctx context.Context, roomName livekit.RoomName) error + HasParticipant(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (bool, error) +} + +//counterfeiter:generate . EgressStore +type EgressStore interface { + StoreEgress(ctx context.Context, info *livekit.EgressInfo) error + LoadEgress(ctx context.Context, egressID string) (*livekit.EgressInfo, error) + ListEgress(ctx context.Context, roomName livekit.RoomName, active bool) ([]*livekit.EgressInfo, error) + UpdateEgress(ctx context.Context, info *livekit.EgressInfo) error +} + +//counterfeiter:generate . IngressStore +type IngressStore interface { + StoreIngress(ctx context.Context, info *livekit.IngressInfo) error + LoadIngress(ctx context.Context, ingressID string) (*livekit.IngressInfo, error) + LoadIngressFromStreamKey(ctx context.Context, streamKey string) (*livekit.IngressInfo, error) + ListIngress(ctx context.Context, roomName livekit.RoomName) ([]*livekit.IngressInfo, error) + UpdateIngress(ctx context.Context, info *livekit.IngressInfo) error + UpdateIngressState(ctx context.Context, ingressId string, state *livekit.IngressState) error + DeleteIngress(ctx context.Context, info *livekit.IngressInfo) error +} + +//counterfeiter:generate . RoomAllocator +type RoomAllocator interface { + AutoCreateEnabled(ctx context.Context) bool + SelectRoomNode(ctx context.Context, roomName livekit.RoomName, nodeID livekit.NodeID) error + CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest, isExplicit bool) (*livekit.Room, *livekit.RoomInternal, bool, error) + ValidateCreateRoom(ctx context.Context, roomName livekit.RoomName) error +} + +//counterfeiter:generate . SIPStore +type SIPStore interface { + StoreSIPTrunk(ctx context.Context, info *livekit.SIPTrunkInfo) error + StoreSIPInboundTrunk(ctx context.Context, info *livekit.SIPInboundTrunkInfo) error + StoreSIPOutboundTrunk(ctx context.Context, info *livekit.SIPOutboundTrunkInfo) error + LoadSIPTrunk(ctx context.Context, sipTrunkID string) (*livekit.SIPTrunkInfo, error) + LoadSIPInboundTrunk(ctx context.Context, sipTrunkID string) (*livekit.SIPInboundTrunkInfo, error) + LoadSIPOutboundTrunk(ctx context.Context, sipTrunkID string) (*livekit.SIPOutboundTrunkInfo, error) + ListSIPTrunk(ctx context.Context, opts *livekit.ListSIPTrunkRequest) (*livekit.ListSIPTrunkResponse, error) + ListSIPInboundTrunk(ctx context.Context, opts *livekit.ListSIPInboundTrunkRequest) (*livekit.ListSIPInboundTrunkResponse, error) + ListSIPOutboundTrunk(ctx context.Context, opts *livekit.ListSIPOutboundTrunkRequest) (*livekit.ListSIPOutboundTrunkResponse, error) + DeleteSIPTrunk(ctx context.Context, sipTrunkID string) error + + StoreSIPDispatchRule(ctx context.Context, info *livekit.SIPDispatchRuleInfo) error + LoadSIPDispatchRule(ctx context.Context, sipDispatchRuleID string) (*livekit.SIPDispatchRuleInfo, error) + ListSIPDispatchRule(ctx context.Context, opts *livekit.ListSIPDispatchRuleRequest) (*livekit.ListSIPDispatchRuleResponse, error) + DeleteSIPDispatchRule(ctx context.Context, sipDispatchRuleID string) error +} + +//counterfeiter:generate . AgentStore +type AgentStore interface { + StoreAgentDispatch(ctx context.Context, dispatch *livekit.AgentDispatch) error + DeleteAgentDispatch(ctx context.Context, dispatch *livekit.AgentDispatch) error + ListAgentDispatches(ctx context.Context, roomName livekit.RoomName) ([]*livekit.AgentDispatch, error) + + StoreAgentJob(ctx context.Context, job *livekit.Job) error + DeleteAgentJob(ctx context.Context, job *livekit.Job) error +} diff --git a/livekit/pkg/service/ioservice.go b/livekit/pkg/service/ioservice.go new file mode 100644 index 0000000..32d86d9 --- /dev/null +++ b/livekit/pkg/service/ioservice.go @@ -0,0 +1,194 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/middleware/otelpsrpc" + + "github.com/livekit/livekit-server/pkg/telemetry" +) + +type IOInfoService struct { + ioServer rpc.IOInfoServer + + es EgressStore + is IngressStore + ss SIPStore + telemetry telemetry.TelemetryService + + shutdown chan struct{} +} + +func NewIOInfoService( + bus psrpc.MessageBus, + es EgressStore, + is IngressStore, + ss SIPStore, + ts telemetry.TelemetryService, +) (*IOInfoService, error) { + s := &IOInfoService{ + es: es, + is: is, + ss: ss, + telemetry: ts, + shutdown: make(chan struct{}), + } + + if bus != nil { + ioServer, err := rpc.NewIOInfoServer(s, bus, + otelpsrpc.ServerOptions(otelpsrpc.Config{}), + ) + if err != nil { + return nil, err + } + s.ioServer = ioServer + } + + return s, nil +} + +func (s *IOInfoService) Start() error { + if s.es != nil { + rs := s.es.(*RedisStore) + err := rs.Start() + if err != nil { + logger.Errorw("failed to start redis egress worker", err) + return err + } + } + + return nil +} + +func (s *IOInfoService) Stop() { + close(s.shutdown) + + if s.ioServer != nil { + s.ioServer.Shutdown() + } +} + +func (s *IOInfoService) CreateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) { + if s.es == nil { + return nil, ErrEgressNotConnected + } + + // check if egress already exists to avoid duplicate EgressStarted event + if _, err := s.es.LoadEgress(ctx, info.EgressId); err == nil { + return &emptypb.Empty{}, nil + } + + err := s.es.StoreEgress(ctx, info) + if err != nil { + logger.Errorw("could not update egress", err) + return nil, err + } + + s.telemetry.EgressStarted(ctx, info) + + return &emptypb.Empty{}, nil +} + +func (s *IOInfoService) UpdateEgress(ctx context.Context, info *livekit.EgressInfo) (*emptypb.Empty, error) { + if s.es == nil { + return nil, ErrEgressNotConnected + } + + err := s.es.UpdateEgress(ctx, info) + + switch info.Status { + case livekit.EgressStatus_EGRESS_ACTIVE, + livekit.EgressStatus_EGRESS_ENDING: + s.telemetry.EgressUpdated(ctx, info) + + case livekit.EgressStatus_EGRESS_COMPLETE, + livekit.EgressStatus_EGRESS_FAILED, + livekit.EgressStatus_EGRESS_ABORTED, + livekit.EgressStatus_EGRESS_LIMIT_REACHED: + s.telemetry.EgressEnded(ctx, info) + } + + if err != nil { + logger.Errorw("could not update egress", err) + return nil, err + } + + return &emptypb.Empty{}, nil +} + +func (s *IOInfoService) GetEgress(ctx context.Context, req *rpc.GetEgressRequest) (*livekit.EgressInfo, error) { + if s.es == nil { + return nil, ErrEgressNotConnected + } + + info, err := s.es.LoadEgress(ctx, req.EgressId) + if err != nil { + logger.Errorw("failed to load egress", err) + return nil, err + } + + return info, nil +} + +func (s *IOInfoService) ListEgress(ctx context.Context, req *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) { + if s.es == nil { + return nil, ErrEgressNotConnected + } + + if req.EgressId != "" { + info, err := s.es.LoadEgress(ctx, req.EgressId) + if err != nil { + logger.Errorw("failed to load egress", err) + return nil, err + } + + return &livekit.ListEgressResponse{Items: []*livekit.EgressInfo{info}}, nil + } + + items, err := s.es.ListEgress(ctx, livekit.RoomName(req.RoomName), req.Active) + if err != nil { + logger.Errorw("failed to list egress", err) + return nil, err + } + + return &livekit.ListEgressResponse{Items: items}, nil +} + +func (s *IOInfoService) UpdateMetrics(ctx context.Context, req *rpc.UpdateMetricsRequest) (*emptypb.Empty, error) { + logger.Infow("received egress metrics", + "egressID", req.Info.EgressId, + "avgCpu", req.AvgCpuUsage, + "maxCpu", req.MaxCpuUsage, + ) + return &emptypb.Empty{}, nil +} + +func (s *IOInfoService) UpdateSIPCallState(ctx context.Context, req *rpc.UpdateSIPCallStateRequest) (*emptypb.Empty, error) { + // TODO: placeholder + return &emptypb.Empty{}, nil +} + +func (s *IOInfoService) RecordCallContext(context.Context, *rpc.RecordCallContextRequest) (*emptypb.Empty, error) { + // TODO: placeholder + return &emptypb.Empty{}, nil +} diff --git a/livekit/pkg/service/ioservice_ingress.go b/livekit/pkg/service/ioservice_ingress.go new file mode 100644 index 0000000..b802c88 --- /dev/null +++ b/livekit/pkg/service/ioservice_ingress.go @@ -0,0 +1,116 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +func (s *IOInfoService) CreateIngress(ctx context.Context, info *livekit.IngressInfo) (*emptypb.Empty, error) { + if s.is == nil { + return nil, ErrIngressNotConnected + } + + err := s.is.StoreIngress(ctx, info) + if err != nil { + return nil, err + } + + s.telemetry.IngressCreated(ctx, info) + + return &emptypb.Empty{}, nil +} + +func (s *IOInfoService) GetIngressInfo(ctx context.Context, req *rpc.GetIngressInfoRequest) (*rpc.GetIngressInfoResponse, error) { + info, err := s.loadIngressFromInfoRequest(req) + if err != nil { + return nil, err + } + + return &rpc.GetIngressInfoResponse{Info: info}, nil +} + +func (s *IOInfoService) loadIngressFromInfoRequest(req *rpc.GetIngressInfoRequest) (info *livekit.IngressInfo, err error) { + if s.is == nil { + return nil, ErrIngressNotConnected + } + + if req.IngressId != "" { + info, err = s.is.LoadIngress(context.Background(), req.IngressId) + } else if req.StreamKey != "" { + info, err = s.is.LoadIngressFromStreamKey(context.Background(), req.StreamKey) + } else { + err = errors.New("request needs to specify either IngressId or StreamKey") + } + return info, err +} + +func (s *IOInfoService) UpdateIngressState(ctx context.Context, req *rpc.UpdateIngressStateRequest) (*emptypb.Empty, error) { + if s.is == nil { + return nil, ErrIngressNotConnected + } + + info, err := s.is.LoadIngress(ctx, req.IngressId) + if err != nil { + return nil, err + } + + if err = s.is.UpdateIngressState(ctx, req.IngressId, req.State); err != nil { + logger.Errorw("could not update ingress", err) + return nil, err + } + + if info.State.Status != req.State.Status { + info.State = req.State + + switch req.State.Status { + case livekit.IngressState_ENDPOINT_ERROR, + livekit.IngressState_ENDPOINT_INACTIVE, + livekit.IngressState_ENDPOINT_COMPLETE: + s.telemetry.IngressEnded(ctx, info) + + if req.State.Error != "" { + logger.Infow("ingress failed", "error", req.State.Error, "ingressID", req.IngressId) + } else { + logger.Infow("ingress ended", "ingressID", req.IngressId) + } + + case livekit.IngressState_ENDPOINT_PUBLISHING: + s.telemetry.IngressStarted(ctx, info) + + logger.Infow("ingress started", "ingressID", req.IngressId) + + case livekit.IngressState_ENDPOINT_BUFFERING: + s.telemetry.IngressUpdated(ctx, info) + + logger.Infow("ingress buffering", "ingressID", req.IngressId) + } + } else { + // Status didn't change, send Updated event + info.State = req.State + + s.telemetry.IngressUpdated(ctx, info) + + logger.Infow("ingress state updated", "ingressID", req.IngressId, "status", info.State.Status) + } + + return &emptypb.Empty{}, nil +} diff --git a/livekit/pkg/service/ioservice_sip.go b/livekit/pkg/service/ioservice_sip.go new file mode 100644 index 0000000..0140466 --- /dev/null +++ b/livekit/pkg/service/ioservice_sip.go @@ -0,0 +1,166 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "net/netip" + + "github.com/dennwc/iters" + "github.com/twitchtv/twirp" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/sip" +) + +// matchSIPTrunk finds a SIP Trunk definition matching the request. +// Returns nil if no rules matched or an error if there are conflicting definitions. +func (s *IOInfoService) matchSIPTrunk(ctx context.Context, trunkID string, call *rpc.SIPCall) (*livekit.SIPInboundTrunkInfo, error) { + if s.ss == nil { + return nil, ErrSIPNotConnected + } + if trunkID != "" { + // This is a best-effort optimization. Fallthrough to listing trunks if it doesn't work. + if tr, err := s.ss.LoadSIPInboundTrunk(ctx, trunkID); err == nil { + tr, err = sip.MatchTrunkIter(iters.Slice([]*livekit.SIPInboundTrunkInfo{tr}), call) + if err == nil { + return tr, nil + } + } + } + it := s.SelectSIPInboundTrunk(ctx, call.To.User) + return sip.MatchTrunkIter(it, call) +} + +func (s *IOInfoService) SelectSIPInboundTrunk(ctx context.Context, called string) iters.Iter[*livekit.SIPInboundTrunkInfo] { + it := livekit.ListPageIter(s.ss.ListSIPInboundTrunk, &livekit.ListSIPInboundTrunkRequest{ + Numbers: []string{called}, + }) + return iters.PagesAsIter(ctx, it) +} + +// matchSIPDispatchRule finds the best dispatch rule matching the request parameters. Returns an error if no rule matched. +// Trunk parameter can be nil, in which case only wildcard dispatch rules will be effective (ones without Trunk IDs). +func (s *IOInfoService) matchSIPDispatchRule(ctx context.Context, trunk *livekit.SIPInboundTrunkInfo, req *rpc.EvaluateSIPDispatchRulesRequest) (*livekit.SIPDispatchRuleInfo, error) { + if s.ss == nil { + return nil, ErrSIPNotConnected + } + var trunkID string + if trunk != nil { + trunkID = trunk.SipTrunkId + } + // Trunk can still be nil here in case none matched or were defined. + // This is still fine, but only in case we'll match exactly one wildcard dispatch rule. + it := s.SelectSIPDispatchRule(ctx, trunkID) + return sip.MatchDispatchRuleIter(trunk, it, req) +} + +func (s *IOInfoService) SelectSIPDispatchRule(ctx context.Context, trunkID string) iters.Iter[*livekit.SIPDispatchRuleInfo] { + var trunkIDs []string + if trunkID != "" { + trunkIDs = []string{trunkID} + } + it := livekit.ListPageIter(s.ss.ListSIPDispatchRule, &livekit.ListSIPDispatchRuleRequest{ + TrunkIds: trunkIDs, + }) + return iters.PagesAsIter(ctx, it) +} + +func (s *IOInfoService) EvaluateSIPDispatchRules(ctx context.Context, req *rpc.EvaluateSIPDispatchRulesRequest) (*rpc.EvaluateSIPDispatchRulesResponse, error) { + call := req.SIPCall() + log := logger.GetLogger() + log = log.WithValues("toUser", call.To.User, "fromUser", call.From.User, "src", call.SourceIp) + if call.SourceIp == "" { + log.Warnw("source address is not set", nil) + // TODO: return error in the next release + } + _, err := netip.ParseAddr(call.SourceIp) + if call.SourceIp != "" && err != nil { + log.Errorw("cannot parse source IP", err) + return nil, twirp.WrapError(twirp.NewError(twirp.InvalidArgument, err.Error()), err) + } + trunk, err := s.matchSIPTrunk(ctx, req.SipTrunkId, call) + if err != nil { + return nil, err + } + trunkID := "" + if trunk != nil { + trunkID = trunk.SipTrunkId + } + log = log.WithValues("sipTrunk", trunkID) + if trunk != nil { + log.Debugw("SIP trunk matched") + } else { + log.Debugw("No SIP trunk matched") + } + best, err := s.matchSIPDispatchRule(ctx, trunk, req) + if err != nil { + if e := (*sip.ErrNoDispatchMatched)(nil); errors.As(err, &e) { + return &rpc.EvaluateSIPDispatchRulesResponse{ + SipTrunkId: trunkID, + Result: rpc.SIPDispatchResult_DROP, + }, nil + } + return nil, err + } + log.Debugw("SIP dispatch rule matched", "sipRule", best.SipDispatchRuleId) + resp, err := sip.EvaluateDispatchRule("", trunk, best, req) + if err != nil { + return nil, err + } + resp.SipTrunkId = trunkID + return resp, err +} + +func (s *IOInfoService) GetSIPTrunkAuthentication(ctx context.Context, req *rpc.GetSIPTrunkAuthenticationRequest) (*rpc.GetSIPTrunkAuthenticationResponse, error) { + call := req.SIPCall() + log := logger.GetLogger() + log = log.WithValues("toUser", call.To.User, "fromUser", call.From.User, "src", call.SourceIp) + if call.SourceIp == "" { + log.Warnw("source address is not set", nil) + // TODO: return error in the next release + } + _, err := netip.ParseAddr(call.SourceIp) + if call.SourceIp != "" && err != nil { + log.Errorw("cannot parse source IP", err) + return nil, twirp.WrapError(twirp.NewError(twirp.InvalidArgument, err.Error()), err) + } + trunk, err := s.matchSIPTrunk(ctx, "", call) + if err != nil { + return nil, err + } + if trunk == nil { + log.Debugw("No SIP trunk matched for auth", "sipTrunk", "") + return &rpc.GetSIPTrunkAuthenticationResponse{}, nil + } + log.Debugw("SIP trunk matched for auth", "sipTrunk", trunk.SipTrunkId) + + // Create provider info for the trunk + providerInfo := &livekit.ProviderInfo{ + Id: trunk.SipTrunkId, + Name: trunk.Name, + Type: livekit.ProviderType_PROVIDER_TYPE_EXTERNAL, // External trunk + } + + return &rpc.GetSIPTrunkAuthenticationResponse{ + SipTrunkId: trunk.SipTrunkId, + Username: trunk.AuthUsername, + Password: trunk.AuthPassword, + ProviderInfo: providerInfo, + }, nil +} diff --git a/livekit/pkg/service/ioservice_sip_test.go b/livekit/pkg/service/ioservice_sip_test.go new file mode 100644 index 0000000..2a8b752 --- /dev/null +++ b/livekit/pkg/service/ioservice_sip_test.go @@ -0,0 +1,122 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "github.com/dennwc/iters" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/psrpc" + "slices" + "testing" + + "github.com/livekit/protocol/livekit" + "github.com/stretchr/testify/require" +) + +func ioStoreDocker(t testing.TB) (*service.IOInfoService, *service.RedisStore) { + r := redisClientDocker(t) + bus := psrpc.NewRedisMessageBus(r) + rs := service.NewRedisStore(r) + io, err := service.NewIOInfoService(bus, rs, rs, rs, nil) + require.NoError(t, err) + return io, rs +} + +func TestSIPTrunkSelect(t *testing.T) { + ctx := context.Background() + s, rs := ioStoreDocker(t) + + for _, tr := range []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "any", Numbers: nil}, + {SipTrunkId: "B", Numbers: []string{"B1", "B2"}}, + {SipTrunkId: "BC", Numbers: []string{"B1", "C1"}}, + } { + err := rs.StoreSIPInboundTrunk(ctx, tr) + require.NoError(t, err) + } + + for _, tr := range []*livekit.SIPTrunkInfo{ + {SipTrunkId: "old-any", OutboundNumber: ""}, + {SipTrunkId: "old-A", OutboundNumber: "A"}, + } { + err := rs.StoreSIPTrunk(ctx, tr) + require.NoError(t, err) + } + + for _, c := range []struct { + number string + exp []string + }{ + {"A", []string{"old-A", "old-any", "any"}}, + {"B1", []string{"B", "BC", "old-any", "any"}}, + {"B2", []string{"B", "old-any", "any"}}, + {"C1", []string{"BC", "old-any", "any"}}, + {"wrong", []string{"old-any", "any"}}, + } { + t.Run(c.number, func(t *testing.T) { + it := s.SelectSIPInboundTrunk(ctx, c.number) + defer it.Close() + list, err := iters.All(it) + require.NoError(t, err) + var ids []string + for _, v := range list { + ids = append(ids, v.SipTrunkId) + } + slices.Sort(c.exp) + slices.Sort(ids) + require.Equal(t, c.exp, ids) + }) + } +} + +func TestSIPRuleSelect(t *testing.T) { + ctx := context.Background() + s, rs := ioStoreDocker(t) + + for _, r := range []*livekit.SIPDispatchRuleInfo{ + {SipDispatchRuleId: "any", TrunkIds: nil}, + {SipDispatchRuleId: "B", TrunkIds: []string{"B1", "B2"}}, + {SipDispatchRuleId: "BC", TrunkIds: []string{"B1", "C1"}}, + } { + err := rs.StoreSIPDispatchRule(ctx, r) + require.NoError(t, err) + } + + for _, c := range []struct { + trunk string + exp []string + }{ + {"A", []string{"any"}}, + {"B1", []string{"B", "BC", "any"}}, + {"B2", []string{"B", "any"}}, + {"C1", []string{"BC", "any"}}, + {"wrong", []string{"any"}}, + } { + t.Run(c.trunk, func(t *testing.T) { + it := s.SelectSIPDispatchRule(ctx, c.trunk) + defer it.Close() + list, err := iters.All(it) + require.NoError(t, err) + var ids []string + for _, v := range list { + ids = append(ids, v.SipDispatchRuleId) + } + slices.Sort(c.exp) + slices.Sort(ids) + require.Equal(t, c.exp, ids) + }) + } +} diff --git a/livekit/pkg/service/localstore.go b/livekit/pkg/service/localstore.go new file mode 100644 index 0000000..4479857 --- /dev/null +++ b/livekit/pkg/service/localstore.go @@ -0,0 +1,296 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "sync" + "time" + + "github.com/thoas/go-funk" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" +) + +// encapsulates CRUD operations for room settings +type LocalStore struct { + // map of roomName => room + rooms map[livekit.RoomName]*livekit.Room + roomInternal map[livekit.RoomName]*livekit.RoomInternal + // map of roomName => { identity: participant } + participants map[livekit.RoomName]map[livekit.ParticipantIdentity]*livekit.ParticipantInfo + + agentDispatches map[livekit.RoomName]map[string]*livekit.AgentDispatch + agentJobs map[livekit.RoomName]map[string]*livekit.Job + + lock sync.RWMutex + globalLock sync.Mutex +} + +func NewLocalStore() *LocalStore { + return &LocalStore{ + rooms: make(map[livekit.RoomName]*livekit.Room), + roomInternal: make(map[livekit.RoomName]*livekit.RoomInternal), + participants: make(map[livekit.RoomName]map[livekit.ParticipantIdentity]*livekit.ParticipantInfo), + agentDispatches: make(map[livekit.RoomName]map[string]*livekit.AgentDispatch), + agentJobs: make(map[livekit.RoomName]map[string]*livekit.Job), + lock: sync.RWMutex{}, + } +} + +func (s *LocalStore) StoreRoom(_ context.Context, room *livekit.Room, internal *livekit.RoomInternal) error { + if room.CreationTime == 0 { + now := time.Now() + room.CreationTime = now.Unix() + room.CreationTimeMs = now.UnixMilli() + } + roomName := livekit.RoomName(room.Name) + + s.lock.Lock() + s.rooms[roomName] = room + s.roomInternal[roomName] = internal + s.lock.Unlock() + + return nil +} + +func (s *LocalStore) LoadRoom(_ context.Context, roomName livekit.RoomName, includeInternal bool) (*livekit.Room, *livekit.RoomInternal, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + room := s.rooms[roomName] + if room == nil { + return nil, nil, ErrRoomNotFound + } + + var internal *livekit.RoomInternal + if includeInternal { + internal = s.roomInternal[roomName] + } + + return room, internal, nil +} + +func (s *LocalStore) RoomExists(ctx context.Context, roomName livekit.RoomName) (bool, error) { + _, _, err := s.LoadRoom(ctx, roomName, false) + if err == ErrRoomNotFound { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +func (s *LocalStore) ListRooms(_ context.Context, roomNames []livekit.RoomName) ([]*livekit.Room, error) { + s.lock.RLock() + defer s.lock.RUnlock() + rooms := make([]*livekit.Room, 0, len(s.rooms)) + for _, r := range s.rooms { + if roomNames == nil || funk.Contains(roomNames, livekit.RoomName(r.Name)) { + rooms = append(rooms, r) + } + } + return rooms, nil +} + +func (s *LocalStore) DeleteRoom(ctx context.Context, roomName livekit.RoomName) error { + room, _, err := s.LoadRoom(ctx, roomName, false) + if err == ErrRoomNotFound { + return nil + } else if err != nil { + return err + } + + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.participants, livekit.RoomName(room.Name)) + delete(s.rooms, livekit.RoomName(room.Name)) + delete(s.roomInternal, livekit.RoomName(room.Name)) + delete(s.agentDispatches, livekit.RoomName(room.Name)) + delete(s.agentJobs, livekit.RoomName(room.Name)) + return nil +} + +func (s *LocalStore) LockRoom(_ context.Context, _ livekit.RoomName, _ time.Duration) (string, error) { + // local rooms lock & unlock globally + s.globalLock.Lock() + return "", nil +} + +func (s *LocalStore) UnlockRoom(_ context.Context, _ livekit.RoomName, _ string) error { + s.globalLock.Unlock() + return nil +} + +func (s *LocalStore) StoreParticipant(_ context.Context, roomName livekit.RoomName, participant *livekit.ParticipantInfo) error { + s.lock.Lock() + defer s.lock.Unlock() + roomParticipants := s.participants[roomName] + if roomParticipants == nil { + roomParticipants = make(map[livekit.ParticipantIdentity]*livekit.ParticipantInfo) + s.participants[roomName] = roomParticipants + } + roomParticipants[livekit.ParticipantIdentity(participant.Identity)] = participant + return nil +} + +func (s *LocalStore) LoadParticipant(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + roomParticipants := s.participants[roomName] + if roomParticipants == nil { + return nil, ErrParticipantNotFound + } + participant := roomParticipants[identity] + if participant == nil { + return nil, ErrParticipantNotFound + } + return participant, nil +} + +func (s *LocalStore) HasParticipant(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) (bool, error) { + p, err := s.LoadParticipant(ctx, roomName, identity) + return p != nil, utils.ScreenError(err, ErrParticipantNotFound) +} + +func (s *LocalStore) ListParticipants(_ context.Context, roomName livekit.RoomName) ([]*livekit.ParticipantInfo, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + roomParticipants := s.participants[roomName] + if roomParticipants == nil { + // empty array + return nil, nil + } + + items := make([]*livekit.ParticipantInfo, 0, len(roomParticipants)) + for _, p := range roomParticipants { + items = append(items, p) + } + + return items, nil +} + +func (s *LocalStore) DeleteParticipant(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) error { + s.lock.Lock() + defer s.lock.Unlock() + + roomParticipants := s.participants[roomName] + if roomParticipants != nil { + delete(roomParticipants, identity) + } + return nil +} + +func (s *LocalStore) StoreAgentDispatch(ctx context.Context, dispatch *livekit.AgentDispatch) error { + s.lock.Lock() + defer s.lock.Unlock() + + clone := utils.CloneProto(dispatch) + if clone.State != nil { + clone.State.Jobs = nil + } + + roomDispatches := s.agentDispatches[livekit.RoomName(dispatch.Room)] + if roomDispatches == nil { + roomDispatches = make(map[string]*livekit.AgentDispatch) + s.agentDispatches[livekit.RoomName(dispatch.Room)] = roomDispatches + } + + roomDispatches[clone.Id] = clone + return nil +} + +func (s *LocalStore) DeleteAgentDispatch(ctx context.Context, dispatch *livekit.AgentDispatch) error { + s.lock.Lock() + defer s.lock.Unlock() + + roomDispatches := s.agentDispatches[livekit.RoomName(dispatch.Room)] + if roomDispatches != nil { + delete(roomDispatches, dispatch.Id) + } + + return nil +} + +func (s *LocalStore) ListAgentDispatches(ctx context.Context, roomName livekit.RoomName) ([]*livekit.AgentDispatch, error) { + s.lock.Lock() + defer s.lock.Unlock() + + agentDispatches := s.agentDispatches[roomName] + if agentDispatches == nil { + return nil, nil + } + agentJobs := s.agentJobs[roomName] + + var js []*livekit.Job + for _, j := range agentJobs { + js = append(js, utils.CloneProto(j)) + } + var ds []*livekit.AgentDispatch + + m := make(map[string]*livekit.AgentDispatch) + for _, d := range agentDispatches { + clone := utils.CloneProto(d) + m[d.Id] = clone + ds = append(ds, clone) + } + + for _, j := range js { + d := m[j.DispatchId] + if d != nil { + d.State.Jobs = append(d.State.Jobs, utils.CloneProto(j)) + } + } + + return ds, nil +} + +func (s *LocalStore) StoreAgentJob(ctx context.Context, job *livekit.Job) error { + s.lock.Lock() + defer s.lock.Unlock() + + clone := utils.CloneProto(job) + clone.Room = nil + if clone.Participant != nil { + clone.Participant = &livekit.ParticipantInfo{ + Identity: clone.Participant.Identity, + } + } + + roomJobs := s.agentJobs[livekit.RoomName(job.Room.Name)] + if roomJobs == nil { + roomJobs = make(map[string]*livekit.Job) + s.agentJobs[livekit.RoomName(job.Room.Name)] = roomJobs + } + roomJobs[clone.Id] = clone + + return nil +} + +func (s *LocalStore) DeleteAgentJob(ctx context.Context, job *livekit.Job) error { + s.lock.Lock() + defer s.lock.Unlock() + + roomJobs := s.agentJobs[livekit.RoomName(job.Room.Name)] + if roomJobs != nil { + delete(roomJobs, job.Id) + } + + return nil +} diff --git a/livekit/pkg/service/redisstore.go b/livekit/pkg/service/redisstore.go new file mode 100644 index 0000000..8d9387f --- /dev/null +++ b/livekit/pkg/service/redisstore.go @@ -0,0 +1,1093 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "time" + + goversion "github.com/hashicorp/go-version" + "github.com/pkg/errors" + "github.com/redis/go-redis/v9" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/ingress" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/version" +) + +const ( + VersionKey = "livekit_version" + + // RoomsKey is hash of room_name => Room proto + RoomsKey = "rooms" + RoomInternalKey = "room_internal" + + // EgressKey is a hash of egressID => egress info + EgressKey = "egress" + EndedEgressKey = "ended_egress" + RoomEgressPrefix = "egress:room:" + + // IngressKey is a hash of ingressID => ingress info + IngressKey = "ingress" + StreamKeyKey = "{ingress}_stream_key" + IngressStatePrefix = "{ingress}_state:" + RoomIngressPrefix = "room_{ingress}:" + + // RoomParticipantsPrefix is hash of participant_name => ParticipantInfo + RoomParticipantsPrefix = "room_participants:" + + // RoomLockPrefix is a simple key containing a provided lock uid + RoomLockPrefix = "room_lock:" + + // Agents + AgentDispatchPrefix = "agent_dispatch:" + AgentJobPrefix = "agent_job:" + + maxRetries = 5 +) + +type RedisStore struct { + rc redis.UniversalClient + unlockScript *redis.Script + ctx context.Context + done chan struct{} +} + +func NewRedisStore(rc redis.UniversalClient) *RedisStore { + unlockScript := `if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) + else return 0 + end` + + return &RedisStore{ + ctx: context.Background(), + rc: rc, + unlockScript: redis.NewScript(unlockScript), + } +} + +func (s *RedisStore) Start() error { + if s.done != nil { + return nil + } + + s.done = make(chan struct{}, 1) + + v, err := s.rc.Get(s.ctx, VersionKey).Result() + if err != nil && err != redis.Nil { + return err + } + if v == "" { + v = "0.0.0" + } + existing, _ := goversion.NewVersion(v) + current, _ := goversion.NewVersion(version.Version) + if current.GreaterThan(existing) { + if err = s.rc.Set(s.ctx, VersionKey, version.Version, 0).Err(); err != nil { + return err + } + } + + go s.egressWorker() + return nil +} + +func (s *RedisStore) Stop() { + select { + case <-s.done: + default: + close(s.done) + } +} + +func (s *RedisStore) StoreRoom(_ context.Context, room *livekit.Room, internal *livekit.RoomInternal) error { + if room.CreationTime == 0 { + now := time.Now() + room.CreationTime = now.Unix() + room.CreationTimeMs = now.UnixMilli() + } + + roomData, err := proto.Marshal(room) + if err != nil { + return err + } + + pp := s.rc.Pipeline() + pp.HSet(s.ctx, RoomsKey, room.Name, roomData) + + var internalData []byte + if internal != nil { + internalData, err = proto.Marshal(internal) + if err != nil { + return err + } + pp.HSet(s.ctx, RoomInternalKey, room.Name, internalData) + } else { + pp.HDel(s.ctx, RoomInternalKey, room.Name) + } + + if _, err = pp.Exec(s.ctx); err != nil { + return errors.Wrap(err, "could not create room") + } + return nil +} + +func (s *RedisStore) LoadRoom(_ context.Context, roomName livekit.RoomName, includeInternal bool) (*livekit.Room, *livekit.RoomInternal, error) { + pp := s.rc.Pipeline() + pp.HGet(s.ctx, RoomsKey, string(roomName)) + if includeInternal { + pp.HGet(s.ctx, RoomInternalKey, string(roomName)) + } + + res, err := pp.Exec(s.ctx) + if err != nil && err != redis.Nil { + // if the room exists but internal does not, the pipeline will still return redis.Nil + return nil, nil, err + } + + room := &livekit.Room{} + roomData, err := res[0].(*redis.StringCmd).Result() + if err != nil { + if err == redis.Nil { + err = ErrRoomNotFound + } + return nil, nil, err + } + if err = proto.Unmarshal([]byte(roomData), room); err != nil { + return nil, nil, err + } + + var internal *livekit.RoomInternal + if includeInternal { + internalData, err := res[1].(*redis.StringCmd).Result() + if err == nil { + internal = &livekit.RoomInternal{} + if err = proto.Unmarshal([]byte(internalData), internal); err != nil { + return nil, nil, err + } + } else if err != redis.Nil { + return nil, nil, err + } + } + + return room, internal, nil +} + +func (s *RedisStore) RoomExists(ctx context.Context, roomName livekit.RoomName) (bool, error) { + _, _, err := s.LoadRoom(ctx, roomName, false) + if err == ErrRoomNotFound { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +func (s *RedisStore) ListRooms(_ context.Context, roomNames []livekit.RoomName) ([]*livekit.Room, error) { + var items []string + var err error + if roomNames == nil { + items, err = s.rc.HVals(s.ctx, RoomsKey).Result() + if err != nil && err != redis.Nil { + return nil, errors.Wrap(err, "could not get rooms") + } + } else { + names := livekit.IDsAsStrings(roomNames) + var results []any + results, err = s.rc.HMGet(s.ctx, RoomsKey, names...).Result() + if err != nil && err != redis.Nil { + return nil, errors.Wrap(err, "could not get rooms by names") + } + for _, r := range results { + if item, ok := r.(string); ok { + items = append(items, item) + } + } + } + + rooms := make([]*livekit.Room, 0, len(items)) + + for _, item := range items { + room := livekit.Room{} + err := proto.Unmarshal([]byte(item), &room) + if err != nil { + return nil, err + } + rooms = append(rooms, &room) + } + return rooms, nil +} + +func (s *RedisStore) DeleteRoom(ctx context.Context, roomName livekit.RoomName) error { + _, _, err := s.LoadRoom(ctx, roomName, false) + if err == ErrRoomNotFound { + return nil + } + + pp := s.rc.Pipeline() + pp.HDel(s.ctx, RoomsKey, string(roomName)) + pp.HDel(s.ctx, RoomInternalKey, string(roomName)) + pp.Del(s.ctx, RoomParticipantsPrefix+string(roomName)) + pp.Del(s.ctx, AgentDispatchPrefix+string(roomName)) + pp.Del(s.ctx, AgentJobPrefix+string(roomName)) + + _, err = pp.Exec(s.ctx) + return err +} + +func (s *RedisStore) LockRoom(_ context.Context, roomName livekit.RoomName, duration time.Duration) (string, error) { + token := guid.New("LOCK") + key := RoomLockPrefix + string(roomName) + + startTime := time.Now() + for { + locked, err := s.rc.SetNX(s.ctx, key, token, duration).Result() + if err != nil { + return "", err + } + if locked { + return token, nil + } + + // stop waiting past lock duration + if time.Since(startTime) > duration { + break + } + + time.Sleep(100 * time.Millisecond) + } + + return "", ErrRoomLockFailed +} + +func (s *RedisStore) UnlockRoom(_ context.Context, roomName livekit.RoomName, uid string) error { + key := RoomLockPrefix + string(roomName) + res, err := s.unlockScript.Run(s.ctx, s.rc, []string{key}, uid).Result() + if err != nil { + return err + } + + // uid does not match + if i, ok := res.(int64); !ok || i != 1 { + return ErrRoomUnlockFailed + } + + return nil +} + +func (s *RedisStore) StoreParticipant(_ context.Context, roomName livekit.RoomName, participant *livekit.ParticipantInfo) error { + key := RoomParticipantsPrefix + string(roomName) + + data, err := proto.Marshal(participant) + if err != nil { + return err + } + + return s.rc.HSet(s.ctx, key, participant.Identity, data).Err() +} + +func (s *RedisStore) LoadParticipant(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) { + key := RoomParticipantsPrefix + string(roomName) + data, err := s.rc.HGet(s.ctx, key, string(identity)).Result() + if err == redis.Nil { + return nil, ErrParticipantNotFound + } else if err != nil { + return nil, err + } + + pi := livekit.ParticipantInfo{} + if err := proto.Unmarshal([]byte(data), &pi); err != nil { + return nil, err + } + return &pi, nil +} + +func (s *RedisStore) HasParticipant(ctx context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) (bool, error) { + p, err := s.LoadParticipant(ctx, roomName, identity) + return p != nil, utils.ScreenError(err, ErrParticipantNotFound) +} + +func (s *RedisStore) ListParticipants(_ context.Context, roomName livekit.RoomName) ([]*livekit.ParticipantInfo, error) { + key := RoomParticipantsPrefix + string(roomName) + items, err := s.rc.HVals(s.ctx, key).Result() + if err == redis.Nil { + return nil, nil + } else if err != nil { + return nil, err + } + + participants := make([]*livekit.ParticipantInfo, 0, len(items)) + for _, item := range items { + pi := livekit.ParticipantInfo{} + if err := proto.Unmarshal([]byte(item), &pi); err != nil { + return nil, err + } + participants = append(participants, &pi) + } + return participants, nil +} + +func (s *RedisStore) DeleteParticipant(_ context.Context, roomName livekit.RoomName, identity livekit.ParticipantIdentity) error { + key := RoomParticipantsPrefix + string(roomName) + + return s.rc.HDel(s.ctx, key, string(identity)).Err() +} + +func (s *RedisStore) StoreEgress(_ context.Context, info *livekit.EgressInfo) error { + data, err := proto.Marshal(info) + if err != nil { + return err + } + + pp := s.rc.Pipeline() + pp.HSet(s.ctx, EgressKey, info.EgressId, data) + pp.SAdd(s.ctx, RoomEgressPrefix+info.RoomName, info.EgressId) + if _, err = pp.Exec(s.ctx); err != nil { + return errors.Wrap(err, "could not store egress info") + } + + return nil +} + +func (s *RedisStore) LoadEgress(_ context.Context, egressID string) (*livekit.EgressInfo, error) { + data, err := s.rc.HGet(s.ctx, EgressKey, egressID).Result() + switch err { + case nil: + info := &livekit.EgressInfo{} + err = proto.Unmarshal([]byte(data), info) + if err != nil { + return nil, err + } + return info, nil + + case redis.Nil: + return nil, ErrEgressNotFound + + default: + return nil, err + } +} + +func (s *RedisStore) ListEgress(_ context.Context, roomName livekit.RoomName, active bool) ([]*livekit.EgressInfo, error) { + var infos []*livekit.EgressInfo + + if roomName == "" { + data, err := s.rc.HGetAll(s.ctx, EgressKey).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + for _, d := range data { + info := &livekit.EgressInfo{} + err = proto.Unmarshal([]byte(d), info) + if err != nil { + return nil, err + } + + // if active, filter status starting, active, and ending + if !active || int32(info.Status) < int32(livekit.EgressStatus_EGRESS_COMPLETE) { + infos = append(infos, info) + } + } + } else { + egressIDs, err := s.rc.SMembers(s.ctx, RoomEgressPrefix+string(roomName)).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + data, _ := s.rc.HMGet(s.ctx, EgressKey, egressIDs...).Result() + for _, d := range data { + if d == nil { + continue + } + info := &livekit.EgressInfo{} + err = proto.Unmarshal([]byte(d.(string)), info) + if err != nil { + return nil, err + } + + // if active, filter status starting, active, and ending + if !active || int32(info.Status) < int32(livekit.EgressStatus_EGRESS_COMPLETE) { + infos = append(infos, info) + } + } + } + + return infos, nil +} + +func (s *RedisStore) UpdateEgress(_ context.Context, info *livekit.EgressInfo) error { + data, err := proto.Marshal(info) + if err != nil { + return err + } + + if info.EndedAt != 0 { + pp := s.rc.Pipeline() + pp.HSet(s.ctx, EgressKey, info.EgressId, data) + pp.HSet(s.ctx, EndedEgressKey, info.EgressId, egressEndedValue(info.RoomName, info.EndedAt)) + _, err = pp.Exec(s.ctx) + } else { + err = s.rc.HSet(s.ctx, EgressKey, info.EgressId, data).Err() + } + + if err != nil { + return errors.Wrap(err, "could not update egress info") + } + + return nil +} + +// Deletes egress info 24h after the egress has ended +func (s *RedisStore) egressWorker() { + ticker := time.NewTicker(time.Minute * 30) + defer ticker.Stop() + + for { + select { + case <-s.done: + return + case <-ticker.C: + err := s.CleanEndedEgress() + if err != nil { + logger.Errorw("could not clean egress info", err) + } + } + } +} + +func (s *RedisStore) CleanEndedEgress() error { + values, err := s.rc.HGetAll(s.ctx, EndedEgressKey).Result() + if err != nil && err != redis.Nil { + return err + } + + expiry := time.Now().Add(-24 * time.Hour).UnixNano() + for egressID, val := range values { + roomName, endedAt, err := parseEgressEnded(val) + if err != nil { + return err + } + + if endedAt < expiry { + pp := s.rc.Pipeline() + pp.SRem(s.ctx, RoomEgressPrefix+roomName, egressID) + pp.HDel(s.ctx, EgressKey, egressID) + // Delete the EndedEgressKey entry last so that future sweeper runs get another chance to delete dangling data is the deletion partially failed. + pp.HDel(s.ctx, EndedEgressKey, egressID) + if _, err := pp.Exec(s.ctx); err != nil { + return err + } + } + } + + return nil +} + +func egressEndedValue(roomName string, endedAt int64) string { + return fmt.Sprintf("%s|%d", roomName, endedAt) +} + +func parseEgressEnded(value string) (roomName string, endedAt int64, err error) { + s := strings.Split(value, "|") + if len(s) != 2 { + err = errors.New("invalid egressEnded value") + return + } + + roomName = s[0] + endedAt, err = strconv.ParseInt(s[1], 10, 64) + return +} + +func (s *RedisStore) StoreIngress(ctx context.Context, info *livekit.IngressInfo) error { + err := s.storeIngress(ctx, info) + if err != nil { + return err + } + + return s.storeIngressState(ctx, info.IngressId, nil) +} + +func (s *RedisStore) storeIngress(_ context.Context, info *livekit.IngressInfo) error { + if info.IngressId == "" { + return errors.New("Missing IngressId") + } + if info.StreamKey == "" && info.InputType != livekit.IngressInput_URL_INPUT { + return errors.New("Missing StreamKey") + } + + // ignore state + infoCopy := utils.CloneProto(info) + infoCopy.State = nil + + data, err := proto.Marshal(infoCopy) + if err != nil { + return err + } + + // Use a "transaction" to remove the old room association if it changed + txf := func(tx *redis.Tx) error { + var oldRoom string + + oldInfo, err := s.loadIngress(tx, info.IngressId) + switch err { + case ErrIngressNotFound: + // Ingress doesn't exist yet + case nil: + oldRoom = oldInfo.RoomName + default: + return err + } + + results, err := tx.TxPipelined(s.ctx, func(p redis.Pipeliner) error { + p.HSet(s.ctx, IngressKey, info.IngressId, data) + if info.StreamKey != "" { + p.HSet(s.ctx, StreamKeyKey, info.StreamKey, info.IngressId) + } + + if oldRoom != info.RoomName { + if oldRoom != "" { + p.SRem(s.ctx, RoomIngressPrefix+oldRoom, info.IngressId) + } + if info.RoomName != "" { + p.SAdd(s.ctx, RoomIngressPrefix+info.RoomName, info.IngressId) + } + } + + return nil + }) + + if err != nil { + return err + } + + for _, res := range results { + if err := res.Err(); err != nil { + return err + } + } + + return nil + } + + // Retry if the key has been changed. + for range maxRetries { + err := s.rc.Watch(s.ctx, txf, IngressKey) + switch err { + case redis.TxFailedErr: + // Optimistic lock lost. Retry. + continue + default: + return err + } + } + + return nil +} + +func (s *RedisStore) storeIngressState(_ context.Context, ingressId string, state *livekit.IngressState) error { + if ingressId == "" { + return errors.New("Missing IngressId") + } + + if state == nil { + state = &livekit.IngressState{} + } + + data, err := proto.Marshal(state) + if err != nil { + return err + } + + // Use a "transaction" to remove the old room association if it changed + txf := func(tx *redis.Tx) error { + var oldStartedAt int64 + var oldUpdatedAt int64 + + oldState, err := s.loadIngressState(tx, ingressId) + switch err { + case ErrIngressNotFound: + // Ingress state doesn't exist yet + case nil: + oldStartedAt = oldState.StartedAt + oldUpdatedAt = oldState.UpdatedAt + default: + return err + } + + results, err := tx.TxPipelined(s.ctx, func(p redis.Pipeliner) error { + if state.StartedAt < oldStartedAt { + // Do not overwrite the info and state of a more recent session + return ingress.ErrIngressOutOfDate + } + + if state.StartedAt == oldStartedAt && state.UpdatedAt < oldUpdatedAt { + // Do not overwrite with an old state in case RPCs were delivered out of order. + // All RPCs come from the same ingress server and should thus be on the same clock. + return nil + } + + p.Set(s.ctx, IngressStatePrefix+ingressId, data, 0) + + return nil + }) + + if err != nil { + return err + } + + for _, res := range results { + if err := res.Err(); err != nil { + return err + } + } + + return nil + } + + // Retry if the key has been changed. + for range maxRetries { + err := s.rc.Watch(s.ctx, txf, IngressStatePrefix+ingressId) + switch err { + case redis.TxFailedErr: + // Optimistic lock lost. Retry. + continue + default: + return err + } + } + + return nil +} + +func (s *RedisStore) loadIngress(c redis.Cmdable, ingressId string) (*livekit.IngressInfo, error) { + data, err := c.HGet(s.ctx, IngressKey, ingressId).Result() + switch err { + case nil: + info := &livekit.IngressInfo{} + err = proto.Unmarshal([]byte(data), info) + if err != nil { + return nil, err + } + return info, nil + + case redis.Nil: + return nil, ErrIngressNotFound + + default: + return nil, err + } +} + +func (s *RedisStore) loadIngressState(c redis.Cmdable, ingressId string) (*livekit.IngressState, error) { + data, err := c.Get(s.ctx, IngressStatePrefix+ingressId).Result() + switch err { + case nil: + state := &livekit.IngressState{} + err = proto.Unmarshal([]byte(data), state) + if err != nil { + return nil, err + } + return state, nil + + case redis.Nil: + return nil, ErrIngressNotFound + + default: + return nil, err + } +} + +func (s *RedisStore) LoadIngress(_ context.Context, ingressId string) (*livekit.IngressInfo, error) { + info, err := s.loadIngress(s.rc, ingressId) + if err != nil { + return nil, err + } + state, err := s.loadIngressState(s.rc, ingressId) + switch err { + case nil: + info.State = state + case ErrIngressNotFound: + // No state for this ingress + default: + return nil, err + } + + return info, nil +} + +func (s *RedisStore) LoadIngressFromStreamKey(_ context.Context, streamKey string) (*livekit.IngressInfo, error) { + ingressID, err := s.rc.HGet(s.ctx, StreamKeyKey, streamKey).Result() + switch err { + case nil: + return s.LoadIngress(s.ctx, ingressID) + + case redis.Nil: + return nil, ErrIngressNotFound + + default: + return nil, err + } +} + +func (s *RedisStore) ListIngress(_ context.Context, roomName livekit.RoomName) ([]*livekit.IngressInfo, error) { + var infos []*livekit.IngressInfo + + if roomName == "" { + data, err := s.rc.HGetAll(s.ctx, IngressKey).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + for _, d := range data { + info := &livekit.IngressInfo{} + err = proto.Unmarshal([]byte(d), info) + if err != nil { + return nil, err + } + state, err := s.loadIngressState(s.rc, info.IngressId) + switch err { + case nil: + info.State = state + case ErrIngressNotFound: + // No state for this ingress + default: + return nil, err + } + + infos = append(infos, info) + } + } else { + ingressIDs, err := s.rc.SMembers(s.ctx, RoomIngressPrefix+string(roomName)).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, err + } + + data, _ := s.rc.HMGet(s.ctx, IngressKey, ingressIDs...).Result() + for _, d := range data { + if d == nil { + continue + } + info := &livekit.IngressInfo{} + err = proto.Unmarshal([]byte(d.(string)), info) + if err != nil { + return nil, err + } + state, err := s.loadIngressState(s.rc, info.IngressId) + switch err { + case nil: + info.State = state + case ErrIngressNotFound: + // No state for this ingress + default: + return nil, err + } + + infos = append(infos, info) + } + } + + return infos, nil +} + +func (s *RedisStore) UpdateIngress(ctx context.Context, info *livekit.IngressInfo) error { + return s.storeIngress(ctx, info) +} + +func (s *RedisStore) UpdateIngressState(ctx context.Context, ingressId string, state *livekit.IngressState) error { + return s.storeIngressState(ctx, ingressId, state) +} + +func (s *RedisStore) DeleteIngress(_ context.Context, info *livekit.IngressInfo) error { + tx := s.rc.TxPipeline() + tx.SRem(s.ctx, RoomIngressPrefix+info.RoomName, info.IngressId) + if info.StreamKey != "" { + tx.HDel(s.ctx, StreamKeyKey, info.StreamKey) + } + tx.HDel(s.ctx, IngressKey, info.IngressId) + tx.Del(s.ctx, IngressStatePrefix+info.IngressId) + if _, err := tx.Exec(s.ctx); err != nil { + return errors.Wrap(err, "could not delete ingress info") + } + + return nil +} + +func (s *RedisStore) StoreAgentDispatch(_ context.Context, dispatch *livekit.AgentDispatch) error { + di := utils.CloneProto(dispatch) + + // Do not store jobs with the dispatch + if di.State != nil { + di.State.Jobs = nil + } + + key := AgentDispatchPrefix + string(dispatch.Room) + + data, err := proto.Marshal(di) + if err != nil { + return err + } + + return s.rc.HSet(s.ctx, key, di.Id, data).Err() +} + +// This will not delete the jobs created by the dispatch +func (s *RedisStore) DeleteAgentDispatch(_ context.Context, dispatch *livekit.AgentDispatch) error { + key := AgentDispatchPrefix + string(dispatch.Room) + return s.rc.HDel(s.ctx, key, dispatch.Id).Err() +} + +func (s *RedisStore) ListAgentDispatches(_ context.Context, roomName livekit.RoomName) ([]*livekit.AgentDispatch, error) { + key := AgentDispatchPrefix + string(roomName) + dispatches, err := redisLoadAll[livekit.AgentDispatch](s.ctx, s, key) + if err != nil { + return nil, err + } + + dMap := make(map[string]*livekit.AgentDispatch) + for _, di := range dispatches { + dMap[di.Id] = di + } + + key = AgentJobPrefix + string(roomName) + jobs, err := redisLoadAll[livekit.Job](s.ctx, s, key) + if err != nil { + return nil, err + } + + // Associate job to dispatch + for _, jb := range jobs { + di := dMap[jb.DispatchId] + if di == nil { + continue + } + if di.State == nil { + di.State = &livekit.AgentDispatchState{} + } + di.State.Jobs = append(di.State.Jobs, jb) + } + + return dispatches, nil +} + +func (s *RedisStore) StoreAgentJob(_ context.Context, job *livekit.Job) error { + if job.Room == nil { + return psrpc.NewErrorf(psrpc.InvalidArgument, "job doesn't have a valid Room field") + } + + key := AgentJobPrefix + string(job.Room.Name) + + jb := utils.CloneProto(job) + + // Do not store room with the job + jb.Room = nil + + // Only store the participant identity + if jb.Participant != nil { + jb.Participant = &livekit.ParticipantInfo{ + Identity: jb.Participant.Identity, + } + } + + data, err := proto.Marshal(jb) + if err != nil { + return err + } + + return s.rc.HSet(s.ctx, key, job.Id, data).Err() +} + +func (s *RedisStore) DeleteAgentJob(_ context.Context, job *livekit.Job) error { + if job.Room == nil { + return psrpc.NewErrorf(psrpc.InvalidArgument, "job doesn't have a valid Room field") + } + + key := AgentJobPrefix + string(job.Room.Name) + return s.rc.HDel(s.ctx, key, job.Id).Err() +} + +func redisStoreOne(ctx context.Context, s *RedisStore, key, id string, p proto.Message) error { + if id == "" { + return errors.New("id is not set") + } + data, err := proto.Marshal(p) + if err != nil { + return err + } + return s.rc.HSet(s.ctx, key, id, data).Err() +} + +type protoMsg[T any] interface { + *T + proto.Message +} + +func redisLoadOne[T any, P protoMsg[T]](ctx context.Context, s *RedisStore, key, id string, notFoundErr error) (P, error) { + data, err := s.rc.HGet(s.ctx, key, id).Result() + if err == redis.Nil { + return nil, notFoundErr + } else if err != nil { + return nil, err + } + var p P = new(T) + err = proto.Unmarshal([]byte(data), p) + if err != nil { + return nil, err + } + return p, err +} + +func redisLoadAll[T any, P protoMsg[T]](ctx context.Context, s *RedisStore, key string) ([]P, error) { + data, err := s.rc.HVals(s.ctx, key).Result() + if err == redis.Nil { + return nil, nil + } else if err != nil { + return nil, err + } + + list := make([]P, 0, len(data)) + for _, d := range data { + var p P = new(T) + if err = proto.Unmarshal([]byte(d), p); err != nil { + return list, err + } + list = append(list, p) + } + + return list, nil +} + +func redisLoadBatch[T any, P protoMsg[T]](ctx context.Context, s *RedisStore, key string, ids []string, keepEmpty bool) ([]P, error) { + data, err := s.rc.HMGet(s.ctx, key, ids...).Result() + if err == redis.Nil { + if keepEmpty { + return make([]P, len(ids)), nil + } + return nil, nil + } else if err != nil { + return nil, err + } + if !keepEmpty { + list := make([]P, 0, len(data)) + for _, v := range data { + if d, ok := v.(string); ok { + var p P = new(T) + if err = proto.Unmarshal([]byte(d), p); err != nil { + return list, err + } + list = append(list, p) + } + } + return list, nil + } + // Keep zero values where ID was not found. + list := make([]P, len(ids)) + for i := range ids { + if d, ok := data[i].(string); ok { + var p P = new(T) + if err = proto.Unmarshal([]byte(d), p); err != nil { + return list, err + } + list[i] = p + } + } + return list, nil +} + +func redisIDs(ctx context.Context, s *RedisStore, key string) ([]string, error) { + list, err := s.rc.HKeys(s.ctx, key).Result() + if err == redis.Nil { + return nil, nil + } else if err != nil { + return nil, err + } + slices.Sort(list) + return list, nil +} + +type protoEntity[T any] interface { + protoMsg[T] + ID() string +} + +func redisIterPage[T any, P protoEntity[T]](ctx context.Context, s *RedisStore, key string, page *livekit.Pagination) ([]P, error) { + if page == nil { + return redisLoadAll[T, P](ctx, s, key) + } + ids, err := redisIDs(ctx, s, key) + if err != nil { + return nil, err + } + if len(ids) == 0 { + return nil, nil + } + if page.AfterId != "" { + i, ok := slices.BinarySearch(ids, page.AfterId) + if ok { + i++ + } + ids = ids[i:] + if len(ids) == 0 { + return nil, nil + } + } + limit := 1000 + if page.Limit > 0 { + limit = int(page.Limit) + } + if len(ids) > limit { + ids = ids[:limit] + } + return redisLoadBatch[T, P](ctx, s, key, ids, false) +} + +func sortProtos[T any, P protoEntity[T]](arr []P) { + slices.SortFunc(arr, func(a, b P) int { + return strings.Compare(a.ID(), b.ID()) + }) +} + +func sortPage[T any, P protoEntity[T]](items []P, page *livekit.Pagination) []P { + sortProtos(items) + if page != nil { + if limit := int(page.Limit); limit > 0 && len(items) > limit { + items = items[:limit] + } + } + return items +} diff --git a/livekit/pkg/service/redisstore_sip.go b/livekit/pkg/service/redisstore_sip.go new file mode 100644 index 0000000..ba2931a --- /dev/null +++ b/livekit/pkg/service/redisstore_sip.go @@ -0,0 +1,254 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "github.com/livekit/protocol/livekit" +) + +const ( + SIPTrunkKey = "sip_trunk" + SIPInboundTrunkKey = "sip_inbound_trunk" + SIPOutboundTrunkKey = "sip_outbound_trunk" + SIPDispatchRuleKey = "sip_dispatch_rule" +) + +func (s *RedisStore) StoreSIPTrunk(ctx context.Context, info *livekit.SIPTrunkInfo) error { + return redisStoreOne(s.ctx, s, SIPTrunkKey, info.SipTrunkId, info) +} + +func (s *RedisStore) StoreSIPInboundTrunk(ctx context.Context, info *livekit.SIPInboundTrunkInfo) error { + return redisStoreOne(s.ctx, s, SIPInboundTrunkKey, info.SipTrunkId, info) +} + +func (s *RedisStore) StoreSIPOutboundTrunk(ctx context.Context, info *livekit.SIPOutboundTrunkInfo) error { + return redisStoreOne(s.ctx, s, SIPOutboundTrunkKey, info.SipTrunkId, info) +} + +func (s *RedisStore) loadSIPLegacyTrunk(ctx context.Context, id string) (*livekit.SIPTrunkInfo, error) { + return redisLoadOne[livekit.SIPTrunkInfo](ctx, s, SIPTrunkKey, id, ErrSIPTrunkNotFound) +} + +func (s *RedisStore) loadSIPInboundTrunk(ctx context.Context, id string) (*livekit.SIPInboundTrunkInfo, error) { + return redisLoadOne[livekit.SIPInboundTrunkInfo](ctx, s, SIPInboundTrunkKey, id, ErrSIPTrunkNotFound) +} + +func (s *RedisStore) loadSIPOutboundTrunk(ctx context.Context, id string) (*livekit.SIPOutboundTrunkInfo, error) { + return redisLoadOne[livekit.SIPOutboundTrunkInfo](ctx, s, SIPOutboundTrunkKey, id, ErrSIPTrunkNotFound) +} + +func (s *RedisStore) LoadSIPTrunk(ctx context.Context, id string) (*livekit.SIPTrunkInfo, error) { + tr, err := s.loadSIPLegacyTrunk(ctx, id) + if err == nil { + return tr, nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + in, err := s.loadSIPInboundTrunk(ctx, id) + if err == nil { + return in.AsTrunkInfo(), nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + out, err := s.loadSIPOutboundTrunk(ctx, id) + if err == nil { + return out.AsTrunkInfo(), nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + return nil, ErrSIPTrunkNotFound +} + +func (s *RedisStore) LoadSIPInboundTrunk(ctx context.Context, id string) (*livekit.SIPInboundTrunkInfo, error) { + in, err := s.loadSIPInboundTrunk(ctx, id) + if err == nil { + return in, nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + tr, err := s.loadSIPLegacyTrunk(ctx, id) + if err == nil { + return tr.AsInbound(), nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + return nil, ErrSIPTrunkNotFound +} + +func (s *RedisStore) LoadSIPOutboundTrunk(ctx context.Context, id string) (*livekit.SIPOutboundTrunkInfo, error) { + in, err := s.loadSIPOutboundTrunk(ctx, id) + if err == nil { + return in, nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + tr, err := s.loadSIPLegacyTrunk(ctx, id) + if err == nil { + return tr.AsOutbound(), nil + } else if err != ErrSIPTrunkNotFound { + return nil, err + } + return nil, ErrSIPTrunkNotFound +} + +func (s *RedisStore) DeleteSIPTrunk(ctx context.Context, id string) error { + err1 := s.rc.HDel(s.ctx, SIPTrunkKey, id).Err() + err2 := s.rc.HDel(s.ctx, SIPInboundTrunkKey, id).Err() + err3 := s.rc.HDel(s.ctx, SIPOutboundTrunkKey, id).Err() + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + if err3 != nil { + return err3 + } + return nil +} + +func (s *RedisStore) listSIPLegacyTrunk(ctx context.Context, page *livekit.Pagination) ([]*livekit.SIPTrunkInfo, error) { + return redisIterPage[livekit.SIPTrunkInfo](ctx, s, SIPTrunkKey, page) +} + +func (s *RedisStore) listSIPInboundTrunk(ctx context.Context, page *livekit.Pagination) ([]*livekit.SIPInboundTrunkInfo, error) { + return redisIterPage[livekit.SIPInboundTrunkInfo](ctx, s, SIPInboundTrunkKey, page) +} + +func (s *RedisStore) listSIPOutboundTrunk(ctx context.Context, page *livekit.Pagination) ([]*livekit.SIPOutboundTrunkInfo, error) { + return redisIterPage[livekit.SIPOutboundTrunkInfo](ctx, s, SIPOutboundTrunkKey, page) +} + +func (s *RedisStore) listSIPDispatchRule(ctx context.Context, page *livekit.Pagination) ([]*livekit.SIPDispatchRuleInfo, error) { + return redisIterPage[livekit.SIPDispatchRuleInfo](ctx, s, SIPDispatchRuleKey, page) +} + +func (s *RedisStore) ListSIPTrunk(ctx context.Context, req *livekit.ListSIPTrunkRequest) (*livekit.ListSIPTrunkResponse, error) { + var items []*livekit.SIPTrunkInfo + old, err := s.listSIPLegacyTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range old { + v := t + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + in, err := s.listSIPInboundTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range in { + v := t.AsTrunkInfo() + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + out, err := s.listSIPOutboundTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range out { + v := t.AsTrunkInfo() + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + items = sortPage(items, req.Page) + return &livekit.ListSIPTrunkResponse{Items: items}, nil +} + +func (s *RedisStore) ListSIPInboundTrunk(ctx context.Context, req *livekit.ListSIPInboundTrunkRequest) (*livekit.ListSIPInboundTrunkResponse, error) { + var items []*livekit.SIPInboundTrunkInfo + in, err := s.listSIPInboundTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range in { + v := t + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + old, err := s.listSIPLegacyTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range old { + v := t.AsInbound() + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + items = sortPage(items, req.Page) + return &livekit.ListSIPInboundTrunkResponse{Items: items}, nil +} + +func (s *RedisStore) ListSIPOutboundTrunk(ctx context.Context, req *livekit.ListSIPOutboundTrunkRequest) (*livekit.ListSIPOutboundTrunkResponse, error) { + var items []*livekit.SIPOutboundTrunkInfo + out, err := s.listSIPOutboundTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range out { + v := t + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + old, err := s.listSIPLegacyTrunk(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range old { + v := t.AsOutbound() + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + items = sortPage(items, req.Page) + return &livekit.ListSIPOutboundTrunkResponse{Items: items}, nil +} + +func (s *RedisStore) StoreSIPDispatchRule(ctx context.Context, info *livekit.SIPDispatchRuleInfo) error { + return redisStoreOne(ctx, s, SIPDispatchRuleKey, info.SipDispatchRuleId, info) +} + +func (s *RedisStore) LoadSIPDispatchRule(ctx context.Context, sipDispatchRuleId string) (*livekit.SIPDispatchRuleInfo, error) { + return redisLoadOne[livekit.SIPDispatchRuleInfo](ctx, s, SIPDispatchRuleKey, sipDispatchRuleId, ErrSIPDispatchRuleNotFound) +} + +func (s *RedisStore) DeleteSIPDispatchRule(ctx context.Context, sipDispatchRuleId string) error { + return s.rc.HDel(s.ctx, SIPDispatchRuleKey, sipDispatchRuleId).Err() +} + +func (s *RedisStore) ListSIPDispatchRule(ctx context.Context, req *livekit.ListSIPDispatchRuleRequest) (*livekit.ListSIPDispatchRuleResponse, error) { + var items []*livekit.SIPDispatchRuleInfo + out, err := s.listSIPDispatchRule(ctx, req.Page) + if err != nil { + return nil, err + } + for _, t := range out { + v := t + if req.Filter(v) && req.Page.Filter(v) { + items = append(items, v) + } + } + items = sortPage(items, req.Page) + return &livekit.ListSIPDispatchRuleResponse{Items: items}, nil +} diff --git a/livekit/pkg/service/redisstore_sip_test.go b/livekit/pkg/service/redisstore_sip_test.go new file mode 100644 index 0000000..9ca1236 --- /dev/null +++ b/livekit/pkg/service/redisstore_sip_test.go @@ -0,0 +1,357 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "fmt" + "slices" + "strings" + "testing" + + "github.com/dennwc/iters" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/service" +) + +func TestSIPStoreDispatch(t *testing.T) { + ctx := context.Background() + rs := redisStoreDocker(t) + + id := guid.New(utils.SIPDispatchRulePrefix) + + // No dispatch rules initially. + list, err := rs.ListSIPDispatchRule(ctx, &livekit.ListSIPDispatchRuleRequest{}) + require.NoError(t, err) + require.Empty(t, list.Items) + + // Loading non-existent dispatch should return proper not found error. + got, err := rs.LoadSIPDispatchRule(ctx, id) + require.Equal(t, service.ErrSIPDispatchRuleNotFound, err) + require.Nil(t, got) + + // Creation without ID should fail. + rule := &livekit.SIPDispatchRuleInfo{ + TrunkIds: []string{"trunk"}, + Rule: &livekit.SIPDispatchRule{Rule: &livekit.SIPDispatchRule_DispatchRuleDirect{ + DispatchRuleDirect: &livekit.SIPDispatchRuleDirect{ + RoomName: "room", + Pin: "1234", + }, + }}, + } + err = rs.StoreSIPDispatchRule(ctx, rule) + require.Error(t, err) + + // Creation + rule.SipDispatchRuleId = id + err = rs.StoreSIPDispatchRule(ctx, rule) + require.NoError(t, err) + + // Loading + got, err = rs.LoadSIPDispatchRule(ctx, id) + require.NoError(t, err) + require.True(t, proto.Equal(rule, got)) + + // Listing + list, err = rs.ListSIPDispatchRule(ctx, &livekit.ListSIPDispatchRuleRequest{}) + require.NoError(t, err) + require.Len(t, list.Items, 1) + require.True(t, proto.Equal(rule, list.Items[0])) + + // Deletion. Should not return error if not exists. + err = rs.DeleteSIPDispatchRule(ctx, id) + require.NoError(t, err) + err = rs.DeleteSIPDispatchRule(ctx, id) + require.NoError(t, err) + + // Check that it's deleted. + list, err = rs.ListSIPDispatchRule(ctx, &livekit.ListSIPDispatchRuleRequest{}) + require.NoError(t, err) + require.Empty(t, list.Items) + + got, err = rs.LoadSIPDispatchRule(ctx, id) + require.Equal(t, service.ErrSIPDispatchRuleNotFound, err) + require.Nil(t, got) +} + +func TestSIPStoreTrunk(t *testing.T) { + ctx := context.Background() + rs := redisStoreDocker(t) + + oldID := guid.New(utils.SIPTrunkPrefix) + inID := guid.New(utils.SIPTrunkPrefix) + outID := guid.New(utils.SIPTrunkPrefix) + + // No trunks initially. Check legacy, inbound, outbound. + // Loading non-existent trunk should return proper not found error. + oldList, err := rs.ListSIPTrunk(ctx, &livekit.ListSIPTrunkRequest{}) + require.NoError(t, err) + require.Empty(t, oldList.Items) + + old, err := rs.LoadSIPTrunk(ctx, oldID) + require.Equal(t, service.ErrSIPTrunkNotFound, err) + require.Nil(t, old) + + inList, err := rs.ListSIPInboundTrunk(ctx, &livekit.ListSIPInboundTrunkRequest{}) + require.NoError(t, err) + require.Empty(t, inList.Items) + + in, err := rs.LoadSIPInboundTrunk(ctx, oldID) + require.Equal(t, service.ErrSIPTrunkNotFound, err) + require.Nil(t, in) + + outList, err := rs.ListSIPOutboundTrunk(ctx, &livekit.ListSIPOutboundTrunkRequest{}) + require.NoError(t, err) + require.Empty(t, outList.Items) + + out, err := rs.LoadSIPOutboundTrunk(ctx, oldID) + require.Equal(t, service.ErrSIPTrunkNotFound, err) + require.Nil(t, out) + + // Creation without ID should fail. + oldT := &livekit.SIPTrunkInfo{ + Name: "Legacy", + } + err = rs.StoreSIPTrunk(ctx, oldT) + require.Error(t, err) + + inT := &livekit.SIPInboundTrunkInfo{ + Name: "Inbound", + } + err = rs.StoreSIPInboundTrunk(ctx, inT) + require.Error(t, err) + + outT := &livekit.SIPOutboundTrunkInfo{ + Name: "Outbound", + } + err = rs.StoreSIPOutboundTrunk(ctx, outT) + require.Error(t, err) + + // Creation + oldT.SipTrunkId = oldID + err = rs.StoreSIPTrunk(ctx, oldT) + require.NoError(t, err) + + inT.SipTrunkId = inID + err = rs.StoreSIPInboundTrunk(ctx, inT) + require.NoError(t, err) + + outT.SipTrunkId = outID + err = rs.StoreSIPOutboundTrunk(ctx, outT) + require.NoError(t, err) + + // Loading (with matching kind) + oldT2, err := rs.LoadSIPTrunk(ctx, oldID) + require.NoError(t, err) + require.True(t, proto.Equal(oldT, oldT2)) + + inT2, err := rs.LoadSIPInboundTrunk(ctx, inID) + require.NoError(t, err) + require.True(t, proto.Equal(inT, inT2)) + + outT2, err := rs.LoadSIPOutboundTrunk(ctx, outID) + require.NoError(t, err) + require.True(t, proto.Equal(outT, outT2)) + + // Loading (compat) + oldT2, err = rs.LoadSIPTrunk(ctx, inID) + require.NoError(t, err) + require.True(t, proto.Equal(inT.AsTrunkInfo(), oldT2)) + + oldT2, err = rs.LoadSIPTrunk(ctx, outID) + require.NoError(t, err) + require.True(t, proto.Equal(outT.AsTrunkInfo(), oldT2)) + + inT2, err = rs.LoadSIPInboundTrunk(ctx, oldID) + require.NoError(t, err) + require.True(t, proto.Equal(oldT.AsInbound(), inT2)) + + outT2, err = rs.LoadSIPOutboundTrunk(ctx, oldID) + require.NoError(t, err) + require.True(t, proto.Equal(oldT.AsOutbound(), outT2)) + + // Listing (always shows legacy + new) + listOld, err := rs.ListSIPTrunk(ctx, &livekit.ListSIPTrunkRequest{}) + require.NoError(t, err) + require.Len(t, listOld.Items, 3) + slices.SortFunc(listOld.Items, func(a, b *livekit.SIPTrunkInfo) int { + return strings.Compare(a.Name, b.Name) + }) + require.True(t, proto.Equal(inT.AsTrunkInfo(), listOld.Items[0])) + require.True(t, proto.Equal(oldT, listOld.Items[1])) + require.True(t, proto.Equal(outT.AsTrunkInfo(), listOld.Items[2])) + + listIn, err := rs.ListSIPInboundTrunk(ctx, &livekit.ListSIPInboundTrunkRequest{}) + require.NoError(t, err) + require.Len(t, listIn.Items, 2) + slices.SortFunc(listIn.Items, func(a, b *livekit.SIPInboundTrunkInfo) int { + return strings.Compare(a.Name, b.Name) + }) + require.True(t, proto.Equal(inT, listIn.Items[0])) + require.True(t, proto.Equal(oldT.AsInbound(), listIn.Items[1])) + + listOut, err := rs.ListSIPOutboundTrunk(ctx, &livekit.ListSIPOutboundTrunkRequest{}) + require.NoError(t, err) + require.Len(t, listOut.Items, 2) + slices.SortFunc(listOut.Items, func(a, b *livekit.SIPOutboundTrunkInfo) int { + return strings.Compare(a.Name, b.Name) + }) + require.True(t, proto.Equal(oldT.AsOutbound(), listOut.Items[0])) + require.True(t, proto.Equal(outT, listOut.Items[1])) + + // Deletion. Should not return error if not exists. + err = rs.DeleteSIPTrunk(ctx, oldID) + require.NoError(t, err) + err = rs.DeleteSIPTrunk(ctx, oldID) + require.NoError(t, err) + + // Other objects are still there. + inT2, err = rs.LoadSIPInboundTrunk(ctx, inID) + require.NoError(t, err) + require.True(t, proto.Equal(inT, inT2)) + + outT2, err = rs.LoadSIPOutboundTrunk(ctx, outID) + require.NoError(t, err) + require.True(t, proto.Equal(outT, outT2)) + + // Delete the rest + err = rs.DeleteSIPTrunk(ctx, inID) + require.NoError(t, err) + err = rs.DeleteSIPTrunk(ctx, outID) + require.NoError(t, err) + + // Check everything is deleted. + oldList, err = rs.ListSIPTrunk(ctx, &livekit.ListSIPTrunkRequest{}) + require.NoError(t, err) + require.Empty(t, oldList.Items) + + inList, err = rs.ListSIPInboundTrunk(ctx, &livekit.ListSIPInboundTrunkRequest{}) + require.NoError(t, err) + require.Empty(t, inList.Items) + + outList, err = rs.ListSIPOutboundTrunk(ctx, &livekit.ListSIPOutboundTrunkRequest{}) + require.NoError(t, err) + require.Empty(t, outList.Items) + + old, err = rs.LoadSIPTrunk(ctx, oldID) + require.Equal(t, service.ErrSIPTrunkNotFound, err) + require.Nil(t, old) + + in, err = rs.LoadSIPInboundTrunk(ctx, oldID) + require.Equal(t, service.ErrSIPTrunkNotFound, err) + require.Nil(t, in) + + out, err = rs.LoadSIPOutboundTrunk(ctx, oldID) + require.Equal(t, service.ErrSIPTrunkNotFound, err) + require.Nil(t, out) +} + +func TestSIPTrunkList(t *testing.T) { + s := redisStoreDocker(t) + + testIter(t, func(ctx context.Context, id string) error { + if strings.HasSuffix(id, "0") { + return s.StoreSIPTrunk(ctx, &livekit.SIPTrunkInfo{ + SipTrunkId: id, + OutboundNumber: id, + }) + } + return s.StoreSIPInboundTrunk(ctx, &livekit.SIPInboundTrunkInfo{ + SipTrunkId: id, + Numbers: []string{id}, + }) + }, func(ctx context.Context, page *livekit.Pagination, ids []string) iters.PageIter[*livekit.SIPInboundTrunkInfo] { + return livekit.ListPageIter(s.ListSIPInboundTrunk, &livekit.ListSIPInboundTrunkRequest{ + TrunkIds: ids, Page: page, + }) + }) +} + +func TestSIPRuleList(t *testing.T) { + s := redisStoreDocker(t) + + testIter(t, func(ctx context.Context, id string) error { + return s.StoreSIPDispatchRule(ctx, &livekit.SIPDispatchRuleInfo{ + SipDispatchRuleId: id, + TrunkIds: []string{id}, + }) + }, func(ctx context.Context, page *livekit.Pagination, ids []string) iters.PageIter[*livekit.SIPDispatchRuleInfo] { + return livekit.ListPageIter(s.ListSIPDispatchRule, &livekit.ListSIPDispatchRuleRequest{ + DispatchRuleIds: ids, Page: page, + }) + }) +} + +type listItem interface { + ID() string +} + +func allIDs[T listItem](t testing.TB, it iters.PageIter[T]) []string { + defer it.Close() + got, err := iters.AllPages(context.Background(), iters.MapPage(it, func(ctx context.Context, v T) (string, error) { + return v.ID(), nil + })) + require.NoError(t, err) + return got +} + +func testIter[T listItem]( + t *testing.T, + create func(ctx context.Context, id string) error, + list func(ctx context.Context, page *livekit.Pagination, ids []string) iters.PageIter[T], +) { + ctx := context.Background() + var all []string + for i := range 250 { + id := fmt.Sprintf("%05d", i) + all = append(all, id) + err := create(ctx, id) + require.NoError(t, err) + } + + // List everything with pagination disabled (legacy) + it := list(ctx, nil, nil) + got := allIDs(t, it) + require.Equal(t, all, got) + + // List with pagination enabled + it = list(ctx, &livekit.Pagination{Limit: 10}, nil) + got = allIDs(t, it) + require.Equal(t, all, got) + + // List with pagination enabled, custom ID + it = list(ctx, &livekit.Pagination{Limit: 10, AfterId: all[55]}, nil) + got = allIDs(t, it) + require.Equal(t, all[56:], got) + + // List fixed IDs + it = list(ctx, &livekit.Pagination{Limit: 10, AfterId: all[5]}, []string{ + all[10], + all[3], + "invalid", + all[8], + }) + got = allIDs(t, it) + require.Equal(t, []string{ + all[8], + all[10], + }, got) +} diff --git a/livekit/pkg/service/redisstore_test.go b/livekit/pkg/service/redisstore_test.go new file mode 100644 index 0000000..e84a6c4 --- /dev/null +++ b/livekit/pkg/service/redisstore_test.go @@ -0,0 +1,391 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/ingress" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/service" +) + +func redisStoreDocker(t testing.TB) *service.RedisStore { + return service.NewRedisStore(redisClientDocker(t)) +} + +func redisStore(t testing.TB) *service.RedisStore { + return service.NewRedisStore(redisClient(t)) +} + +func TestRoomInternal(t *testing.T) { + ctx := context.Background() + rs := redisStore(t) + + room := &livekit.Room{ + Sid: "123", + Name: "test_room", + } + internal := &livekit.RoomInternal{ + TrackEgress: &livekit.AutoTrackEgress{Filepath: "egress"}, + } + + require.NoError(t, rs.StoreRoom(ctx, room, internal)) + actualRoom, actualInternal, err := rs.LoadRoom(ctx, livekit.RoomName(room.Name), true) + require.NoError(t, err) + require.Equal(t, room.Sid, actualRoom.Sid) + require.Equal(t, internal.TrackEgress.Filepath, actualInternal.TrackEgress.Filepath) + + // remove internal + require.NoError(t, rs.StoreRoom(ctx, room, nil)) + _, actualInternal, err = rs.LoadRoom(ctx, livekit.RoomName(room.Name), true) + require.NoError(t, err) + require.Nil(t, actualInternal) + + // clean up + require.NoError(t, rs.DeleteRoom(ctx, "test_room")) +} + +func TestParticipantPersistence(t *testing.T) { + ctx := context.Background() + rs := redisStore(t) + + roomName := livekit.RoomName("room1") + _ = rs.DeleteRoom(ctx, roomName) + + p := &livekit.ParticipantInfo{ + Sid: "PA_test", + Identity: "test", + State: livekit.ParticipantInfo_ACTIVE, + Tracks: []*livekit.TrackInfo{ + { + Sid: "track1", + Type: livekit.TrackType_AUDIO, + Name: "audio", + }, + }, + } + + // create the participant + require.NoError(t, rs.StoreParticipant(ctx, roomName, p)) + + // result should match + pGet, err := rs.LoadParticipant(ctx, roomName, livekit.ParticipantIdentity(p.Identity)) + require.NoError(t, err) + require.Equal(t, p.Identity, pGet.Identity) + require.Equal(t, len(p.Tracks), len(pGet.Tracks)) + require.Equal(t, p.Tracks[0].Sid, pGet.Tracks[0].Sid) + + // list should return one participant + participants, err := rs.ListParticipants(ctx, roomName) + require.NoError(t, err) + require.Len(t, participants, 1) + + // deleting participant should return to normal + require.NoError(t, rs.DeleteParticipant(ctx, roomName, livekit.ParticipantIdentity(p.Identity))) + + participants, err = rs.ListParticipants(ctx, roomName) + require.NoError(t, err) + require.Len(t, participants, 0) + + // shouldn't be able to get it + _, err = rs.LoadParticipant(ctx, roomName, livekit.ParticipantIdentity(p.Identity)) + require.Equal(t, err, service.ErrParticipantNotFound) +} + +func TestRoomLock(t *testing.T) { + ctx := context.Background() + rs := redisStore(t) + lockInterval := 5 * time.Millisecond + roomName := livekit.RoomName("myroom") + + t.Run("normal locking", func(t *testing.T) { + token, err := rs.LockRoom(ctx, roomName, lockInterval) + require.NoError(t, err) + require.NotEmpty(t, token) + require.NoError(t, rs.UnlockRoom(ctx, roomName, token)) + }) + + t.Run("waits before acquiring lock", func(t *testing.T) { + token, err := rs.LockRoom(ctx, roomName, lockInterval) + require.NoError(t, err) + require.NotEmpty(t, token) + unlocked := atomic.NewUint32(0) + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + // attempt to lock again + defer wg.Done() + token2, err := rs.LockRoom(ctx, roomName, lockInterval) + require.NoError(t, err) + defer rs.UnlockRoom(ctx, roomName, token2) + require.Equal(t, uint32(1), unlocked.Load()) + }() + + // release after 2 ms + time.Sleep(2 * time.Millisecond) + unlocked.Store(1) + _ = rs.UnlockRoom(ctx, roomName, token) + + wg.Wait() + }) + + t.Run("lock expires", func(t *testing.T) { + token, err := rs.LockRoom(ctx, roomName, lockInterval) + require.NoError(t, err) + defer rs.UnlockRoom(ctx, roomName, token) + + time.Sleep(lockInterval + time.Millisecond) + token2, err := rs.LockRoom(ctx, roomName, lockInterval) + require.NoError(t, err) + _ = rs.UnlockRoom(ctx, roomName, token2) + }) +} + +func TestEgressStore(t *testing.T) { + ctx := context.Background() + rs := redisStore(t) + + roomName := "egress-test" + + // store egress info + info := &livekit.EgressInfo{ + EgressId: guid.New(utils.EgressPrefix), + RoomId: guid.New(utils.RoomPrefix), + RoomName: roomName, + Status: livekit.EgressStatus_EGRESS_STARTING, + Request: &livekit.EgressInfo_RoomComposite{ + RoomComposite: &livekit.RoomCompositeEgressRequest{ + RoomName: roomName, + Layout: "speaker-dark", + }, + }, + } + require.NoError(t, rs.StoreEgress(ctx, info)) + + // load + res, err := rs.LoadEgress(ctx, info.EgressId) + require.NoError(t, err) + require.Equal(t, res.EgressId, info.EgressId) + + // store another + info2 := &livekit.EgressInfo{ + EgressId: guid.New(utils.EgressPrefix), + RoomId: guid.New(utils.RoomPrefix), + RoomName: "another-egress-test", + Status: livekit.EgressStatus_EGRESS_STARTING, + Request: &livekit.EgressInfo_RoomComposite{ + RoomComposite: &livekit.RoomCompositeEgressRequest{ + RoomName: "another-egress-test", + Layout: "speaker-dark", + }, + }, + } + require.NoError(t, rs.StoreEgress(ctx, info2)) + + // update + info2.Status = livekit.EgressStatus_EGRESS_COMPLETE + info2.EndedAt = time.Now().Add(-24 * time.Hour).UnixNano() + require.NoError(t, rs.UpdateEgress(ctx, info)) + + // list + list, err := rs.ListEgress(ctx, "", false) + require.NoError(t, err) + require.Len(t, list, 2) + + // list by room + list, err = rs.ListEgress(ctx, livekit.RoomName(roomName), false) + require.NoError(t, err) + require.Len(t, list, 1) + + // update + info.Status = livekit.EgressStatus_EGRESS_COMPLETE + info.EndedAt = time.Now().Add(-24 * time.Hour).UnixNano() + require.NoError(t, rs.UpdateEgress(ctx, info)) + + // clean + require.NoError(t, rs.CleanEndedEgress()) + + // list + list, err = rs.ListEgress(ctx, livekit.RoomName(roomName), false) + require.NoError(t, err) + require.Len(t, list, 0) +} + +func TestIngressStore(t *testing.T) { + ctx := context.Background() + rs := redisStore(t) + + info := &livekit.IngressInfo{ + IngressId: "ingressId", + StreamKey: "streamKey", + State: &livekit.IngressState{ + StartedAt: 2, + }, + } + + err := rs.StoreIngress(ctx, info) + require.NoError(t, err) + + err = rs.UpdateIngressState(ctx, info.IngressId, info.State) + require.NoError(t, err) + + t.Cleanup(func() { + rs.DeleteIngress(ctx, info) + }) + + pulledInfo, err := rs.LoadIngress(ctx, "ingressId") + require.NoError(t, err) + compareIngressInfo(t, pulledInfo, info) + + infos, err := rs.ListIngress(ctx, "room") + require.NoError(t, err) + require.Equal(t, 0, len(infos)) + + info.RoomName = "room" + err = rs.UpdateIngress(ctx, info) + require.NoError(t, err) + + infos, err = rs.ListIngress(ctx, "room") + require.NoError(t, err) + + require.NoError(t, err) + require.Equal(t, 1, len(infos)) + compareIngressInfo(t, infos[0], info) + + info.RoomName = "" + err = rs.UpdateIngress(ctx, info) + require.NoError(t, err) + + infos, err = rs.ListIngress(ctx, "room") + require.NoError(t, err) + require.Equal(t, 0, len(infos)) + + info.State.StartedAt = 1 + err = rs.UpdateIngressState(ctx, info.IngressId, info.State) + require.Equal(t, ingress.ErrIngressOutOfDate, err) + + info.State.StartedAt = 3 + err = rs.UpdateIngressState(ctx, info.IngressId, info.State) + require.NoError(t, err) + + infos, err = rs.ListIngress(ctx, "") + require.NoError(t, err) + require.Equal(t, 1, len(infos)) + require.Equal(t, "", infos[0].RoomName) +} + +func TestAgentStore(t *testing.T) { + ctx := context.Background() + rs := redisStore(t) + + ad := &livekit.AgentDispatch{ + Id: "dispatch_id", + AgentName: "agent_name", + Metadata: "metadata", + Room: "room_name", + State: &livekit.AgentDispatchState{ + CreatedAt: 1, + DeletedAt: 2, + Jobs: []*livekit.Job{ + &livekit.Job{ + Id: "job_id", + DispatchId: "dispatch_id", + Type: livekit.JobType_JT_PUBLISHER, + Room: &livekit.Room{ + Name: "room_name", + }, + Participant: &livekit.ParticipantInfo{ + Identity: "identity", + Name: "name", + }, + Namespace: "ns", + Metadata: "metadata", + AgentName: "agent_name", + State: &livekit.JobState{ + Status: livekit.JobStatus_JS_RUNNING, + StartedAt: 3, + EndedAt: 4, + Error: "error", + }, + }, + }, + }, + } + + err := rs.StoreAgentDispatch(ctx, ad) + require.NoError(t, err) + + rd, err := rs.ListAgentDispatches(ctx, "not_a_room") + require.NoError(t, err) + require.Equal(t, 0, len(rd)) + + rd, err = rs.ListAgentDispatches(ctx, "room_name") + require.NoError(t, err) + require.Equal(t, 1, len(rd)) + + expected := utils.CloneProto(ad) + expected.State.Jobs = nil + require.True(t, proto.Equal(expected, rd[0])) + + err = rs.StoreAgentJob(ctx, ad.State.Jobs[0]) + require.NoError(t, err) + + rd, err = rs.ListAgentDispatches(ctx, "room_name") + require.NoError(t, err) + require.Equal(t, 1, len(rd)) + + expected = utils.CloneProto(ad) + expected.State.Jobs[0].Room = nil + expected.State.Jobs[0].Participant = &livekit.ParticipantInfo{ + Identity: "identity", + } + require.True(t, proto.Equal(expected, rd[0])) + + err = rs.DeleteAgentJob(ctx, ad.State.Jobs[0]) + require.NoError(t, err) + + rd, err = rs.ListAgentDispatches(ctx, "room_name") + require.NoError(t, err) + require.Equal(t, 1, len(rd)) + + expected = utils.CloneProto(ad) + expected.State.Jobs = nil + require.True(t, proto.Equal(expected, rd[0])) + + err = rs.DeleteAgentDispatch(ctx, ad) + require.NoError(t, err) + + rd, err = rs.ListAgentDispatches(ctx, "room_name") + require.NoError(t, err) + require.Equal(t, 0, len(rd)) +} + +func compareIngressInfo(t *testing.T, expected, v *livekit.IngressInfo) { + require.Equal(t, expected.IngressId, v.IngressId) + require.Equal(t, expected.StreamKey, v.StreamKey) + require.Equal(t, expected.RoomName, v.RoomName) +} diff --git a/livekit/pkg/service/roomallocator.go b/livekit/pkg/service/roomallocator.go new file mode 100644 index 0000000..43d0edc --- /dev/null +++ b/livekit/pkg/service/roomallocator.go @@ -0,0 +1,248 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/routing/selector" +) + +type StandardRoomAllocator struct { + config *config.Config + router routing.Router + selector selector.NodeSelector + roomStore ObjectStore +} + +func NewRoomAllocator(conf *config.Config, router routing.Router, rs ObjectStore) (RoomAllocator, error) { + ns, err := selector.CreateNodeSelector(conf) + if err != nil { + return nil, err + } + + return &StandardRoomAllocator{ + config: conf, + router: router, + selector: ns, + roomStore: rs, + }, nil +} + +func (r *StandardRoomAllocator) AutoCreateEnabled(context.Context) bool { + return r.config.Room.AutoCreate +} + +// CreateRoom creates a new room from a request and allocates it to a node to handle +// it'll also monitor its state, and cleans it up when appropriate +func (r *StandardRoomAllocator) CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest, isExplicit bool) (*livekit.Room, *livekit.RoomInternal, bool, error) { + token, err := r.roomStore.LockRoom(ctx, livekit.RoomName(req.Name), 5*time.Second) + if err != nil { + return nil, nil, false, err + } + defer func() { + _ = r.roomStore.UnlockRoom(ctx, livekit.RoomName(req.Name), token) + }() + + // find existing room and update it + var created bool + rm, internal, err := r.roomStore.LoadRoom(ctx, livekit.RoomName(req.Name), true) + if errors.Is(err, ErrRoomNotFound) { + created = true + now := time.Now() + rm = &livekit.Room{ + Sid: guid.New(utils.RoomPrefix), + Name: req.Name, + CreationTime: now.Unix(), + CreationTimeMs: now.UnixMilli(), + TurnPassword: utils.RandomSecret(), + } + internal = &livekit.RoomInternal{} + applyDefaultRoomConfig(rm, internal, &r.config.Room) + } else if err != nil { + return nil, nil, false, err + } + + req, err = r.applyNamedRoomConfiguration(req) + if err != nil { + return nil, nil, false, err + } + + if req.EmptyTimeout > 0 { + rm.EmptyTimeout = req.EmptyTimeout + } + if req.DepartureTimeout > 0 { + rm.DepartureTimeout = req.DepartureTimeout + } + if req.MaxParticipants > 0 { + rm.MaxParticipants = req.MaxParticipants + } + if req.Metadata != "" { + rm.Metadata = req.Metadata + } + if req.Egress != nil { + if req.Egress.Participant != nil { + internal.ParticipantEgress = req.Egress.Participant + } + if req.Egress.Tracks != nil { + internal.TrackEgress = req.Egress.Tracks + } + } + if req.Agents != nil { + internal.AgentDispatches = req.Agents + } + if req.MinPlayoutDelay > 0 || req.MaxPlayoutDelay > 0 { + internal.PlayoutDelay = &livekit.PlayoutDelay{ + Enabled: true, + Min: req.MinPlayoutDelay, + Max: req.MaxPlayoutDelay, + } + } + if req.SyncStreams { + internal.SyncStreams = true + } + + if err = r.roomStore.StoreRoom(ctx, rm, internal); err != nil { + return nil, nil, false, err + } + + return rm, internal, created, nil +} + +func (r *StandardRoomAllocator) SelectRoomNode(ctx context.Context, roomName livekit.RoomName, nodeID livekit.NodeID) error { + // check if room already assigned + existing, err := r.router.GetNodeForRoom(ctx, roomName) + if !errors.Is(err, routing.ErrNotFound) && err != nil { + return err + } + + // if already assigned and still available, keep it on that node + if err == nil && selector.IsAvailable(existing) { + // if node hosting the room is full, deny entry + if selector.LimitsReached(r.config.Limit, existing.Stats) { + return routing.ErrNodeLimitReached + } + + return nil + } + + // select a new node + if nodeID == "" { + nodes, err := r.router.ListNodes() + if err != nil { + return err + } + + node, err := r.selector.SelectNode(nodes) + if err != nil { + return err + } + + nodeID = livekit.NodeID(node.Id) + } + + logger.Infow("selected node for room", "room", roomName, "selectedNodeID", nodeID) + err = r.router.SetNodeForRoom(ctx, roomName, nodeID) + if err != nil { + return err + } + + return nil +} + +func (r *StandardRoomAllocator) ValidateCreateRoom(ctx context.Context, roomName livekit.RoomName) error { + // when auto create is disabled, we'll check to ensure it's already created + if !r.config.Room.AutoCreate { + _, _, err := r.roomStore.LoadRoom(ctx, roomName, false) + if err != nil { + return err + } + } + return nil +} + +func applyDefaultRoomConfig(room *livekit.Room, internal *livekit.RoomInternal, conf *config.RoomConfig) { + room.EmptyTimeout = conf.EmptyTimeout + room.DepartureTimeout = conf.DepartureTimeout + room.MaxParticipants = conf.MaxParticipants + for _, codec := range conf.EnabledCodecs { + room.EnabledCodecs = append(room.EnabledCodecs, &livekit.Codec{ + Mime: codec.Mime, + FmtpLine: codec.FmtpLine, + }) + } + internal.PlayoutDelay = &livekit.PlayoutDelay{ + Enabled: conf.PlayoutDelay.Enabled, + Min: uint32(conf.PlayoutDelay.Min), + Max: uint32(conf.PlayoutDelay.Max), + } + internal.SyncStreams = conf.SyncStreams +} + +func (r *StandardRoomAllocator) applyNamedRoomConfiguration(req *livekit.CreateRoomRequest) (*livekit.CreateRoomRequest, error) { + if req.RoomPreset == "" { + return req, nil + } + + conf, ok := r.config.Room.RoomConfigurations[req.RoomPreset] + if !ok { + return req, psrpc.NewErrorf(psrpc.InvalidArgument, "unknown room configuration in create room request") + } + + clone := utils.CloneProto(req) + + if clone.EmptyTimeout == 0 { + clone.EmptyTimeout = conf.EmptyTimeout + } + if clone.DepartureTimeout == 0 { + clone.DepartureTimeout = conf.DepartureTimeout + } + if clone.MaxParticipants == 0 { + clone.MaxParticipants = conf.MaxParticipants + } + if clone.Egress == nil { + clone.Egress = utils.CloneProto(conf.Egress) + } + if clone.Agents == nil { + clone.Agents = make([]*livekit.RoomAgentDispatch, 0, len(conf.Agents)) + for _, agent := range conf.Agents { + clone.Agents = append(clone.Agents, utils.CloneProto(agent)) + } + } + if clone.MinPlayoutDelay == 0 { + clone.MinPlayoutDelay = conf.MinPlayoutDelay + } + if clone.MaxPlayoutDelay == 0 { + clone.MaxPlayoutDelay = conf.MaxPlayoutDelay + } + if !clone.SyncStreams { + clone.SyncStreams = conf.SyncStreams + } + if clone.Metadata == "" { + clone.Metadata = conf.Metadata + } + + return clone, nil +} diff --git a/livekit/pkg/service/roomallocator_test.go b/livekit/pkg/service/roomallocator_test.go new file mode 100644 index 0000000..3d3dd29 --- /dev/null +++ b/livekit/pkg/service/roomallocator_test.go @@ -0,0 +1,98 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/routing/routingfakes" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/service/servicefakes" +) + +func TestCreateRoom(t *testing.T) { + t.Run("ensure default room settings are applied", func(t *testing.T) { + conf, err := config.NewConfig("", true, nil, nil) + require.NoError(t, err) + + node, err := routing.NewLocalNode(conf) + require.NoError(t, err) + + ra, conf := newTestRoomAllocator(t, conf, node.Clone()) + + room, _, _, err := ra.CreateRoom(context.Background(), &livekit.CreateRoomRequest{Name: "myroom"}, true) + require.NoError(t, err) + require.Equal(t, conf.Room.EmptyTimeout, room.EmptyTimeout) + require.Equal(t, conf.Room.DepartureTimeout, room.DepartureTimeout) + require.NotEmpty(t, room.EnabledCodecs) + }) +} + +func SelectRoomNode(t *testing.T) { + t.Run("reject new participants when track limit has been reached", func(t *testing.T) { + conf, err := config.NewConfig("", true, nil, nil) + require.NoError(t, err) + conf.Limit.NumTracks = 10 + + node, err := routing.NewLocalNode(conf) + require.NoError(t, err) + node.SetStats(&livekit.NodeStats{ + NumTracksIn: 100, + NumTracksOut: 100, + }) + + ra, _ := newTestRoomAllocator(t, conf, node.Clone()) + + err = ra.SelectRoomNode(context.Background(), "low-limit-room", "") + require.ErrorIs(t, err, routing.ErrNodeLimitReached) + }) + + t.Run("reject new participants when bandwidth limit has been reached", func(t *testing.T) { + conf, err := config.NewConfig("", true, nil, nil) + require.NoError(t, err) + conf.Limit.BytesPerSec = 100 + + node, err := routing.NewLocalNode(conf) + require.NoError(t, err) + node.SetStats(&livekit.NodeStats{ + BytesInPerSec: 1000, + BytesOutPerSec: 1000, + }) + + ra, _ := newTestRoomAllocator(t, conf, node.Clone()) + + err = ra.SelectRoomNode(context.Background(), "low-limit-room", "") + require.ErrorIs(t, err, routing.ErrNodeLimitReached) + }) +} + +func newTestRoomAllocator(t *testing.T, conf *config.Config, node *livekit.Node) (service.RoomAllocator, *config.Config) { + store := &servicefakes.FakeObjectStore{} + store.LoadRoomReturns(nil, nil, service.ErrRoomNotFound) + router := &routingfakes.FakeRouter{} + + router.GetNodeForRoomReturns(node, nil) + + ra, err := service.NewRoomAllocator(conf, router, store) + require.NoError(t, err) + return ra, conf +} diff --git a/livekit/pkg/service/roommanager.go b/livekit/pkg/service/roommanager.go new file mode 100644 index 0000000..df269ac --- /dev/null +++ b/livekit/pkg/service/roommanager.go @@ -0,0 +1,1163 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "os" + "sync" + "time" + + "github.com/pkg/errors" + "golang.org/x/exp/maps" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/sfu" + sutils "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/mediatransportutil/pkg/rtcconfig" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/utils/must" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/middleware" + + "github.com/livekit/livekit-server/pkg/clientconfiguration" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/version" +) + +const ( + tokenRefreshInterval = 5 * time.Minute + tokenDefaultTTL = 10 * time.Minute +) + +type iceConfigCacheKey struct { + roomName livekit.RoomName + participantIdentity livekit.ParticipantIdentity +} + +// RoomManager manages rooms and its interaction with participants. +// It's responsible for creating, deleting rooms, as well as running sessions for participants +type RoomManager struct { + lock sync.RWMutex + + config *config.Config + rtcConfig *rtc.WebRTCConfig + serverInfo *livekit.ServerInfo + currentNode routing.LocalNode + router routing.Router + roomAllocator RoomAllocator + roomManagerServer rpc.TypedRoomManagerServer + whipServer rpc.WHIPServer[livekit.NodeID] + roomStore ObjectStore + telemetry telemetry.TelemetryService + clientConfManager clientconfiguration.ClientConfigurationManager + agentClient agent.Client + agentStore AgentStore + egressLauncher rtc.EgressLauncher + versionGenerator utils.TimedVersionGenerator + turnAuthHandler *TURNAuthHandler + bus psrpc.MessageBus + + rooms map[livekit.RoomName]*rtc.Room + + roomServers utils.MultitonService[rpc.RoomTopic] + agentDispatchServers utils.MultitonService[rpc.RoomTopic] + participantServers utils.MultitonService[rpc.ParticipantTopic] + httpSignalParticipantServers utils.MultitonService[rpc.ParticipantTopic] + whipParticipantServers utils.MultitonService[rpc.ParticipantTopic] + + iceConfigCache *sutils.IceConfigCache[iceConfigCacheKey] + + forwardStats *sfu.ForwardStats + + rpc.UnimplementedParticipantServer + rpc.UnimplementedRoomServer + rpc.UnimplementedRoomManagerServer +} + +func NewLocalRoomManager( + conf *config.Config, + roomStore ObjectStore, + currentNode routing.LocalNode, + router routing.Router, + roomAllocator RoomAllocator, + telemetry telemetry.TelemetryService, + agentClient agent.Client, + agentStore AgentStore, + egressLauncher rtc.EgressLauncher, + versionGenerator utils.TimedVersionGenerator, + turnAuthHandler *TURNAuthHandler, + bus psrpc.MessageBus, + forwardStats *sfu.ForwardStats, +) (*RoomManager, error) { + rtcConf, err := rtc.NewWebRTCConfig(conf) + if err != nil { + return nil, err + } + + r := &RoomManager{ + config: conf, + rtcConfig: rtcConf, + currentNode: currentNode, + router: router, + roomAllocator: roomAllocator, + roomStore: roomStore, + telemetry: telemetry, + clientConfManager: clientconfiguration.NewStaticClientConfigurationManager(clientconfiguration.StaticConfigurations), + egressLauncher: egressLauncher, + agentClient: agentClient, + agentStore: agentStore, + versionGenerator: versionGenerator, + turnAuthHandler: turnAuthHandler, + bus: bus, + forwardStats: forwardStats, + + rooms: make(map[livekit.RoomName]*rtc.Room), + + iceConfigCache: sutils.NewIceConfigCache[iceConfigCacheKey](0), + + serverInfo: &livekit.ServerInfo{ + Edition: livekit.ServerInfo_Standard, + Version: version.Version, + Protocol: types.CurrentProtocol, + AgentProtocol: agent.CurrentProtocol, + Region: conf.Region, + NodeId: string(currentNode.NodeID()), + }, + } + + r.roomManagerServer, err = rpc.NewTypedRoomManagerServer(r, bus, rpc.WithServerLogger(logger.GetLogger()), middleware.WithServerMetrics(rpc.PSRPCMetricsObserver{}), psrpc.WithServerChannelSize(conf.PSRPC.BufferSize)) + if err != nil { + return nil, err + } + if err := r.roomManagerServer.RegisterAllNodeTopics(currentNode.NodeID()); err != nil { + return nil, err + } + + whipService, err := newWhipService(r) + if err != nil { + return nil, err + } + r.whipServer, err = rpc.NewWHIPServer[livekit.NodeID](whipService, bus, rpc.WithDefaultServerOptions(conf.PSRPC, logger.GetLogger())) + if err != nil { + return nil, err + } + if err := r.whipServer.RegisterAllCommonTopics(currentNode.NodeID()); err != nil { + return nil, err + } + + return r, nil +} + +func (r *RoomManager) GetRoom(_ context.Context, roomName livekit.RoomName) *rtc.Room { + r.lock.RLock() + defer r.lock.RUnlock() + return r.rooms[roomName] +} + +// deleteRoom completely deletes all room information, including active sessions, room store, and routing info +func (r *RoomManager) deleteRoom(ctx context.Context, roomName livekit.RoomName) error { + logger.Infow("deleting room state", "room", roomName) + r.lock.Lock() + delete(r.rooms, roomName) + r.lock.Unlock() + + var err, err2 error + wg := sync.WaitGroup{} + wg.Add(2) + // clear routing information + go func() { + defer wg.Done() + err = r.router.ClearRoomState(ctx, roomName) + }() + // also delete room from db + go func() { + defer wg.Done() + err2 = r.roomStore.DeleteRoom(ctx, roomName) + }() + + wg.Wait() + if err2 != nil { + err = err2 + } + + return err +} + +func (r *RoomManager) CloseIdleRooms() { + r.lock.RLock() + rooms := maps.Values(r.rooms) + r.lock.RUnlock() + + for _, room := range rooms { + room.CloseIfEmpty() + } +} + +func (r *RoomManager) HasParticipants() bool { + r.lock.RLock() + defer r.lock.RUnlock() + + for _, room := range r.rooms { + if len(room.GetParticipants()) != 0 { + return true + } + } + return false +} + +func (r *RoomManager) Stop() { + // disconnect all clients + r.lock.RLock() + rooms := maps.Values(r.rooms) + r.lock.RUnlock() + + for _, room := range rooms { + room.Close(types.ParticipantCloseReasonRoomManagerStop) + } + + r.roomManagerServer.Kill() + r.whipServer.Kill() + r.roomServers.Kill() + r.agentDispatchServers.Kill() + r.participantServers.Kill() + r.httpSignalParticipantServers.Kill() + r.whipParticipantServers.Kill() + + if r.rtcConfig != nil { + if r.rtcConfig.UDPMux != nil { + _ = r.rtcConfig.UDPMux.Close() + } + if r.rtcConfig.TCPMuxListener != nil { + _ = r.rtcConfig.TCPMuxListener.Close() + } + } + + r.iceConfigCache.Stop() + + if r.forwardStats != nil { + r.forwardStats.Stop() + } +} + +func (r *RoomManager) CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, error) { + room, err := r.getOrCreateRoom(ctx, req) + if err != nil { + return nil, err + } + defer room.Release() + + return room.ToProto(), nil +} + +// StartSession starts WebRTC session when a new participant is connected, takes place on RTC node +func (r *RoomManager) StartSession( + ctx context.Context, + pi routing.ParticipantInit, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + useOneShotSignallingMode bool, +) error { + sessionStartTime := time.Now() + + createRoom := pi.CreateRoom + room, err := r.getOrCreateRoom(ctx, createRoom) + if err != nil { + return err + } + defer room.Release() + + protoRoom, roomInternal := room.ToProto(), room.Internal() + + // only create the room, but don't start a participant session + if pi.Identity == "" { + return nil + } + + // should not error out, error is logged in iceServersForParticipant even if it fails + // since this is used for TURN server credentials, we don't want to fail the request even if there's no TURN for the session + apiKey, _, _ := r.getFirstKeyPair() + + participant := room.GetParticipant(pi.Identity) + if participant != nil { + // When reconnecting, it means WS has interrupted but underlying peer connection is still ok in this state, + // we'll keep the participant SID, and just swap the sink for the underlying connection + if pi.Reconnect { + if participant.IsClosed() { + // Send leave request if participant is closed, i. e. handle the case of client trying to resume crossing wires with + // server closing the participant due to some irrecoverable condition. Such a condition would have triggered + // a full reconnect when that condition occurred. + // + // It is possible that the client did not get that send request. So, send it again. + logger.Infow("cannot restart a closed participant", + "room", room.Name(), + "nodeID", r.currentNode.NodeID(), + "participant", pi.Identity, + "reason", pi.ReconnectReason, + ) + + var leave *livekit.LeaveRequest + pv := types.ProtocolVersion(pi.Client.Protocol) + if pv.SupportsRegionsInLeaveRequest() { + leave = &livekit.LeaveRequest{ + Reason: livekit.DisconnectReason_STATE_MISMATCH, + Action: livekit.LeaveRequest_RECONNECT, + } + } else { + leave = &livekit.LeaveRequest{ + CanReconnect: true, + Reason: livekit.DisconnectReason_STATE_MISMATCH, + } + } + _ = responseSink.WriteMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Leave{ + Leave: leave, + }, + }) + prometheus.IncrementParticipantRtcCanceled(1) + return errors.New("could not restart closed participant") + } + + participant.GetLogger().Infow( + "resuming RTC session", + "nodeID", r.currentNode.NodeID(), + "participantInit", &pi, + "numParticipants", room.GetParticipantCount(), + ) + iceConfig := r.getIceConfig(room.Name(), participant) + if err = room.ResumeParticipant( + participant, + requestSource, + responseSink, + iceConfig, + r.iceServersForParticipant( + apiKey, + participant, + iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TLS, + ), + pi.ReconnectReason, + ); err != nil { + participant.GetLogger().Warnw("could not resume participant", err) + return err + } + r.telemetry.ParticipantResumed(ctx, room.ToProto(), participant.ToProto(), r.currentNode.NodeID(), pi.ReconnectReason) + + go room.HandleSyncState(participant, pi.SyncState) + + go r.rtcSessionWorker(room, participant, requestSource) + return nil + } + + // we need to clean up the existing participant, so a new one can join + participant.GetLogger().Infow("removing duplicate participant") + room.RemoveParticipant(participant.Identity(), participant.ID(), types.ParticipantCloseReasonDuplicateIdentity) + } else if pi.Reconnect { + // send leave request if participant is trying to reconnect without keep subscribe state + // but missing from the room + var leave *livekit.LeaveRequest + pv := types.ProtocolVersion(pi.Client.Protocol) + if pv.SupportsRegionsInLeaveRequest() { + leave = &livekit.LeaveRequest{ + Reason: livekit.DisconnectReason_STATE_MISMATCH, + Action: livekit.LeaveRequest_RECONNECT, + } + } else { + leave = &livekit.LeaveRequest{ + CanReconnect: true, + Reason: livekit.DisconnectReason_STATE_MISMATCH, + } + } + _ = responseSink.WriteMessage(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Leave{ + Leave: leave, + }, + }) + prometheus.IncrementParticipantRtcCanceled(1) + return errors.New("could not restart participant") + } + + sid := livekit.ParticipantID(guid.New(utils.ParticipantPrefix)) + pLogger := rtc.LoggerWithParticipant( + rtc.LoggerWithRoom(logger.GetLogger(), room.Name(), room.ID()), + pi.Identity, + sid, + false, + ) + pLogger.Infow( + "starting RTC session", + "room", room.Name(), + "nodeID", r.currentNode.NodeID(), + "numParticipants", room.GetParticipantCount(), + "participantInit", &pi, + ) + + clientConf := r.clientConfManager.GetConfiguration(pi.Client) + + pv := types.ProtocolVersion(pi.Client.Protocol) + rtcConf := *r.rtcConfig + rtcConf.SetBufferFactory(room.GetBufferFactory()) + if pi.DisableICELite { + rtcConf.SettingEngine.SetLite(false) + } + rtcConf.UpdatePublisherConfig(pi.UseSinglePeerConnection) + + // default allow forceTCP + allowFallback := true + if r.config.RTC.AllowTCPFallback != nil { + allowFallback = *r.config.RTC.AllowTCPFallback + } + + // default do not force full reconnect on a publication error + reconnectOnPublicationError := false + if r.config.RTC.ReconnectOnPublicationError != nil { + reconnectOnPublicationError = *r.config.RTC.ReconnectOnPublicationError + } + + // default do not force full reconnect on a subscription error + reconnectOnSubscriptionError := false + if r.config.RTC.ReconnectOnSubscriptionError != nil { + reconnectOnSubscriptionError = *r.config.RTC.ReconnectOnSubscriptionError + } + + // default do not force full reconnect on a data channel error + reconnectOnDataChannelError := false + if r.config.RTC.ReconnectOnDataChannelError != nil { + reconnectOnDataChannelError = *r.config.RTC.ReconnectOnDataChannelError + } + + subscriberAllowPause := r.config.RTC.CongestionControl.AllowPause + if pi.SubscriberAllowPause != nil { + subscriberAllowPause = *pi.SubscriberAllowPause + } + + participant, err = rtc.NewParticipant(rtc.ParticipantParams{ + Identity: pi.Identity, + Name: pi.Name, + SID: sid, + Config: &rtcConf, + Sink: responseSink, + AudioConfig: r.config.Audio, + VideoConfig: r.config.Video, + LimitConfig: r.config.Limit, + ProtocolVersion: pv, + SessionStartTime: sessionStartTime, + Telemetry: r.telemetry, + Trailer: room.Trailer(), + PLIThrottleConfig: r.config.RTC.PLIThrottle, + CongestionControlConfig: r.config.RTC.CongestionControl, + PublishEnabledCodecs: protoRoom.EnabledCodecs, + SubscribeEnabledCodecs: protoRoom.EnabledCodecs, + Grants: pi.Grants, + Reconnect: pi.Reconnect, + Logger: pLogger, + Reporter: roomobs.NewNoopParticipantSessionReporter(), + ClientConf: clientConf, + ClientInfo: rtc.ClientInfo{ClientInfo: pi.Client}, + Region: pi.Region, + AdaptiveStream: pi.AdaptiveStream, + AllowTCPFallback: allowFallback, + TURNSEnabled: r.config.IsTURNSEnabled(), + ParticipantListener: room.LocalParticipantListener(), + ParticipantHelper: &roomManagerParticipantHelper{ + room: room, + codecRegressionThreshold: r.config.Video.CodecRegressionThreshold, + }, + ReconnectOnPublicationError: reconnectOnPublicationError, + ReconnectOnSubscriptionError: reconnectOnSubscriptionError, + ReconnectOnDataChannelError: reconnectOnDataChannelError, + VersionGenerator: r.versionGenerator, + SubscriberAllowPause: subscriberAllowPause, + SubscriptionLimitAudio: r.config.Limit.SubscriptionLimitAudio, + SubscriptionLimitVideo: r.config.Limit.SubscriptionLimitVideo, + PlayoutDelay: roomInternal.GetPlayoutDelay(), + SyncStreams: roomInternal.GetSyncStreams(), + ForwardStats: r.forwardStats, + MetricConfig: r.config.Metric, + UseOneShotSignallingMode: useOneShotSignallingMode, + DataChannelMaxBufferedAmount: r.config.RTC.DataChannelMaxBufferedAmount, + DatachannelSlowThreshold: r.config.RTC.DatachannelSlowThreshold, + DatachannelLossyTargetLatency: r.config.RTC.DatachannelLossyTargetLatency, + FireOnTrackBySdp: true, + UseSinglePeerConnection: pi.UseSinglePeerConnection, + EnableDataTracks: r.config.EnableDataTracks, + EnableRTPStreamRestartDetection: r.config.RTC.EnableRTPStreamRestartDetection, + }) + if err != nil { + return err + } + iceConfig := r.setIceConfig(room.Name(), participant) + + // join room + opts := rtc.ParticipantOptions{ + AutoSubscribe: pi.AutoSubscribe, + } + if pi.AutoSubscribeDataTrack != nil { + opts.AutoSubscribeDataTrack = *pi.AutoSubscribeDataTrack + } + iceServers := r.iceServersForParticipant(apiKey, participant, iceConfig.PreferenceSubscriber == livekit.ICECandidateType_ICT_TLS) + if err = room.Join(participant, requestSource, &opts, iceServers); err != nil { + pLogger.Errorw("could not join room", err) + _ = participant.Close(true, types.ParticipantCloseReasonJoinFailed, false) + return err + } + + var participantServerClosers utils.Closers + participantTopic := rpc.FormatParticipantTopic(room.Name(), participant.Identity()) + participantServer := must.Get(rpc.NewTypedParticipantServer(r, r.bus)) + participantServerClosers = append(participantServerClosers, utils.CloseFunc(r.participantServers.Replace(participantTopic, participantServer))) + if err := participantServer.RegisterAllParticipantTopics(participantTopic); err != nil { + participantServerClosers.Close() + pLogger.Errorw("could not join register participant topic", err) + _ = participant.Close(true, types.ParticipantCloseReasonMessageBusFailed, false) + return err + } + + if useOneShotSignallingMode { + whipParticipantServer := must.Get(rpc.NewTypedWHIPParticipantServer(whipParticipantService{r}, r.bus)) + participantServerClosers = append(participantServerClosers, utils.CloseFunc(r.whipParticipantServers.Replace(participantTopic, whipParticipantServer))) + if err := whipParticipantServer.RegisterAllCommonTopics(participantTopic); err != nil { + participantServerClosers.Close() + pLogger.Errorw("could not join register participant topic for rtc rest participant server", err) + _ = participant.Close(true, types.ParticipantCloseReasonMessageBusFailed, false) + return err + } + } + + if err = r.roomStore.StoreParticipant(ctx, room.Name(), participant.ToProto()); err != nil { + pLogger.Errorw("could not store participant", err) + } + + persistRoomForParticipantCount := func(proto *livekit.Room) { + if !participant.Hidden() && !room.IsClosed() { + err = r.roomStore.StoreRoom(ctx, proto, room.Internal()) + if err != nil { + logger.Errorw("could not store room", err) + } + } + } + + // update room store with new numParticipants + persistRoomForParticipantCount(room.ToProto()) + + clientMeta := &livekit.AnalyticsClientMeta{Region: r.currentNode.Region(), Node: string(r.currentNode.NodeID())} + r.telemetry.ParticipantJoined(ctx, protoRoom, participant.ToProto(), pi.Client, clientMeta, true, participant.TelemetryGuard()) + participant.AddOnClose(types.ParticipantCloseKeyNormal, func(p types.LocalParticipant) { + participantServerClosers.Close() + + if err := r.roomStore.DeleteParticipant(ctx, room.Name(), p.Identity()); err != nil { + pLogger.Errorw("could not delete participant", err) + } + + // update room store with new numParticipants + proto := room.ToProto() + persistRoomForParticipantCount(proto) + r.telemetry.ParticipantLeft(ctx, proto, p.ToProto(), true, participant.TelemetryGuard()) + }) + participant.OnClaimsChanged(func(participant types.LocalParticipant) { + pLogger.Debugw("refreshing client token after claims change") + if err := r.refreshToken(participant); err != nil { + pLogger.Errorw("could not refresh token", err) + } + }) + participant.OnICEConfigChanged(func(participant types.LocalParticipant, iceConfig *livekit.ICEConfig) { + r.iceConfigCache.Put(iceConfigCacheKey{room.Name(), participant.Identity()}, iceConfig) + }) + + for _, addTrackRequest := range pi.AddTrackRequests { + participant.AddTrack(addTrackRequest) + } + if pi.PublisherOffer != nil { + participant.HandleOffer(pi.PublisherOffer) + } + + go r.rtcSessionWorker(room, participant, requestSource) + return nil +} + +// create the actual room object, to be used on RTC node +func (r *RoomManager) getOrCreateRoom(ctx context.Context, createRoom *livekit.CreateRoomRequest) (*rtc.Room, error) { + roomName := livekit.RoomName(createRoom.Name) + + r.lock.RLock() + lastSeenRoom := r.rooms[roomName] + r.lock.RUnlock() + + if lastSeenRoom != nil && lastSeenRoom.Hold() { + return lastSeenRoom, nil + } + + // create new room, get details first + ri, internal, created, err := r.roomAllocator.CreateRoom(ctx, createRoom, true) + if err != nil { + return nil, err + } + + r.lock.Lock() + + currentRoom := r.rooms[roomName] + for currentRoom != lastSeenRoom { + r.lock.Unlock() + if currentRoom != nil && currentRoom.Hold() { + return currentRoom, nil + } + + lastSeenRoom = currentRoom + r.lock.Lock() + currentRoom = r.rooms[roomName] + } + + // construct ice servers + newRoom := rtc.NewRoom(ri, internal, *r.rtcConfig, r.config.Room, &r.config.Audio, r.serverInfo, r.telemetry, r.agentClient, r.agentStore, r.egressLauncher) + + roomTopic := rpc.FormatRoomTopic(roomName) + roomServer := must.Get(rpc.NewTypedRoomServer(r, r.bus)) + killRoomServer := r.roomServers.Replace(roomTopic, roomServer) + if err := roomServer.RegisterAllRoomTopics(roomTopic); err != nil { + killRoomServer() + r.lock.Unlock() + return nil, err + } + agentDispatchServer := must.Get(rpc.NewTypedAgentDispatchInternalServer(r, r.bus)) + killDispServer := r.agentDispatchServers.Replace(roomTopic, agentDispatchServer) + if err := agentDispatchServer.RegisterAllRoomTopics(roomTopic); err != nil { + killRoomServer() + killDispServer() + r.lock.Unlock() + return nil, err + } + + newRoom.OnClose(func() { + killRoomServer() + killDispServer() + + roomInfo := newRoom.ToProto() + r.telemetry.RoomEnded(ctx, roomInfo) + prometheus.RoomEnded(time.Unix(roomInfo.CreationTime, 0)) + if err := r.deleteRoom(ctx, roomName); err != nil { + newRoom.Logger().Errorw("could not delete room", err) + } + + newRoom.Logger().Infow("room closed") + }) + + newRoom.OnRoomUpdated(func() { + if err := r.roomStore.StoreRoom(ctx, newRoom.ToProto(), newRoom.Internal()); err != nil { + newRoom.Logger().Errorw("could not handle metadata update", err) + } + }) + + newRoom.OnParticipantChanged(func(p types.Participant) { + if !p.IsDisconnected() { + if err := r.roomStore.StoreParticipant(ctx, roomName, p.ToProto()); err != nil { + newRoom.Logger().Errorw("could not handle participant change", err) + } + } + }) + + r.rooms[roomName] = newRoom + + r.lock.Unlock() + + newRoom.Hold() + + r.telemetry.RoomStarted(ctx, newRoom.ToProto()) + prometheus.RoomStarted() + + if created && createRoom.GetEgress().GetRoom() != nil { + // ensure room name matches + createRoom.Egress.Room.RoomName = createRoom.Name + _, err = r.egressLauncher.StartEgress(ctx, &rpc.StartEgressRequest{ + Request: &rpc.StartEgressRequest_RoomComposite{ + RoomComposite: createRoom.Egress.Room, + }, + RoomId: ri.Sid, + }) + if err != nil { + newRoom.Release() + return nil, err + } + } + + return newRoom, nil +} + +// manages an RTC session for a participant, runs on the RTC node +func (r *RoomManager) rtcSessionWorker(room *rtc.Room, participant types.LocalParticipant, requestSource routing.MessageSource) { + pLogger := participant.GetLogger() + defer func() { + pLogger.Debugw("RTC session finishing", "connID", requestSource.ConnectionID()) + requestSource.Close() + }() + + defer func() { + if r := rtc.Recover(pLogger); r != nil { + os.Exit(1) + } + }() + + // send first refresh for cases when client token is close to expiring + _ = r.refreshToken(participant) + tokenTicker := time.NewTicker(tokenRefreshInterval) + defer tokenTicker.Stop() + for { + select { + case <-participant.Disconnected(): + return + + case <-tokenTicker.C: + // refresh token with the first API Key/secret pair + if err := r.refreshToken(participant); err != nil { + pLogger.Errorw("could not refresh token", err, "connID", requestSource.ConnectionID()) + } + + case obj := <-requestSource.ReadChan(): + if obj == nil { + if room.GetParticipantRequestSource(participant.Identity()) == requestSource { + participant.HandleSignalSourceClose() + } + return + } + + if err := participant.HandleSignalMessage(obj); err != nil { + // more specific errors are already logged + // treat errors returned as fatal + return + } + } + } +} + +type participantReq interface { + GetRoom() string + GetIdentity() string +} + +func (r *RoomManager) roomAndParticipantForReq(ctx context.Context, req participantReq) (*rtc.Room, types.LocalParticipant, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.GetRoom())) + if room == nil { + return nil, nil, ErrRoomNotFound + } + + participant := room.GetParticipant(livekit.ParticipantIdentity(req.GetIdentity())) + if participant == nil { + return nil, nil, ErrParticipantNotFound + } + + return room, participant, nil +} + +func (r *RoomManager) RemoveParticipant(ctx context.Context, req *livekit.RoomParticipantIdentity) (*livekit.RemoveParticipantResponse, error) { + room, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err + } + + participant.GetLogger().Infow("removing participant") + room.RemoveParticipant(livekit.ParticipantIdentity(req.Identity), "", types.ParticipantCloseReasonServiceRequestRemoveParticipant) + return &livekit.RemoveParticipantResponse{}, nil +} + +func (r *RoomManager) MutePublishedTrack(ctx context.Context, req *livekit.MuteRoomTrackRequest) (*livekit.MuteRoomTrackResponse, error) { + _, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err + } + + participant.GetLogger().Debugw("setting track muted", + "trackID", req.TrackSid, "muted", req.Muted) + if !req.Muted && !r.config.Room.EnableRemoteUnmute { + participant.GetLogger().Errorw("cannot unmute track, remote unmute is disabled", nil) + return nil, ErrRemoteUnmuteNoteEnabled + } + track := participant.SetTrackMuted(&livekit.MuteTrackRequest{ + Sid: req.TrackSid, + Muted: req.Muted, + }, true) + return &livekit.MuteRoomTrackResponse{Track: track}, nil +} + +func (r *RoomManager) UpdateParticipant(ctx context.Context, req *livekit.UpdateParticipantRequest) (*livekit.ParticipantInfo, error) { + _, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err + } + + if err = participant.UpdateMetadata(&livekit.UpdateParticipantMetadata{ + Name: req.Name, + Metadata: req.Metadata, + Attributes: req.Attributes, + }, true); err != nil { + return nil, err + } + + if req.Permission != nil { + participant.GetLogger().Debugw( + "updating participant permission", + "permission", req.Permission, + ) + + participant.SetPermission(req.Permission) + } + + return participant.ToProto(), nil +} + +func (r *RoomManager) ForwardParticipant(ctx context.Context, req *livekit.ForwardParticipantRequest) (*livekit.ForwardParticipantResponse, error) { + return nil, errors.New("not implemented") +} + +func (r *RoomManager) MoveParticipant(ctx context.Context, req *livekit.MoveParticipantRequest) (*livekit.MoveParticipantResponse, error) { + return nil, errors.New("not implemented") +} + +func (r *RoomManager) PerformRpc(ctx context.Context, req *livekit.PerformRpcRequest) (*livekit.PerformRpcResponse, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.GetRoom())) + if room == nil { + return nil, ErrRoomNotFound + } + + participant := room.GetParticipant(livekit.ParticipantIdentity(req.GetDestinationIdentity())) + if participant == nil { + return nil, ErrParticipantNotFound + } + + resultChan := make(chan string, 1) + errorChan := make(chan error, 1) + + participant.PerformRpc(req, resultChan, errorChan) + + select { + case result := <-resultChan: + return &livekit.PerformRpcResponse{Payload: result}, nil + case err := <-errorChan: + return nil, err + } +} + +func (r *RoomManager) DeleteRoom(ctx context.Context, req *livekit.DeleteRoomRequest) (*livekit.DeleteRoomResponse, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + // special case of a non-RTC room e.g. room created but no participants joined + logger.Debugw("Deleting non-rtc room, loading from roomstore") + err := r.roomStore.DeleteRoom(ctx, livekit.RoomName(req.Room)) + if err != nil { + logger.Debugw("Error deleting non-rtc room", "err", err) + return nil, err + } + } else { + room.Logger().Infow("deleting room") + room.Close(types.ParticipantCloseReasonServiceRequestDeleteRoom) + } + return &livekit.DeleteRoomResponse{}, nil +} + +func (r *RoomManager) UpdateSubscriptions(ctx context.Context, req *livekit.UpdateSubscriptionsRequest) (*livekit.UpdateSubscriptionsResponse, error) { + room, participant, err := r.roomAndParticipantForReq(ctx, req) + if err != nil { + return nil, err + } + + participant.GetLogger().Debugw("updating participant subscriptions") + room.UpdateSubscriptions( + participant, + livekit.StringsAsIDs[livekit.TrackID](req.TrackSids), + req.ParticipantTracks, + req.Subscribe, + ) + return &livekit.UpdateSubscriptionsResponse{}, nil +} + +func (r *RoomManager) SendData(ctx context.Context, req *livekit.SendDataRequest) (*livekit.SendDataResponse, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + room.Logger().Debugw("api send data", "size", len(req.Data)) + room.SendDataPacket(&livekit.DataPacket{ + Kind: req.Kind, + DestinationIdentities: req.DestinationIdentities, + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{ + Payload: req.Data, + DestinationSids: req.DestinationSids, + DestinationIdentities: req.DestinationIdentities, + Topic: req.Topic, + Nonce: req.Nonce, + }, + }, + }, req.Kind) + return &livekit.SendDataResponse{}, nil +} + +func (r *RoomManager) UpdateRoomMetadata(ctx context.Context, req *livekit.UpdateRoomMetadataRequest) (*livekit.Room, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + room.Logger().Debugw("updating room") + done := room.SetMetadata(req.Metadata) + // wait till the update is applied + <-done + return room.ToProto(), nil +} + +func (r *RoomManager) ListDispatch(ctx context.Context, req *livekit.ListAgentDispatchRequest) (*livekit.ListAgentDispatchResponse, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.GetAgentDispatches(req.DispatchId) + if err != nil { + return nil, err + } + + ret := &livekit.ListAgentDispatchResponse{ + AgentDispatches: disp, + } + + return ret, nil +} + +func (r *RoomManager) CreateDispatch(ctx context.Context, req *livekit.AgentDispatch) (*livekit.AgentDispatch, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.AddAgentDispatch(req) + if err != nil { + return nil, err + } + + return disp, nil +} + +func (r *RoomManager) DeleteDispatch(ctx context.Context, req *livekit.DeleteAgentDispatchRequest) (*livekit.AgentDispatch, error) { + room := r.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + disp, err := room.DeleteAgentDispatch(req.DispatchId) + if err != nil { + return nil, err + } + + return disp, nil +} + +func (r *RoomManager) iceServersForParticipant(apiKey string, participant types.LocalParticipant, tlsOnly bool) []*livekit.ICEServer { + var iceServers []*livekit.ICEServer + rtcConf := r.config.RTC + + if tlsOnly && r.config.TURN.TLSPort == 0 { + logger.Warnw("tls only enabled but no turn tls config", nil) + tlsOnly = false + } + + hasSTUN := false + if r.config.TURN.Enabled { + var urls []string + if r.config.TURN.UDPPort > 0 && !tlsOnly { + // UDP TURN is used as STUN + hasSTUN = true + urls = append(urls, fmt.Sprintf("turn:%s:%d?transport=udp", r.config.RTC.NodeIP, r.config.TURN.UDPPort)) + } + if r.config.TURN.TLSPort > 0 { + urls = append(urls, fmt.Sprintf("turns:%s:443?transport=tcp", r.config.TURN.Domain)) + } + if len(urls) > 0 { + username := r.turnAuthHandler.CreateUsername(apiKey, participant.ID()) + password, err := r.turnAuthHandler.CreatePassword(apiKey, participant.ID()) + if err != nil { + participant.GetLogger().Warnw("could not create turn password", err) + hasSTUN = false + } else { + logger.Infow("created TURN password", "username", username, "password", password) + iceServers = append(iceServers, &livekit.ICEServer{ + Urls: urls, + Username: username, + Credential: password, + }) + } + } + } + + if len(rtcConf.TURNServers) > 0 { + hasSTUN = true + for _, s := range r.config.RTC.TURNServers { + scheme := "turn" + transport := "tcp" + switch s.Protocol { + case "tls": + scheme = "turns" + case "udp": + transport = "udp" + } + + var username, credential string + if s.Secret != "" { + // Generate dynamic credentials using TURN static auth secrets + ttl := s.TTL + if ttl == 0 { + ttl = 14400 // Default 4 hours + } + + expiry := time.Now().Add(time.Duration(ttl) * time.Second).Unix() + participantID := string(participant.ID()) + username = fmt.Sprintf("%d:%s", expiry, participantID) + + // HMAC-SHA1 signature + h := hmac.New(sha1.New, []byte(s.Secret)) + h.Write([]byte(username)) + credential = base64.StdEncoding.EncodeToString(h.Sum(nil)) + } else { + // Use static credentials + username = s.Username + credential = s.Credential + } + + is := &livekit.ICEServer{ + Urls: []string{ + fmt.Sprintf("%s:%s:%d?transport=%s", scheme, s.Host, s.Port, transport), + }, + Username: username, + Credential: credential, + } + iceServers = append(iceServers, is) + } + } + + if len(rtcConf.STUNServers) > 0 { + hasSTUN = true + iceServers = append(iceServers, iceServerForStunServers(r.config.RTC.STUNServers)) + } + + if !hasSTUN { + iceServers = append(iceServers, iceServerForStunServers(rtcconfig.DefaultStunServers)) + } + return iceServers +} + +func (r *RoomManager) refreshToken(participant types.LocalParticipant) error { + key, secret, err := r.getFirstKeyPair() + if err != nil { + return err + } + + grants := participant.ClaimGrants() + token := auth.NewAccessToken(key, secret) + token.SetName(grants.Name). + SetIdentity(string(participant.Identity())). + SetKind(grants.GetParticipantKind()). + SetValidFor(tokenDefaultTTL). + SetMetadata(grants.Metadata). + SetAttributes(grants.Attributes). + SetVideoGrant(grants.Video). + SetRoomConfig(grants.GetRoomConfiguration()). + SetRoomPreset(grants.RoomPreset) + jwt, err := token.ToJWT() + if err == nil { + err = participant.SendRefreshToken(jwt) + } + if err != nil { + return err + } + + return nil +} + +func (r *RoomManager) setIceConfig(roomName livekit.RoomName, participant types.LocalParticipant) *livekit.ICEConfig { + iceConfig := r.getIceConfig(roomName, participant) + participant.SetICEConfig(iceConfig) + return iceConfig +} + +func (r *RoomManager) getIceConfig(roomName livekit.RoomName, participant types.LocalParticipant) *livekit.ICEConfig { + return r.iceConfigCache.Get(iceConfigCacheKey{roomName, participant.Identity()}) +} + +func (r *RoomManager) getFirstKeyPair() (string, string, error) { + for key, secret := range r.config.Keys { + return key, secret, nil + } + return "", "", errors.New("no API keys configured") +} + +// ------------------------------------ + +func iceServerForStunServers(servers []string) *livekit.ICEServer { + iceServer := &livekit.ICEServer{} + for _, stunServer := range servers { + iceServer.Urls = append(iceServer.Urls, fmt.Sprintf("stun:%s", stunServer)) + } + return iceServer +} + +// ------------------------------------ + +type roomManagerParticipantHelper struct { + room *rtc.Room + codecRegressionThreshold int +} + +func (h *roomManagerParticipantHelper) GetParticipantInfo(pID livekit.ParticipantID) *livekit.ParticipantInfo { + if p := h.room.GetParticipantByID(pID); p != nil { + return p.ToProto() + } + return nil +} + +func (h *roomManagerParticipantHelper) GetRegionSettings(ip string) *livekit.RegionSettings { + return nil +} + +func (h *roomManagerParticipantHelper) GetSubscriberForwarderState(lp types.LocalParticipant) (map[livekit.TrackID]*livekit.RTPForwarderState, error) { + return nil, nil +} + +func (h *roomManagerParticipantHelper) ResolveMediaTrack(lp types.LocalParticipant, trackID livekit.TrackID) types.MediaResolverResult { + return h.room.ResolveMediaTrackForSubscriber(lp, trackID) +} + +func (h *roomManagerParticipantHelper) ResolveDataTrack(lp types.LocalParticipant, trackID livekit.TrackID) types.DataResolverResult { + return h.room.ResolveDataTrackForSubscriber(lp, trackID) +} + +func (h *roomManagerParticipantHelper) ShouldRegressCodec() bool { + return h.codecRegressionThreshold == 0 || h.room.GetParticipantCount() < h.codecRegressionThreshold +} + +func (h *roomManagerParticipantHelper) GetCachedReliableDataMessage(seqs map[livekit.ParticipantID]uint32) []*types.DataMessageCache { + return h.room.GetCachedReliableDataMessage(seqs) +} diff --git a/livekit/pkg/service/roommanager_service.go b/livekit/pkg/service/roommanager_service.go new file mode 100644 index 0000000..0b5ce6d --- /dev/null +++ b/livekit/pkg/service/roommanager_service.go @@ -0,0 +1,337 @@ +package service + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/pion/webrtc/v4" + "golang.org/x/sync/errgroup" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" +) + +const ( + whipSessionNotifyInterval = 10 * time.Second +) + +type whipService struct { + *RoomManager + + ingressRpcCli rpc.IngressHandlerClient + + rpc.UnimplementedWHIPServer +} + +func newWhipService(rm *RoomManager) (*whipService, error) { + cli, err := rpc.NewIngressHandlerClient(rm.bus, rpc.WithDefaultClientOptions(logger.GetLogger())) + if err != nil { + return nil, err + } + return &whipService{ + RoomManager: rm, + ingressRpcCli: cli, + }, nil +} + +func (s whipService) Create(ctx context.Context, req *rpc.WHIPCreateRequest) (*rpc.WHIPCreateResponse, error) { + pi, err := routing.ParticipantInitFromStartSession(req.StartSession, s.RoomManager.currentNode.Region()) + if err != nil { + logger.Errorw("whip service: could not create participant init", err) + return nil, err + } + + prometheus.IncrementParticipantRtcInit(1) + + if err = s.RoomManager.StartSession( + ctx, + *pi, + routing.NewNullMessageSource(livekit.ConnectionID(req.StartSession.ConnectionId)), // no requestSource + routing.NewNullMessageSink(livekit.ConnectionID(req.StartSession.ConnectionId)), // no responseSink + true, // useOneShotSignallingMode + ); err != nil { + logger.Errorw("whip service: could not start session", err) + return nil, err + } + + room := s.RoomManager.GetRoom(ctx, livekit.RoomName(req.StartSession.RoomName)) + if room == nil { + logger.Errorw("whip service: could not find room", nil, "room", req.StartSession.RoomName) + return nil, ErrRoomNotFound + } + + lp := room.GetParticipant(pi.Identity) + if lp == nil { + room.Logger().Errorw("whip service: could not find local participant", nil, "participant", pi.Identity) + return nil, ErrParticipantNotFound + } + + if err := lp.HandleOffer(&livekit.SessionDescription{ + Type: webrtc.SDPTypeOffer.String(), + Sdp: req.OfferSdp, + Id: 0, + }); err != nil { + lp.GetLogger().Errorw("whip service: could not handle offer", err) + return nil, err + } + + // wait for subscriptions to resolve + // NOTE: this is outside the WHIP spec, but added as a convenience for clients doing + // one-shot signalling (i. e. send an offer and get an answer once) to publish and subscribe to + // well-known tracks (i. e. remote participant identity and track names are well known) + eg, _ := errgroup.WithContext(ctx) + for publisherIdentity, trackList := range req.SubscribedParticipantTracks { + for _, trackName := range trackList.TrackNames { + eg.Go(func() error { + for { + if lp.IsTrackNameSubscribed(livekit.ParticipantIdentity(publisherIdentity), trackName) { + return nil + } + time.Sleep(50 * time.Millisecond) + } + }) + } + } + err = eg.Wait() + if err != nil { + lp.GetLogger().Errorw("whip service: could not subscribe to tracks", err) + return nil, err + } + + answer, _, err := lp.GetAnswer() + if err != nil { + lp.GetLogger().Errorw("whip service: could not get answer", err) + return nil, err + } + + iceSessionID, err := lp.GetPublisherICESessionUfrag() + if err != nil { + lp.GetLogger().Errorw("whip service: could not get ICE session ID", err) + return nil, err + } + + if req.FromIngress { + aliveCtx, cancel := context.WithCancel(context.Background()) + + lp.AddOnClose(types.ParticipantCloseKeyWHIP, func(lp types.LocalParticipant) { + cancel() + + go func() { + lp.GetLogger().Debugw("whip service: notify participant closed") + + video, audio := getMediaStateForParticipant(lp) + + _, err := s.ingressRpcCli.WHIPRTCConnectionNotify(context.Background(), string(lp.ID()), &rpc.WHIPRTCConnectionNotifyRequest{ + ParticipantId: string(lp.ID()), + Closed: true, + Audio: audio, + Video: video, + }, psrpc.WithRequestTimeout(rpc.DefaultPSRPCConfig.Timeout)) + if err != nil { + lp.GetLogger().Warnw("whip service: could not notify ingress of participant closed", err) + } + }() + }) + go func() { + if err := s.notifySession(aliveCtx, lp); err != nil { + cancel() + } + }() + } + + var iceServers []*livekit.ICEServer + apiKey, _, err := s.RoomManager.getFirstKeyPair() + if err == nil { + iceServers = s.RoomManager.iceServersForParticipant( + apiKey, + lp, + false, + ) + } + return &rpc.WHIPCreateResponse{ + AnswerSdp: answer.SDP, + ParticipantId: string(lp.ID()), + IceServers: iceServers, + IceSessionId: iceSessionID, + }, nil +} + +func (s whipService) notifySession(ctx context.Context, participant types.Participant) error { + ticker := time.NewTicker(whipSessionNotifyInterval) + defer ticker.Stop() + + err := s.sendConnectionNotify(ctx, participant) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + } + + for { + select { + case <-ticker.C: + err := s.sendConnectionNotify(ctx, participant) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + } + + case <-ctx.Done(): + return nil + } + } +} + +func (s whipService) sendConnectionNotify(ctx context.Context, participant types.Participant) error { + video, audio := getMediaStateForParticipant(participant) + + _, err := s.ingressRpcCli.WHIPRTCConnectionNotify(ctx, string(participant.ID()), &rpc.WHIPRTCConnectionNotifyRequest{ + ParticipantId: string(participant.ID()), + Video: video, + Audio: audio, + }, psrpc.WithRequestTimeout(rpc.DefaultPSRPCConfig.Timeout)) + + return err +} + +func getMediaStateForParticipant(participant types.Participant) (*livekit.InputVideoState, *livekit.InputAudioState) { + pParticipant := participant.ToProto() + + var video *livekit.InputVideoState + var audio *livekit.InputAudioState + + for _, v := range pParticipant.Tracks { + if v == nil { + continue + } + + if v.Type != livekit.TrackType_VIDEO { + continue + } + + video = &livekit.InputVideoState{} + + video.MimeType = v.MimeType + video.Height = v.Height + video.Width = v.Width + + break + } + + for _, a := range pParticipant.Tracks { + if a == nil { + continue + } + + if a.Type != livekit.TrackType_AUDIO { + continue + } + + audio = &livekit.InputAudioState{} + + audio.MimeType = a.MimeType + audio.Channels = 1 + if a.Stereo { + audio.Channels = 2 + } + + break + } + + return video, audio +} + +// ------------------------------------------- + +type whipParticipantService struct { + *RoomManager +} + +func (r whipParticipantService) ICETrickle(ctx context.Context, req *rpc.WHIPParticipantICETrickleRequest) (*emptypb.Empty, error) { + room := r.RoomManager.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + lp := room.GetParticipantByID(livekit.ParticipantID(req.ParticipantId)) + if lp == nil { + return nil, ErrParticipantNotFound + } + + iceSessionID, err := lp.GetPublisherICESessionUfrag() + if err != nil { + lp.GetLogger().Warnw("whipParticipant service ice-trickle: could not get ICE session ufrag", err) + return nil, psrpc.NewError(psrpc.Internal, err) + } + + if req.IceSessionId != iceSessionID { + return nil, psrpc.NewError( + psrpc.FailedPrecondition, + fmt.Errorf("ice session mismatch, expected: %s, got: %s", iceSessionID, req.IceSessionId), + ) + } + + if err := lp.HandleICETrickleSDPFragment(req.SdpFragment); err != nil { + return nil, psrpc.NewError(psrpc.InvalidArgument, err) + } + + return &emptypb.Empty{}, nil +} + +func (r whipParticipantService) ICERestart(ctx context.Context, req *rpc.WHIPParticipantICERestartRequest) (*rpc.WHIPParticipantICERestartResponse, error) { + room := r.RoomManager.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + lp := room.GetParticipantByID(livekit.ParticipantID(req.ParticipantId)) + if lp == nil { + return nil, ErrParticipantNotFound + } + + sdpFragment, err := lp.HandleICERestartSDPFragment(req.SdpFragment) + if err != nil { + return nil, psrpc.NewError(psrpc.InvalidArgument, err) + } + + iceSessionID, err := lp.GetPublisherICESessionUfrag() + if err != nil { + lp.GetLogger().Warnw("whipParticipant service ice-restart: could not get ICE session ufrag", err) + return nil, psrpc.NewError(psrpc.Internal, err) + } + + return &rpc.WHIPParticipantICERestartResponse{ + IceSessionId: iceSessionID, + SdpFragment: sdpFragment, + }, nil +} + +func (r whipParticipantService) DeleteSession(ctx context.Context, req *rpc.WHIPParticipantDeleteSessionRequest) (*emptypb.Empty, error) { + room := r.RoomManager.GetRoom(ctx, livekit.RoomName(req.Room)) + if room == nil { + return nil, ErrRoomNotFound + } + + lp := room.GetParticipantByID(livekit.ParticipantID(req.ParticipantId)) + if lp != nil { + lp.AddOnClose(types.ParticipantCloseKeyWHIP, nil) + room.RemoveParticipant( + lp.Identity(), + lp.ID(), + types.ParticipantCloseReasonClientRequestLeave, + ) + } + + return &emptypb.Empty{}, nil +} + +// -------------------------------- diff --git a/livekit/pkg/service/roomservice.go b/livekit/pkg/service/roomservice.go new file mode 100644 index 0000000..84fbe79 --- /dev/null +++ b/livekit/pkg/service/roomservice.go @@ -0,0 +1,377 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "fmt" + "strconv" + + "github.com/twitchtv/twirp" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" +) + +type RoomService struct { + limitConf config.LimitConfig + apiConf config.APIConfig + router routing.MessageRouter + roomAllocator RoomAllocator + roomStore ServiceStore + egressLauncher rtc.EgressLauncher + topicFormatter rpc.TopicFormatter + roomClient rpc.TypedRoomClient + participantClient rpc.TypedParticipantClient + + rpc.UnimplementedRoomServer + rpc.UnimplementedParticipantServer +} + +func NewRoomService( + limitConf config.LimitConfig, + apiConf config.APIConfig, + router routing.MessageRouter, + roomAllocator RoomAllocator, + serviceStore ServiceStore, + egressLauncher rtc.EgressLauncher, + topicFormatter rpc.TopicFormatter, + roomClient rpc.TypedRoomClient, + participantClient rpc.TypedParticipantClient, +) (svc *RoomService, err error) { + svc = &RoomService{ + limitConf: limitConf, + apiConf: apiConf, + router: router, + roomAllocator: roomAllocator, + roomStore: serviceStore, + egressLauncher: egressLauncher, + topicFormatter: topicFormatter, + roomClient: roomClient, + participantClient: participantClient, + } + return +} + +func (s *RoomService) CreateRoom(ctx context.Context, req *livekit.CreateRoomRequest) (*livekit.Room, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Name, "request", logger.Proto(req)) + if err := EnsureCreatePermission(ctx); err != nil { + return nil, twirpAuthError(err) + } else if req.Egress != nil && s.egressLauncher == nil { + return nil, ErrEgressNotConnected + } + + if !s.limitConf.CheckRoomNameLength(req.Name) { + return nil, fmt.Errorf("%w: max length %d", ErrRoomNameExceedsLimits, s.limitConf.MaxRoomNameLength) + } + + err := s.roomAllocator.SelectRoomNode(ctx, livekit.RoomName(req.Name), livekit.NodeID(req.NodeId)) + if err != nil { + return nil, err + } + + room, err := s.router.CreateRoom(ctx, req) + RecordResponse(ctx, room) + return room, err +} + +func (s *RoomService) ListRooms(ctx context.Context, req *livekit.ListRoomsRequest) (*livekit.ListRoomsResponse, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Names) + err := EnsureListPermission(ctx) + if err != nil { + return nil, twirpAuthError(err) + } + + var names []livekit.RoomName + if len(req.Names) > 0 { + names = livekit.StringsAsIDs[livekit.RoomName](req.Names) + } + rooms, err := s.roomStore.ListRooms(ctx, names) + if err != nil { + // TODO: translate error codes to Twirp + return nil, err + } + + res := &livekit.ListRoomsResponse{ + Rooms: rooms, + } + RecordResponse(ctx, res) + return res, nil +} + +func (s *RoomService) DeleteRoom(ctx context.Context, req *livekit.DeleteRoomRequest) (*livekit.DeleteRoomResponse, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room) + if err := EnsureCreatePermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + + exists, err := s.roomStore.RoomExists(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, err + } else if !exists { + return nil, ErrRoomNotFound + } + + // ensure at least one node is available to handle the request + room, err := s.router.CreateRoom(ctx, &livekit.CreateRoomRequest{Name: req.Room}) + if err != nil { + return nil, err + } + + _, err = s.roomClient.DeleteRoom(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + if err != nil { + return nil, err + } + + if os, ok := s.roomStore.(OSSServiceStore); ok { + err = os.DeleteRoom(ctx, livekit.RoomName(req.Room)) + } + res := &livekit.DeleteRoomResponse{} + RecordResponse(ctx, room) + return res, err +} + +func (s *RoomService) ListParticipants(ctx context.Context, req *livekit.ListParticipantsRequest) (*livekit.ListParticipantsResponse, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room) + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + participants, err := s.roomStore.ListParticipants(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, err + } + + res := &livekit.ListParticipantsResponse{ + Participants: participants, + } + RecordResponse(ctx, res) + return res, nil +} + +func (s *RoomService) GetParticipant(ctx context.Context, req *livekit.RoomParticipantIdentity) (*livekit.ParticipantInfo, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room, "participant", req.Identity) + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + participant, err := s.roomStore.LoadParticipant(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)) + if err != nil { + return nil, err + } + + RecordResponse(ctx, participant) + return participant, nil +} + +func (s *RoomService) RemoveParticipant(ctx context.Context, req *livekit.RoomParticipantIdentity) (*livekit.RemoveParticipantResponse, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room, "participant", req.Identity) + + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + if _, err := s.roomStore.LoadParticipant(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)); err == ErrParticipantNotFound { + return nil, twirp.NotFoundError("participant not found") + } + + res, err := s.participantClient.RemoveParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) MutePublishedTrack(ctx context.Context, req *livekit.MuteRoomTrackRequest) (*livekit.MuteRoomTrackResponse, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room, "participant", req.Identity, "trackID", req.TrackSid, "muted", req.Muted) + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + res, err := s.participantClient.MutePublishedTrack(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) UpdateParticipant(ctx context.Context, req *livekit.UpdateParticipantRequest) (*livekit.ParticipantInfo, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room, "participant", req.Identity) + + if !s.limitConf.CheckParticipantNameLength(req.Name) { + return nil, twirp.InvalidArgumentError(ErrNameExceedsLimits.Error(), strconv.Itoa(s.limitConf.MaxParticipantNameLength)) + } + + if !s.limitConf.CheckMetadataSize(req.Metadata) { + return nil, twirp.InvalidArgumentError(ErrMetadataExceedsLimits.Error(), strconv.Itoa(int(s.limitConf.MaxMetadataSize))) + } + + if !s.limitConf.CheckAttributesSize(req.Attributes) { + return nil, twirp.InvalidArgumentError(ErrAttributeExceedsLimits.Error(), strconv.Itoa(int(s.limitConf.MaxAttributesSize))) + } + + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + if os, ok := s.roomStore.(OSSServiceStore); ok { + found, err := os.HasParticipant(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)) + if err != nil { + return nil, err + } else if !found { + return nil, ErrParticipantNotFound + } + } + + res, err := s.participantClient.UpdateParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) UpdateSubscriptions(ctx context.Context, req *livekit.UpdateSubscriptionsRequest) (*livekit.UpdateSubscriptionsResponse, error) { + RecordRequest(ctx, req) + + trackSIDs := append(make([]string, 0), req.TrackSids...) + for _, pt := range req.ParticipantTracks { + trackSIDs = append(trackSIDs, pt.TrackSids...) + } + AppendLogFields(ctx, "room", req.Room, "participant", req.Identity, "trackID", trackSIDs) + + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + res, err := s.participantClient.UpdateSubscriptions(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) SendData(ctx context.Context, req *livekit.SendDataRequest) (*livekit.SendDataResponse, error) { + RecordRequest(ctx, req) + + roomName := livekit.RoomName(req.Room) + AppendLogFields(ctx, "room", roomName, "size", len(req.Data)) + if err := EnsureAdminPermission(ctx, roomName); err != nil { + return nil, twirpAuthError(err) + } + + // nonce is either absent or 128-bit UUID + if len(req.Nonce) != 0 && len(req.Nonce) != 16 { + return nil, twirp.NewError(twirp.InvalidArgument, fmt.Sprintf("nonce should be 16-bytes or not present, got: %d bytes", len(req.Nonce))) + } + + res, err := s.roomClient.SendData(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) UpdateRoomMetadata(ctx context.Context, req *livekit.UpdateRoomMetadataRequest) (*livekit.Room, error) { + RecordRequest(ctx, req) + + AppendLogFields(ctx, "room", req.Room, "size", len(req.Metadata)) + maxMetadataSize := int(s.limitConf.MaxMetadataSize) + if maxMetadataSize > 0 && len(req.Metadata) > maxMetadataSize { + return nil, twirp.InvalidArgumentError(ErrMetadataExceedsLimits.Error(), strconv.Itoa(maxMetadataSize)) + } + + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.Room)); err != nil { + return nil, twirpAuthError(err) + } + + exists, err := s.roomStore.RoomExists(ctx, livekit.RoomName(req.Room)) + if err != nil { + return nil, err + } else if !exists { + return nil, ErrRoomNotFound + } + + room, err := s.roomClient.UpdateRoomMetadata(ctx, s.topicFormatter.RoomTopic(ctx, livekit.RoomName(req.Room)), req) + if err != nil { + return nil, err + } + + RecordResponse(ctx, room) + return room, nil +} + +func (s *RoomService) ForwardParticipant(ctx context.Context, req *livekit.ForwardParticipantRequest) (*livekit.ForwardParticipantResponse, error) { + RecordRequest(ctx, req) + + roomName := livekit.RoomName(req.Room) + AppendLogFields(ctx, "room", roomName, "participant", req.Identity) + if err := EnsureDestRoomPermission(ctx, roomName, livekit.RoomName(req.DestinationRoom)); err != nil { + return nil, twirpAuthError(err) + } + + if req.Room == req.DestinationRoom { + return nil, twirp.InvalidArgumentError(ErrDestinationSameAsSourceRoom.Error(), "") + } + + res, err := s.participantClient.ForwardParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) MoveParticipant(ctx context.Context, req *livekit.MoveParticipantRequest) (*livekit.MoveParticipantResponse, error) { + RecordRequest(ctx, req) + + roomName := livekit.RoomName(req.Room) + AppendLogFields(ctx, "room", roomName, "participant", req.Identity) + if err := EnsureDestRoomPermission(ctx, roomName, livekit.RoomName(req.DestinationRoom)); err != nil { + return nil, twirpAuthError(err) + } + + if req.Room == req.DestinationRoom { + return nil, twirp.InvalidArgumentError(ErrDestinationSameAsSourceRoom.Error(), "") + } + + res, err := s.participantClient.MoveParticipant(ctx, s.topicFormatter.ParticipantTopic(ctx, livekit.RoomName(req.Room), livekit.ParticipantIdentity(req.Identity)), req) + RecordResponse(ctx, res) + return res, err +} + +func (s *RoomService) PerformRpc(ctx context.Context, req *livekit.PerformRpcRequest) (*livekit.PerformRpcResponse, error) { + RecordRequest(ctx, req) + + roomName := livekit.RoomName(req.Room) + AppendLogFields(ctx, "room", roomName, "participant", req.DestinationIdentity) + + if err := EnsureAdminPermission(ctx, roomName); err != nil { + return nil, twirpAuthError(err) + } + if req.DestinationIdentity == "" { + return nil, ErrDestinationIdentityRequired + } + + res, err := s.participantClient.PerformRpc(ctx, s.topicFormatter.ParticipantTopic(ctx, roomName, livekit.ParticipantIdentity(req.DestinationIdentity)), req) + RecordResponse(ctx, res) + return res, err +} diff --git a/livekit/pkg/service/roomservice_test.go b/livekit/pkg/service/roomservice_test.go new file mode 100644 index 0000000..d34d510 --- /dev/null +++ b/livekit/pkg/service/roomservice_test.go @@ -0,0 +1,138 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/twitchtv/twirp" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/rpc/rpcfakes" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing/routingfakes" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/service/servicefakes" +) + +func TestDeleteRoom(t *testing.T) { + t.Run("missing permissions", func(t *testing.T) { + svc := newTestRoomService(config.LimitConfig{}) + grant := &auth.ClaimGrants{ + Video: &auth.VideoGrant{}, + } + ctx := service.WithGrants(context.Background(), grant, "") + _, err := svc.DeleteRoom(ctx, &livekit.DeleteRoomRequest{ + Room: "testroom", + }) + require.Error(t, err) + }) +} + +func TestMetaDataLimits(t *testing.T) { + t.Run("metadata exceed limits", func(t *testing.T) { + svc := newTestRoomService(config.LimitConfig{MaxMetadataSize: 5}) + grant := &auth.ClaimGrants{ + Video: &auth.VideoGrant{}, + } + ctx := service.WithGrants(context.Background(), grant, "") + _, err := svc.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: "testroom", + Identity: "123", + Metadata: "abcdefg", + }) + terr, ok := err.(twirp.Error) + require.True(t, ok) + require.Equal(t, twirp.InvalidArgument, terr.Code()) + + _, err = svc.UpdateRoomMetadata(ctx, &livekit.UpdateRoomMetadataRequest{ + Room: "testroom", + Metadata: "abcdefg", + }) + terr, ok = err.(twirp.Error) + require.True(t, ok) + require.Equal(t, twirp.InvalidArgument, terr.Code()) + }) + + notExceedsLimitsSvc := map[string]*TestRoomService{ + "metadata exceeds limits": newTestRoomService(config.LimitConfig{MaxMetadataSize: 5}), + "metadata no limits": newTestRoomService(config.LimitConfig{}), // no limits + } + + for n, s := range notExceedsLimitsSvc { + svc := s + t.Run(n, func(t *testing.T) { + grant := &auth.ClaimGrants{ + Video: &auth.VideoGrant{}, + } + ctx := service.WithGrants(context.Background(), grant, "") + _, err := svc.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: "testroom", + Identity: "123", + Metadata: "abc", + }) + terr, ok := err.(twirp.Error) + require.True(t, ok) + require.NotEqual(t, twirp.InvalidArgument, terr.Code()) + + _, err = svc.UpdateRoomMetadata(ctx, &livekit.UpdateRoomMetadataRequest{ + Room: "testroom", + Metadata: "abc", + }) + terr, ok = err.(twirp.Error) + require.True(t, ok) + require.NotEqual(t, twirp.InvalidArgument, terr.Code()) + }) + + } +} + +func newTestRoomService(limitConf config.LimitConfig) *TestRoomService { + router := &routingfakes.FakeRouter{} + allocator := &servicefakes.FakeRoomAllocator{} + store := &servicefakes.FakeServiceStore{} + svc, err := service.NewRoomService( + limitConf, + config.APIConfig{ExecutionTimeout: 2}, + router, + allocator, + store, + nil, + rpc.NewTopicFormatter(), + &rpcfakes.FakeTypedRoomClient{}, + &rpcfakes.FakeTypedParticipantClient{}, + ) + if err != nil { + panic(err) + } + return &TestRoomService{ + RoomService: *svc, + router: router, + allocator: allocator, + store: store, + } +} + +type TestRoomService struct { + service.RoomService + router *routingfakes.FakeRouter + allocator *servicefakes.FakeRoomAllocator + store *servicefakes.FakeServiceStore +} diff --git a/livekit/pkg/service/rtcservice.go b/livekit/pkg/service/rtcservice.go new file mode 100644 index 0000000..03fc139 --- /dev/null +++ b/livekit/pkg/service/rtcservice.go @@ -0,0 +1,664 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "bytes" + "compress/gzip" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net/http" + "os" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" + "go.uber.org/atomic" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" +) + +type RTCService struct { + router routing.MessageRouter + roomAllocator RoomAllocator + upgrader websocket.Upgrader + config *config.Config + isDev bool + limits config.LimitConfig + telemetry telemetry.TelemetryService + + mu sync.Mutex + connections map[*websocket.Conn]struct{} +} + +func NewRTCService( + conf *config.Config, + ra RoomAllocator, + router routing.MessageRouter, + telemetry telemetry.TelemetryService, +) *RTCService { + s := &RTCService{ + router: router, + roomAllocator: ra, + config: conf, + isDev: conf.Development, + limits: conf.Limit, + telemetry: telemetry, + connections: map[*websocket.Conn]struct{}{}, + } + + s.upgrader = websocket.Upgrader{ + EnableCompression: true, + + // allow connections from any origin, since script may be hosted anywhere + // security is enforced by access tokens + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + return s +} + +func (s *RTCService) SetupRoutes(mux *http.ServeMux) { + mux.HandleFunc("/rtc", s.v0) + mux.HandleFunc("/rtc/validate", s.v0Validate) + mux.HandleFunc("/rtc/v1", s.v1) + mux.HandleFunc("/rtc/v1/validate", s.v1Validate) +} + +func (s *RTCService) v0Validate(w http.ResponseWriter, r *http.Request) { + lgr := utils.GetLogger(r.Context()) + _, _, code, err := s.validateInternal(lgr, r, false, true) + if err != nil { + HandleError(w, r, code, err) + return + } + _, _ = w.Write([]byte("success")) +} + +func (s *RTCService) v1Validate(w http.ResponseWriter, r *http.Request) { + lgr := utils.GetLogger(r.Context()) + _, _, code, err := s.validateInternal(lgr, r, true, true) + if err != nil { + HandleError(w, r, code, err) + return + } + _, _ = w.Write([]byte("success")) +} + +func decodeAttributes(str string) (map[string]string, error) { + data, err := base64.URLEncoding.DecodeString(str) + if err != nil { + return nil, err + } + var attrs map[string]string + if err := json.Unmarshal(data, &attrs); err != nil { + return nil, err + } + return attrs, nil +} + +var gzipReaderPool = sync.Pool{ + New: func() any { return &gzip.Reader{} }, +} + +func (s *RTCService) validateInternal( + lgr logger.Logger, + r *http.Request, + needsJoinRequest bool, + strict bool, +) (livekit.RoomName, routing.ParticipantInit, int, error) { + var params ValidateConnectRequestParams + useSinglePeerConnection := false + joinRequest := &livekit.JoinRequest{} + + wrappedJoinRequestBase64 := r.FormValue("join_request") + if wrappedJoinRequestBase64 == "" { + if needsJoinRequest { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("join_request is required") + } + + params.publish = r.FormValue("publish") + + attributesStrParam := r.FormValue("attributes") + if attributesStrParam != "" { + attrs, err := decodeAttributes(attributesStrParam) + if err != nil { + if strict { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot decode attributes") + } + lgr.Debugw("failed to decode attributes", "error", err) + // attrs will be empty here, so just proceed + } + params.attributes = attrs + } + } else { + useSinglePeerConnection = true + if wrappedProtoBytes, err := base64.URLEncoding.DecodeString(wrappedJoinRequestBase64); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot base64 decode wrapped join request") + } else { + wrappedJoinRequest := &livekit.WrappedJoinRequest{} + if err := proto.Unmarshal(wrappedProtoBytes, wrappedJoinRequest); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal wrapped join request") + } + + switch wrappedJoinRequest.Compression { + case livekit.WrappedJoinRequest_NONE: + if err := proto.Unmarshal(wrappedJoinRequest.JoinRequest, joinRequest); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal join request") + } + + case livekit.WrappedJoinRequest_GZIP: + reader := gzipReaderPool.Get().(*gzip.Reader) + defer gzipReaderPool.Put(reader) + reader.Reset(bytes.NewReader(wrappedJoinRequest.JoinRequest)) + protoBytes, err := io.ReadAll(reader) + if err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot read decompressed join request") + } + + if err := proto.Unmarshal(protoBytes, joinRequest); err != nil { + return "", routing.ParticipantInit{}, http.StatusBadRequest, errors.New("cannot unmarshal join request") + } + } + + params.metadata = joinRequest.Metadata + params.attributes = joinRequest.ParticipantAttributes + } + } + + res, code, err := ValidateConnectRequest( + lgr, + r, + s.limits, + params, + s.router, + s.roomAllocator, + ) + if err != nil { + return res.roomName, routing.ParticipantInit{}, code, err + } + + pi := routing.ParticipantInit{ + Identity: livekit.ParticipantIdentity(res.grants.Identity), + Name: livekit.ParticipantName(res.grants.Name), + Grants: res.grants, + Region: res.region, + CreateRoom: res.createRoomRequest, + UseSinglePeerConnection: useSinglePeerConnection, + } + + if wrappedJoinRequestBase64 == "" { + pi.Reconnect = boolValue(r.FormValue("reconnect")) + pi.Client = ParseClientInfo(r) + + pi.AutoSubscribe = true + if autoSubscribeParam := r.FormValue("auto_subscribe"); autoSubscribeParam != "" { + pi.AutoSubscribe = boolValue(autoSubscribeParam) + } + + if autoSubscribeDataTrackParam := r.FormValue("auto_subscribe_data_track"); autoSubscribeDataTrackParam != "" { + autoSubscribeDataTrack := boolValue(autoSubscribeDataTrackParam) + pi.AutoSubscribeDataTrack = &autoSubscribeDataTrack + } + + pi.AdaptiveStream = boolValue(r.FormValue("adaptive_stream")) + pi.DisableICELite = boolValue(r.FormValue("disable_ice_lite")) + + reconnectReason, _ := strconv.Atoi(r.FormValue("reconnect_reason")) // 0 means unknown reason + pi.ReconnectReason = livekit.ReconnectReason(reconnectReason) + + if pi.Reconnect { + pi.ID = livekit.ParticipantID(r.FormValue("sid")) + } + + if subscriberAllowPauseParam := r.FormValue("subscriber_allow_pause"); subscriberAllowPauseParam != "" { + subscriberAllowPause := boolValue(subscriberAllowPauseParam) + pi.SubscriberAllowPause = &subscriberAllowPause + } + } else { + lgr.Debugw("processing join request", "joinRequest", logger.Proto(joinRequest)) + + AugmentClientInfo(joinRequest.ClientInfo, r) + pi.Client = joinRequest.ClientInfo + + pi.AutoSubscribe = joinRequest.GetConnectionSettings().GetAutoSubscribe() + + autoSubscribeDataTrack := joinRequest.GetConnectionSettings().GetAutoSubscribeDataTrack() + pi.AutoSubscribeDataTrack = &autoSubscribeDataTrack + + pi.AdaptiveStream = joinRequest.GetConnectionSettings().GetAdaptiveStream() + pi.DisableICELite = joinRequest.GetConnectionSettings().GetDisableIceLite() + + subscriberAllowPause := joinRequest.GetConnectionSettings().GetSubscriberAllowPause() + pi.SubscriberAllowPause = &subscriberAllowPause + + pi.AddTrackRequests = joinRequest.AddTrackRequests + pi.PublisherOffer = joinRequest.PublisherOffer + + pi.Reconnect = joinRequest.Reconnect + pi.ReconnectReason = joinRequest.ReconnectReason + pi.ID = livekit.ParticipantID(joinRequest.ParticipantSid) + } + + return res.roomName, pi, code, err +} + +func (s *RTCService) v0(w http.ResponseWriter, r *http.Request) { + s.serve(w, r, false) +} + +func (s *RTCService) v1(w http.ResponseWriter, r *http.Request) { + s.serve(w, r, true) +} + +func (s *RTCService) serve(w http.ResponseWriter, r *http.Request, needsJoinRequest bool) { + // reject non websocket requests + if !websocket.IsWebSocketUpgrade(r) { + w.WriteHeader(404) + return + } + + var ( + roomName livekit.RoomName + roomID livekit.RoomID + participantIdentity livekit.ParticipantIdentity + pID livekit.ParticipantID + loggerResolved bool + + pi routing.ParticipantInit + code int + err error + ) + + pLogger, loggerResolver := utils.GetLogger(r.Context()).WithDeferredValues() + + getLoggerFields := func() []any { + return []any{ + "room", roomName, + "roomID", roomID, + "participant", participantIdentity, + "pID", pID, + } + } + + resolveLogger := func(force bool) { + if loggerResolved { + return + } + + if force || (roomName != "" && roomID != "" && participantIdentity != "" && pID != "") { + loggerResolved = true + loggerResolver.Resolve(getLoggerFields()...) + } + } + + resetLogger := func() { + loggerResolver.Reset() + + roomName = "" + roomID = "" + participantIdentity = "" + pID = "" + loggerResolved = false + } + + roomName, pi, code, err = s.validateInternal(pLogger, r, needsJoinRequest, false) + if err != nil { + HandleError(w, r, code, err) + return + } + + participantIdentity = pi.Identity + if pi.ID != "" { + pID = pi.ID + } + + // give it a few attempts to start session + var cr connectionResult + var initialResponse *livekit.SignalResponse + for attempt := 0; attempt < s.config.SignalRelay.ConnectAttempts; attempt++ { + connectionTimeout := 3 * time.Second * time.Duration(attempt+1) + ctx := utils.ContextWithAttempt(r.Context(), attempt) + cr, initialResponse, err = s.startConnection(ctx, roomName, pi, connectionTimeout) + if err == nil || errors.Is(err, context.Canceled) { + break + } + } + + if err != nil { + prometheus.IncrementParticipantJoinFail(1) + status := http.StatusInternalServerError + var psrpcErr psrpc.Error + if errors.As(err, &psrpcErr) { + status = psrpcErr.ToHttp() + } + HandleError(w, r, status, err, getLoggerFields()...) + return + } + + prometheus.IncrementParticipantJoin(1) + + pLogger = pLogger.WithValues("connID", cr.ConnectionID) + if !pi.Reconnect && initialResponse.GetJoin() != nil { + joinRoomID := livekit.RoomID(initialResponse.GetJoin().GetRoom().GetSid()) + if joinRoomID != "" { + roomID = joinRoomID + } + + pi.ID = livekit.ParticipantID(initialResponse.GetJoin().GetParticipant().GetSid()) + pID = pi.ID + + resolveLogger(false) + } + + signalStats := telemetry.NewBytesSignalStats(r.Context(), s.telemetry) + if join := initialResponse.GetJoin(); join != nil { + signalStats.ResolveRoom(join.GetRoom()) + signalStats.ResolveParticipant(join.GetParticipant()) + } + if pi.Reconnect && pi.ID != "" { + signalStats.ResolveParticipant(&livekit.ParticipantInfo{ + Sid: string(pi.ID), + Identity: string(pi.Identity), + }) + } + + closedByClient := atomic.NewBool(false) + done := make(chan struct{}) + // function exits when websocket terminates, it'll close the event reading off of request sink and response source as well + defer func() { + pLogger.Debugw("finishing WS connection", "closedByClient", closedByClient.Load()) + cr.ResponseSource.Close() + cr.RequestSink.Close() + close(done) + + signalStats.Stop() + }() + + // upgrade only once the basics are good to go + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + HandleError(w, r, http.StatusInternalServerError, err, getLoggerFields()...) + return + } + + s.mu.Lock() + s.connections[conn] = struct{}{} + s.mu.Unlock() + + defer func() { + s.mu.Lock() + delete(s.connections, conn) + s.mu.Unlock() + }() + + // websocket established + sigConn := NewWSSignalConnection(conn) + pLogger.Debugw("sending initial response", "response", logger.Proto(initialResponse)) + count, err := sigConn.WriteResponse(initialResponse) + if err != nil { + resolveLogger(true) + pLogger.Warnw("could not write initial response", err) + return + } + signalStats.AddBytes(uint64(count), true) + + pLogger.Debugw( + "new client WS connected", + "reconnect", pi.Reconnect, + "reconnectReason", pi.ReconnectReason, + "adaptiveStream", pi.AdaptiveStream, + "selectedNodeID", cr.NodeID, + "nodeSelectionReason", cr.NodeSelectionReason, + ) + + // handle responses + go func() { + defer func() { + // when the source is terminated, this means Participant.Close had been called and RTC connection is done + // we would terminate the signal connection as well + closeMsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + _ = conn.WriteControl(websocket.CloseMessage, closeMsg, time.Now().Add(time.Second)) + _ = conn.Close() + }() + defer func() { + if r := rtc.Recover(pLogger); r != nil { + os.Exit(1) + } + }() + for { + select { + case <-done: + return + case msg := <-cr.ResponseSource.ReadChan(): + if msg == nil { + resolveLogger(true) + pLogger.Debugw("nothing to read from response source") + return + } + res, ok := msg.(*livekit.SignalResponse) + if !ok { + pLogger.Errorw( + "unexpected message type", nil, + "type", fmt.Sprintf("%T", msg), + ) + continue + } + + switch m := res.Message.(type) { + case *livekit.SignalResponse_Offer: + pLogger.Debugw("sending offer", "offer", m) + + case *livekit.SignalResponse_Answer: + pLogger.Debugw("sending answer", "answer", m) + + case *livekit.SignalResponse_Join: + pLogger.Debugw("sending join", "join", m) + signalStats.ResolveRoom(m.Join.GetRoom()) + signalStats.ResolveParticipant(m.Join.GetParticipant()) + + case *livekit.SignalResponse_RoomUpdate: + updateRoomID := livekit.RoomID(m.RoomUpdate.GetRoom().GetSid()) + if updateRoomID != "" { + roomID = updateRoomID + resolveLogger(false) + } + pLogger.Debugw("sending room update", "roomUpdate", m) + signalStats.ResolveRoom(m.RoomUpdate.GetRoom()) + + case *livekit.SignalResponse_Update: + pLogger.Debugw("sending participant update", "participantUpdate", m) + + case *livekit.SignalResponse_RoomMoved: + resetLogger() + signalStats.Reset() + + roomName = livekit.RoomName(m.RoomMoved.GetRoom().GetName()) + moveRoomID := livekit.RoomID(m.RoomMoved.GetRoom().GetSid()) + if moveRoomID != "" { + roomID = moveRoomID + } + participantIdentity = livekit.ParticipantIdentity(m.RoomMoved.GetParticipant().GetIdentity()) + pID = livekit.ParticipantID(m.RoomMoved.GetParticipant().GetSid()) + resolveLogger(false) + + signalStats.ResolveRoom(m.RoomMoved.GetRoom()) + signalStats.ResolveParticipant(m.RoomMoved.GetParticipant()) + pLogger.Debugw("sending room moved", "roomMoved", m) + + default: + pLogger.Debugw("sending signal response", "response", m) + } + + if count, err := sigConn.WriteResponse(res); err != nil { + pLogger.Warnw("error writing to websocket", err) + return + } else { + signalStats.AddBytes(uint64(count), true) + } + } + } + }() + + // handle incoming requests from websocket + for { + req, count, err := sigConn.ReadRequest() + if err != nil { + if IsWebSocketCloseError(err) { + closedByClient.Store(true) + } else { + pLogger.Errorw("error reading from websocket", err) + } + return + } + signalStats.AddBytes(uint64(count), false) + + switch m := req.Message.(type) { + case *livekit.SignalRequest_Ping: + count, perr := sigConn.WriteResponse(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{ + // + // Although this field is int64, some clients (like JS) cause overflow if nanosecond granularity is used. + // So. use UnixMillis(). + // + Pong: time.Now().UnixMilli(), + }, + }) + if perr == nil { + signalStats.AddBytes(uint64(count), true) + } + case *livekit.SignalRequest_PingReq: + count, perr := sigConn.WriteResponse(&livekit.SignalResponse{ + Message: &livekit.SignalResponse_PongResp{ + PongResp: &livekit.Pong{ + LastPingTimestamp: m.PingReq.Timestamp, + Timestamp: time.Now().UnixMilli(), + }, + }, + }) + if perr == nil { + signalStats.AddBytes(uint64(count), true) + } + } + + switch m := req.Message.(type) { + case *livekit.SignalRequest_Offer: + pLogger.Debugw("received offer", "offer", m) + case *livekit.SignalRequest_Answer: + pLogger.Debugw("received answer", "answer", m) + default: + pLogger.Debugw("received signal request", "request", m) + } + + if err := cr.RequestSink.WriteMessage(req); err != nil { + pLogger.Warnw("error writing to request sink", err) + return + } + } +} + +func (s *RTCService) DrainConnections(interval time.Duration) { + s.mu.Lock() + conns := maps.Clone(s.connections) + s.mu.Unlock() + + // jitter drain start + time.Sleep(time.Duration(rand.Int63n(int64(interval)))) + + t := time.NewTicker(interval) + defer t.Stop() + + for c := range conns { + _ = c.Close() + <-t.C + } +} + +type connectionResult struct { + routing.StartParticipantSignalResults + Room *livekit.Room +} + +func (s *RTCService) startConnection( + ctx context.Context, + roomName livekit.RoomName, + pi routing.ParticipantInit, + timeout time.Duration, +) (connectionResult, *livekit.SignalResponse, error) { + var cr connectionResult + var err error + + if err := s.roomAllocator.SelectRoomNode(ctx, roomName, ""); err != nil { + return cr, nil, err + } + + // this needs to be started first *before* using router functions on this node + cr.StartParticipantSignalResults, err = s.router.StartParticipantSignal(ctx, roomName, pi) + if err != nil { + return cr, nil, err + } + + // wait for the first message before upgrading to websocket. If no one is + // responding to our connection attempt, we should terminate the connection + // instead of waiting forever on the WebSocket + initialResponse, err := readInitialResponse(cr.ResponseSource, timeout) + if err != nil { + // close the connection to avoid leaking + cr.RequestSink.Close() + cr.ResponseSource.Close() + return cr, nil, err + } + + return cr, initialResponse, nil +} + +func readInitialResponse(source routing.MessageSource, timeout time.Duration) (*livekit.SignalResponse, error) { + responseTimer := time.NewTimer(timeout) + defer responseTimer.Stop() + for { + select { + case <-responseTimer.C: + return nil, errors.New("timed out while waiting for signal response") + case msg := <-source.ReadChan(): + if msg == nil { + return nil, errors.New("connection closed by media") + } + res, ok := msg.(*livekit.SignalResponse) + if !ok { + return nil, fmt.Errorf("unexpected message type: %T", msg) + } + return res, nil + } + } +} diff --git a/livekit/pkg/service/server.go b/livekit/pkg/service/server.go new file mode 100644 index 0000000..11c99d2 --- /dev/null +++ b/livekit/pkg/service/server.go @@ -0,0 +1,409 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + _ "net/http/pprof" + "runtime" + "runtime/pprof" + "strconv" + "time" + + "github.com/pion/turn/v4" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/rs/cors" + "github.com/twitchtv/twirp" + "github.com/urfave/negroni/v3" + "go.uber.org/atomic" + "golang.org/x/sync/errgroup" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/xtwirp" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/version" +) + +type LivekitServer struct { + config *config.Config + ioService *IOInfoService + rtcService *RTCService + whipService *WHIPService + agentService *AgentService + httpServer *http.Server + promServer *http.Server + router routing.Router + roomManager *RoomManager + signalServer *SignalServer + turnServer *turn.Server + currentNode routing.LocalNode + running atomic.Bool + doneChan chan struct{} + closedChan chan struct{} +} + +func NewLivekitServer(conf *config.Config, + roomService livekit.RoomService, + agentDispatchService *AgentDispatchService, + egressService *EgressService, + ingressService *IngressService, + sipService *SIPService, + ioService *IOInfoService, + rtcService *RTCService, + whipService *WHIPService, + agentService *AgentService, + keyProvider auth.KeyProvider, + router routing.Router, + roomManager *RoomManager, + signalServer *SignalServer, + turnServer *turn.Server, + currentNode routing.LocalNode, +) (s *LivekitServer, err error) { + s = &LivekitServer{ + config: conf, + ioService: ioService, + rtcService: rtcService, + whipService: whipService, + agentService: agentService, + router: router, + roomManager: roomManager, + signalServer: signalServer, + // turn server starts automatically + turnServer: turnServer, + currentNode: currentNode, + closedChan: make(chan struct{}), + } + + middlewares := []negroni.Handler{ + // always first + negroni.NewRecovery(), + // CORS is allowed, we rely on token authentication to prevent improper use + cors.New(cors.Options{ + AllowOriginFunc: func(origin string) bool { + return true + }, + AllowedMethods: []string{"OPTIONS", "HEAD", "GET", "POST", "PATCH", "DELETE"}, + AllowedHeaders: []string{"*"}, + ExposedHeaders: []string{"*"}, + // allow preflight to be cached for a day + MaxAge: 86400, + }), + negroni.HandlerFunc(RemoveDoubleSlashes), + } + if keyProvider != nil { + middlewares = append(middlewares, NewAPIKeyAuthMiddleware(keyProvider)) + } + + serverOptions := []any{ + twirp.WithServerHooks(twirp.ChainHooks( + TwirpLogger(), + TwirpRequestStatusReporter(), + )), + } + for _, opt := range xtwirp.DefaultServerOptions() { + serverOptions = append(serverOptions, opt) + } + roomServer := livekit.NewRoomServiceServer(roomService, serverOptions...) + agentDispatchServer := livekit.NewAgentDispatchServiceServer(agentDispatchService, serverOptions...) + egressServer := livekit.NewEgressServer(egressService, serverOptions...) + ingressServer := livekit.NewIngressServer(ingressService, serverOptions...) + sipServer := livekit.NewSIPServer(sipService, serverOptions...) + + mux := http.NewServeMux() + if conf.Development { + // pprof handlers are registered onto DefaultServeMux + mux = http.DefaultServeMux + mux.HandleFunc("/debug/goroutine", s.debugGoroutines) + mux.HandleFunc("/debug/rooms", s.debugInfo) + } + + xtwirp.RegisterServer(mux, roomServer) + xtwirp.RegisterServer(mux, agentDispatchServer) + xtwirp.RegisterServer(mux, egressServer) + xtwirp.RegisterServer(mux, ingressServer) + xtwirp.RegisterServer(mux, sipServer) + rtcService.SetupRoutes(mux) + whipService.SetupRoutes(mux) + mux.Handle("/agent", agentService) + mux.HandleFunc("/", s.defaultHandler) + + s.httpServer = &http.Server{ + Handler: configureMiddlewares(mux, middlewares...), + } + + if conf.PrometheusPort > 0 { + logger.Warnw("prometheus_port is deprecated, please switch prometheus.port instead", nil) + conf.Prometheus.Port = conf.PrometheusPort + } + + if conf.Prometheus.Port > 0 { + promHandler := promhttp.Handler() + if conf.Prometheus.Username != "" && conf.Prometheus.Password != "" { + protectedHandler := negroni.New() + protectedHandler.Use(negroni.HandlerFunc(GenBasicAuthMiddleware(conf.Prometheus.Username, conf.Prometheus.Password))) + protectedHandler.UseHandler(promHandler) + promHandler = protectedHandler + } + s.promServer = &http.Server{ + Handler: promHandler, + } + } + + if err = router.RemoveDeadNodes(); err != nil { + return + } + + return +} + +func (s *LivekitServer) Node() *livekit.Node { + return s.currentNode.Clone() +} + +func (s *LivekitServer) HTTPPort() int { + return int(s.config.Port) +} + +func (s *LivekitServer) IsRunning() bool { + return s.running.Load() +} + +func (s *LivekitServer) Start() error { + if s.running.Load() { + return errors.New("already running") + } + s.doneChan = make(chan struct{}) + + if err := s.router.RegisterNode(); err != nil { + return err + } + defer func() { + if err := s.router.UnregisterNode(); err != nil { + logger.Errorw("could not unregister node", err) + } + }() + + if err := s.router.Start(); err != nil { + return err + } + + if err := s.ioService.Start(); err != nil { + return err + } + + addresses := s.config.BindAddresses + if addresses == nil { + addresses = []string{""} + } + + // ensure we could listen + listeners := make([]net.Listener, 0) + promListeners := make([]net.Listener, 0) + for _, addr := range addresses { + ln, err := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(int(s.config.Port)))) + if err != nil { + return err + } + listeners = append(listeners, ln) + + if s.promServer != nil { + ln, err = net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(int(s.config.Prometheus.Port)))) + if err != nil { + return err + } + promListeners = append(promListeners, ln) + } + } + + values := []any{ + "portHttp", s.config.Port, + "nodeID", s.currentNode.NodeID(), + "nodeIP", s.currentNode.NodeIP(), + "version", version.Version, + } + if s.config.BindAddresses != nil { + values = append(values, "bindAddresses", s.config.BindAddresses) + } + if s.config.RTC.TCPPort != 0 { + values = append(values, "rtc.portTCP", s.config.RTC.TCPPort) + } + if !s.config.RTC.ForceTCP && s.config.RTC.UDPPort.Valid() { + values = append(values, "rtc.portUDP", s.config.RTC.UDPPort) + } else { + values = append(values, + "rtc.portICERange", []uint32{s.config.RTC.ICEPortRangeStart, s.config.RTC.ICEPortRangeEnd}, + ) + } + if s.config.Prometheus.Port != 0 { + values = append(values, "portPrometheus", s.config.Prometheus.Port) + } + if s.config.Region != "" { + values = append(values, "region", s.config.Region) + } + logger.Infow("starting LiveKit server", values...) + if runtime.GOOS == "windows" { + logger.Infow("Windows detected, capacity management is unavailable") + } + + for _, promLn := range promListeners { + go s.promServer.Serve(promLn) + } + + if err := s.signalServer.Start(); err != nil { + return err + } + + httpGroup := &errgroup.Group{} + for _, ln := range listeners { + l := ln + httpGroup.Go(func() error { + return s.httpServer.Serve(l) + }) + } + go func() { + if err := httpGroup.Wait(); err != http.ErrServerClosed { + logger.Errorw("could not start server", err) + s.Stop(true) + } + }() + + go s.backgroundWorker() + + // give time for Serve goroutine to start + time.Sleep(100 * time.Millisecond) + + s.running.Store(true) + + <-s.doneChan + + // wait for shutdown + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _ = s.httpServer.Shutdown(ctx) + + if s.turnServer != nil { + _ = s.turnServer.Close() + } + + s.roomManager.Stop() + s.signalServer.Stop() + s.ioService.Stop() + + close(s.closedChan) + return nil +} + +func (s *LivekitServer) Stop(force bool) { + // wait for all participants to exit + s.router.Drain() + partTicker := time.NewTicker(5 * time.Second) + waitingForParticipants := !force && s.roomManager.HasParticipants() + for waitingForParticipants { + <-partTicker.C + logger.Infow("waiting for participants to exit") + waitingForParticipants = s.roomManager.HasParticipants() + } + partTicker.Stop() + + if !s.running.Swap(false) { + return + } + + s.router.Stop() + close(s.doneChan) + + // wait for fully closed + <-s.closedChan +} + +func (s *LivekitServer) RoomManager() *RoomManager { + return s.roomManager +} + +func (s *LivekitServer) debugGoroutines(w http.ResponseWriter, _ *http.Request) { + _ = pprof.Lookup("goroutine").WriteTo(w, 2) +} + +func (s *LivekitServer) debugInfo(w http.ResponseWriter, _ *http.Request) { + s.roomManager.lock.RLock() + info := make([]map[string]any, 0, len(s.roomManager.rooms)) + for _, room := range s.roomManager.rooms { + info = append(info, room.DebugInfo()) + } + s.roomManager.lock.RUnlock() + + b, err := json.Marshal(info) + if err != nil { + w.WriteHeader(400) + _, _ = w.Write([]byte(err.Error())) + } else { + _, _ = w.Write(b) + } +} + +func (s *LivekitServer) defaultHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + s.healthCheck(w, r) + } else { + http.NotFound(w, r) + } +} + +func (s *LivekitServer) healthCheck(w http.ResponseWriter, _ *http.Request) { + var updatedAt time.Time + if s.Node().Stats != nil { + updatedAt = time.Unix(s.Node().Stats.UpdatedAt, 0) + } + if time.Since(updatedAt) > 4*time.Second { + w.WriteHeader(http.StatusNotAcceptable) + _, _ = w.Write([]byte(fmt.Sprintf("Not Ready\nNode Updated At %s", updatedAt))) + return + } + + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) +} + +// worker to perform periodic tasks per node +func (s *LivekitServer) backgroundWorker() { + roomTicker := time.NewTicker(1 * time.Second) + defer roomTicker.Stop() + for { + select { + case <-s.doneChan: + return + case <-roomTicker.C: + s.roomManager.CloseIdleRooms() + } + } +} + +func configureMiddlewares(handler http.Handler, middlewares ...negroni.Handler) *negroni.Negroni { + n := negroni.New() + for _, m := range middlewares { + n.Use(m) + } + n.UseHandler(handler) + return n +} diff --git a/livekit/pkg/service/servicefakes/fake_agent_store.go b/livekit/pkg/service/servicefakes/fake_agent_store.go new file mode 100644 index 0000000..6727597 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_agent_store.go @@ -0,0 +1,414 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeAgentStore struct { + DeleteAgentDispatchStub func(context.Context, *livekit.AgentDispatch) error + deleteAgentDispatchMutex sync.RWMutex + deleteAgentDispatchArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AgentDispatch + } + deleteAgentDispatchReturns struct { + result1 error + } + deleteAgentDispatchReturnsOnCall map[int]struct { + result1 error + } + DeleteAgentJobStub func(context.Context, *livekit.Job) error + deleteAgentJobMutex sync.RWMutex + deleteAgentJobArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Job + } + deleteAgentJobReturns struct { + result1 error + } + deleteAgentJobReturnsOnCall map[int]struct { + result1 error + } + ListAgentDispatchesStub func(context.Context, livekit.RoomName) ([]*livekit.AgentDispatch, error) + listAgentDispatchesMutex sync.RWMutex + listAgentDispatchesArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + listAgentDispatchesReturns struct { + result1 []*livekit.AgentDispatch + result2 error + } + listAgentDispatchesReturnsOnCall map[int]struct { + result1 []*livekit.AgentDispatch + result2 error + } + StoreAgentDispatchStub func(context.Context, *livekit.AgentDispatch) error + storeAgentDispatchMutex sync.RWMutex + storeAgentDispatchArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AgentDispatch + } + storeAgentDispatchReturns struct { + result1 error + } + storeAgentDispatchReturnsOnCall map[int]struct { + result1 error + } + StoreAgentJobStub func(context.Context, *livekit.Job) error + storeAgentJobMutex sync.RWMutex + storeAgentJobArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Job + } + storeAgentJobReturns struct { + result1 error + } + storeAgentJobReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeAgentStore) DeleteAgentDispatch(arg1 context.Context, arg2 *livekit.AgentDispatch) error { + fake.deleteAgentDispatchMutex.Lock() + ret, specificReturn := fake.deleteAgentDispatchReturnsOnCall[len(fake.deleteAgentDispatchArgsForCall)] + fake.deleteAgentDispatchArgsForCall = append(fake.deleteAgentDispatchArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AgentDispatch + }{arg1, arg2}) + stub := fake.DeleteAgentDispatchStub + fakeReturns := fake.deleteAgentDispatchReturns + fake.recordInvocation("DeleteAgentDispatch", []interface{}{arg1, arg2}) + fake.deleteAgentDispatchMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeAgentStore) DeleteAgentDispatchCallCount() int { + fake.deleteAgentDispatchMutex.RLock() + defer fake.deleteAgentDispatchMutex.RUnlock() + return len(fake.deleteAgentDispatchArgsForCall) +} + +func (fake *FakeAgentStore) DeleteAgentDispatchCalls(stub func(context.Context, *livekit.AgentDispatch) error) { + fake.deleteAgentDispatchMutex.Lock() + defer fake.deleteAgentDispatchMutex.Unlock() + fake.DeleteAgentDispatchStub = stub +} + +func (fake *FakeAgentStore) DeleteAgentDispatchArgsForCall(i int) (context.Context, *livekit.AgentDispatch) { + fake.deleteAgentDispatchMutex.RLock() + defer fake.deleteAgentDispatchMutex.RUnlock() + argsForCall := fake.deleteAgentDispatchArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAgentStore) DeleteAgentDispatchReturns(result1 error) { + fake.deleteAgentDispatchMutex.Lock() + defer fake.deleteAgentDispatchMutex.Unlock() + fake.DeleteAgentDispatchStub = nil + fake.deleteAgentDispatchReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) DeleteAgentDispatchReturnsOnCall(i int, result1 error) { + fake.deleteAgentDispatchMutex.Lock() + defer fake.deleteAgentDispatchMutex.Unlock() + fake.DeleteAgentDispatchStub = nil + if fake.deleteAgentDispatchReturnsOnCall == nil { + fake.deleteAgentDispatchReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteAgentDispatchReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) DeleteAgentJob(arg1 context.Context, arg2 *livekit.Job) error { + fake.deleteAgentJobMutex.Lock() + ret, specificReturn := fake.deleteAgentJobReturnsOnCall[len(fake.deleteAgentJobArgsForCall)] + fake.deleteAgentJobArgsForCall = append(fake.deleteAgentJobArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Job + }{arg1, arg2}) + stub := fake.DeleteAgentJobStub + fakeReturns := fake.deleteAgentJobReturns + fake.recordInvocation("DeleteAgentJob", []interface{}{arg1, arg2}) + fake.deleteAgentJobMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeAgentStore) DeleteAgentJobCallCount() int { + fake.deleteAgentJobMutex.RLock() + defer fake.deleteAgentJobMutex.RUnlock() + return len(fake.deleteAgentJobArgsForCall) +} + +func (fake *FakeAgentStore) DeleteAgentJobCalls(stub func(context.Context, *livekit.Job) error) { + fake.deleteAgentJobMutex.Lock() + defer fake.deleteAgentJobMutex.Unlock() + fake.DeleteAgentJobStub = stub +} + +func (fake *FakeAgentStore) DeleteAgentJobArgsForCall(i int) (context.Context, *livekit.Job) { + fake.deleteAgentJobMutex.RLock() + defer fake.deleteAgentJobMutex.RUnlock() + argsForCall := fake.deleteAgentJobArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAgentStore) DeleteAgentJobReturns(result1 error) { + fake.deleteAgentJobMutex.Lock() + defer fake.deleteAgentJobMutex.Unlock() + fake.DeleteAgentJobStub = nil + fake.deleteAgentJobReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) DeleteAgentJobReturnsOnCall(i int, result1 error) { + fake.deleteAgentJobMutex.Lock() + defer fake.deleteAgentJobMutex.Unlock() + fake.DeleteAgentJobStub = nil + if fake.deleteAgentJobReturnsOnCall == nil { + fake.deleteAgentJobReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteAgentJobReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) ListAgentDispatches(arg1 context.Context, arg2 livekit.RoomName) ([]*livekit.AgentDispatch, error) { + fake.listAgentDispatchesMutex.Lock() + ret, specificReturn := fake.listAgentDispatchesReturnsOnCall[len(fake.listAgentDispatchesArgsForCall)] + fake.listAgentDispatchesArgsForCall = append(fake.listAgentDispatchesArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.ListAgentDispatchesStub + fakeReturns := fake.listAgentDispatchesReturns + fake.recordInvocation("ListAgentDispatches", []interface{}{arg1, arg2}) + fake.listAgentDispatchesMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeAgentStore) ListAgentDispatchesCallCount() int { + fake.listAgentDispatchesMutex.RLock() + defer fake.listAgentDispatchesMutex.RUnlock() + return len(fake.listAgentDispatchesArgsForCall) +} + +func (fake *FakeAgentStore) ListAgentDispatchesCalls(stub func(context.Context, livekit.RoomName) ([]*livekit.AgentDispatch, error)) { + fake.listAgentDispatchesMutex.Lock() + defer fake.listAgentDispatchesMutex.Unlock() + fake.ListAgentDispatchesStub = stub +} + +func (fake *FakeAgentStore) ListAgentDispatchesArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.listAgentDispatchesMutex.RLock() + defer fake.listAgentDispatchesMutex.RUnlock() + argsForCall := fake.listAgentDispatchesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAgentStore) ListAgentDispatchesReturns(result1 []*livekit.AgentDispatch, result2 error) { + fake.listAgentDispatchesMutex.Lock() + defer fake.listAgentDispatchesMutex.Unlock() + fake.ListAgentDispatchesStub = nil + fake.listAgentDispatchesReturns = struct { + result1 []*livekit.AgentDispatch + result2 error + }{result1, result2} +} + +func (fake *FakeAgentStore) ListAgentDispatchesReturnsOnCall(i int, result1 []*livekit.AgentDispatch, result2 error) { + fake.listAgentDispatchesMutex.Lock() + defer fake.listAgentDispatchesMutex.Unlock() + fake.ListAgentDispatchesStub = nil + if fake.listAgentDispatchesReturnsOnCall == nil { + fake.listAgentDispatchesReturnsOnCall = make(map[int]struct { + result1 []*livekit.AgentDispatch + result2 error + }) + } + fake.listAgentDispatchesReturnsOnCall[i] = struct { + result1 []*livekit.AgentDispatch + result2 error + }{result1, result2} +} + +func (fake *FakeAgentStore) StoreAgentDispatch(arg1 context.Context, arg2 *livekit.AgentDispatch) error { + fake.storeAgentDispatchMutex.Lock() + ret, specificReturn := fake.storeAgentDispatchReturnsOnCall[len(fake.storeAgentDispatchArgsForCall)] + fake.storeAgentDispatchArgsForCall = append(fake.storeAgentDispatchArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AgentDispatch + }{arg1, arg2}) + stub := fake.StoreAgentDispatchStub + fakeReturns := fake.storeAgentDispatchReturns + fake.recordInvocation("StoreAgentDispatch", []interface{}{arg1, arg2}) + fake.storeAgentDispatchMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeAgentStore) StoreAgentDispatchCallCount() int { + fake.storeAgentDispatchMutex.RLock() + defer fake.storeAgentDispatchMutex.RUnlock() + return len(fake.storeAgentDispatchArgsForCall) +} + +func (fake *FakeAgentStore) StoreAgentDispatchCalls(stub func(context.Context, *livekit.AgentDispatch) error) { + fake.storeAgentDispatchMutex.Lock() + defer fake.storeAgentDispatchMutex.Unlock() + fake.StoreAgentDispatchStub = stub +} + +func (fake *FakeAgentStore) StoreAgentDispatchArgsForCall(i int) (context.Context, *livekit.AgentDispatch) { + fake.storeAgentDispatchMutex.RLock() + defer fake.storeAgentDispatchMutex.RUnlock() + argsForCall := fake.storeAgentDispatchArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAgentStore) StoreAgentDispatchReturns(result1 error) { + fake.storeAgentDispatchMutex.Lock() + defer fake.storeAgentDispatchMutex.Unlock() + fake.StoreAgentDispatchStub = nil + fake.storeAgentDispatchReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) StoreAgentDispatchReturnsOnCall(i int, result1 error) { + fake.storeAgentDispatchMutex.Lock() + defer fake.storeAgentDispatchMutex.Unlock() + fake.StoreAgentDispatchStub = nil + if fake.storeAgentDispatchReturnsOnCall == nil { + fake.storeAgentDispatchReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeAgentDispatchReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) StoreAgentJob(arg1 context.Context, arg2 *livekit.Job) error { + fake.storeAgentJobMutex.Lock() + ret, specificReturn := fake.storeAgentJobReturnsOnCall[len(fake.storeAgentJobArgsForCall)] + fake.storeAgentJobArgsForCall = append(fake.storeAgentJobArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Job + }{arg1, arg2}) + stub := fake.StoreAgentJobStub + fakeReturns := fake.storeAgentJobReturns + fake.recordInvocation("StoreAgentJob", []interface{}{arg1, arg2}) + fake.storeAgentJobMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeAgentStore) StoreAgentJobCallCount() int { + fake.storeAgentJobMutex.RLock() + defer fake.storeAgentJobMutex.RUnlock() + return len(fake.storeAgentJobArgsForCall) +} + +func (fake *FakeAgentStore) StoreAgentJobCalls(stub func(context.Context, *livekit.Job) error) { + fake.storeAgentJobMutex.Lock() + defer fake.storeAgentJobMutex.Unlock() + fake.StoreAgentJobStub = stub +} + +func (fake *FakeAgentStore) StoreAgentJobArgsForCall(i int) (context.Context, *livekit.Job) { + fake.storeAgentJobMutex.RLock() + defer fake.storeAgentJobMutex.RUnlock() + argsForCall := fake.storeAgentJobArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAgentStore) StoreAgentJobReturns(result1 error) { + fake.storeAgentJobMutex.Lock() + defer fake.storeAgentJobMutex.Unlock() + fake.StoreAgentJobStub = nil + fake.storeAgentJobReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) StoreAgentJobReturnsOnCall(i int, result1 error) { + fake.storeAgentJobMutex.Lock() + defer fake.storeAgentJobMutex.Unlock() + fake.StoreAgentJobStub = nil + if fake.storeAgentJobReturnsOnCall == nil { + fake.storeAgentJobReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeAgentJobReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeAgentStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeAgentStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.AgentStore = new(FakeAgentStore) diff --git a/livekit/pkg/service/servicefakes/fake_egress_store.go b/livekit/pkg/service/servicefakes/fake_egress_store.go new file mode 100644 index 0000000..b642eb9 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_egress_store.go @@ -0,0 +1,347 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeEgressStore struct { + ListEgressStub func(context.Context, livekit.RoomName, bool) ([]*livekit.EgressInfo, error) + listEgressMutex sync.RWMutex + listEgressArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 bool + } + listEgressReturns struct { + result1 []*livekit.EgressInfo + result2 error + } + listEgressReturnsOnCall map[int]struct { + result1 []*livekit.EgressInfo + result2 error + } + LoadEgressStub func(context.Context, string) (*livekit.EgressInfo, error) + loadEgressMutex sync.RWMutex + loadEgressArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadEgressReturns struct { + result1 *livekit.EgressInfo + result2 error + } + loadEgressReturnsOnCall map[int]struct { + result1 *livekit.EgressInfo + result2 error + } + StoreEgressStub func(context.Context, *livekit.EgressInfo) error + storeEgressMutex sync.RWMutex + storeEgressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.EgressInfo + } + storeEgressReturns struct { + result1 error + } + storeEgressReturnsOnCall map[int]struct { + result1 error + } + UpdateEgressStub func(context.Context, *livekit.EgressInfo) error + updateEgressMutex sync.RWMutex + updateEgressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.EgressInfo + } + updateEgressReturns struct { + result1 error + } + updateEgressReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeEgressStore) ListEgress(arg1 context.Context, arg2 livekit.RoomName, arg3 bool) ([]*livekit.EgressInfo, error) { + fake.listEgressMutex.Lock() + ret, specificReturn := fake.listEgressReturnsOnCall[len(fake.listEgressArgsForCall)] + fake.listEgressArgsForCall = append(fake.listEgressArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.ListEgressStub + fakeReturns := fake.listEgressReturns + fake.recordInvocation("ListEgress", []interface{}{arg1, arg2, arg3}) + fake.listEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeEgressStore) ListEgressCallCount() int { + fake.listEgressMutex.RLock() + defer fake.listEgressMutex.RUnlock() + return len(fake.listEgressArgsForCall) +} + +func (fake *FakeEgressStore) ListEgressCalls(stub func(context.Context, livekit.RoomName, bool) ([]*livekit.EgressInfo, error)) { + fake.listEgressMutex.Lock() + defer fake.listEgressMutex.Unlock() + fake.ListEgressStub = stub +} + +func (fake *FakeEgressStore) ListEgressArgsForCall(i int) (context.Context, livekit.RoomName, bool) { + fake.listEgressMutex.RLock() + defer fake.listEgressMutex.RUnlock() + argsForCall := fake.listEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeEgressStore) ListEgressReturns(result1 []*livekit.EgressInfo, result2 error) { + fake.listEgressMutex.Lock() + defer fake.listEgressMutex.Unlock() + fake.ListEgressStub = nil + fake.listEgressReturns = struct { + result1 []*livekit.EgressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeEgressStore) ListEgressReturnsOnCall(i int, result1 []*livekit.EgressInfo, result2 error) { + fake.listEgressMutex.Lock() + defer fake.listEgressMutex.Unlock() + fake.ListEgressStub = nil + if fake.listEgressReturnsOnCall == nil { + fake.listEgressReturnsOnCall = make(map[int]struct { + result1 []*livekit.EgressInfo + result2 error + }) + } + fake.listEgressReturnsOnCall[i] = struct { + result1 []*livekit.EgressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeEgressStore) LoadEgress(arg1 context.Context, arg2 string) (*livekit.EgressInfo, error) { + fake.loadEgressMutex.Lock() + ret, specificReturn := fake.loadEgressReturnsOnCall[len(fake.loadEgressArgsForCall)] + fake.loadEgressArgsForCall = append(fake.loadEgressArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadEgressStub + fakeReturns := fake.loadEgressReturns + fake.recordInvocation("LoadEgress", []interface{}{arg1, arg2}) + fake.loadEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeEgressStore) LoadEgressCallCount() int { + fake.loadEgressMutex.RLock() + defer fake.loadEgressMutex.RUnlock() + return len(fake.loadEgressArgsForCall) +} + +func (fake *FakeEgressStore) LoadEgressCalls(stub func(context.Context, string) (*livekit.EgressInfo, error)) { + fake.loadEgressMutex.Lock() + defer fake.loadEgressMutex.Unlock() + fake.LoadEgressStub = stub +} + +func (fake *FakeEgressStore) LoadEgressArgsForCall(i int) (context.Context, string) { + fake.loadEgressMutex.RLock() + defer fake.loadEgressMutex.RUnlock() + argsForCall := fake.loadEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeEgressStore) LoadEgressReturns(result1 *livekit.EgressInfo, result2 error) { + fake.loadEgressMutex.Lock() + defer fake.loadEgressMutex.Unlock() + fake.LoadEgressStub = nil + fake.loadEgressReturns = struct { + result1 *livekit.EgressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeEgressStore) LoadEgressReturnsOnCall(i int, result1 *livekit.EgressInfo, result2 error) { + fake.loadEgressMutex.Lock() + defer fake.loadEgressMutex.Unlock() + fake.LoadEgressStub = nil + if fake.loadEgressReturnsOnCall == nil { + fake.loadEgressReturnsOnCall = make(map[int]struct { + result1 *livekit.EgressInfo + result2 error + }) + } + fake.loadEgressReturnsOnCall[i] = struct { + result1 *livekit.EgressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeEgressStore) StoreEgress(arg1 context.Context, arg2 *livekit.EgressInfo) error { + fake.storeEgressMutex.Lock() + ret, specificReturn := fake.storeEgressReturnsOnCall[len(fake.storeEgressArgsForCall)] + fake.storeEgressArgsForCall = append(fake.storeEgressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.EgressInfo + }{arg1, arg2}) + stub := fake.StoreEgressStub + fakeReturns := fake.storeEgressReturns + fake.recordInvocation("StoreEgress", []interface{}{arg1, arg2}) + fake.storeEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeEgressStore) StoreEgressCallCount() int { + fake.storeEgressMutex.RLock() + defer fake.storeEgressMutex.RUnlock() + return len(fake.storeEgressArgsForCall) +} + +func (fake *FakeEgressStore) StoreEgressCalls(stub func(context.Context, *livekit.EgressInfo) error) { + fake.storeEgressMutex.Lock() + defer fake.storeEgressMutex.Unlock() + fake.StoreEgressStub = stub +} + +func (fake *FakeEgressStore) StoreEgressArgsForCall(i int) (context.Context, *livekit.EgressInfo) { + fake.storeEgressMutex.RLock() + defer fake.storeEgressMutex.RUnlock() + argsForCall := fake.storeEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeEgressStore) StoreEgressReturns(result1 error) { + fake.storeEgressMutex.Lock() + defer fake.storeEgressMutex.Unlock() + fake.StoreEgressStub = nil + fake.storeEgressReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeEgressStore) StoreEgressReturnsOnCall(i int, result1 error) { + fake.storeEgressMutex.Lock() + defer fake.storeEgressMutex.Unlock() + fake.StoreEgressStub = nil + if fake.storeEgressReturnsOnCall == nil { + fake.storeEgressReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeEgressReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeEgressStore) UpdateEgress(arg1 context.Context, arg2 *livekit.EgressInfo) error { + fake.updateEgressMutex.Lock() + ret, specificReturn := fake.updateEgressReturnsOnCall[len(fake.updateEgressArgsForCall)] + fake.updateEgressArgsForCall = append(fake.updateEgressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.EgressInfo + }{arg1, arg2}) + stub := fake.UpdateEgressStub + fakeReturns := fake.updateEgressReturns + fake.recordInvocation("UpdateEgress", []interface{}{arg1, arg2}) + fake.updateEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeEgressStore) UpdateEgressCallCount() int { + fake.updateEgressMutex.RLock() + defer fake.updateEgressMutex.RUnlock() + return len(fake.updateEgressArgsForCall) +} + +func (fake *FakeEgressStore) UpdateEgressCalls(stub func(context.Context, *livekit.EgressInfo) error) { + fake.updateEgressMutex.Lock() + defer fake.updateEgressMutex.Unlock() + fake.UpdateEgressStub = stub +} + +func (fake *FakeEgressStore) UpdateEgressArgsForCall(i int) (context.Context, *livekit.EgressInfo) { + fake.updateEgressMutex.RLock() + defer fake.updateEgressMutex.RUnlock() + argsForCall := fake.updateEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeEgressStore) UpdateEgressReturns(result1 error) { + fake.updateEgressMutex.Lock() + defer fake.updateEgressMutex.Unlock() + fake.UpdateEgressStub = nil + fake.updateEgressReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeEgressStore) UpdateEgressReturnsOnCall(i int, result1 error) { + fake.updateEgressMutex.Lock() + defer fake.updateEgressMutex.Unlock() + fake.UpdateEgressStub = nil + if fake.updateEgressReturnsOnCall == nil { + fake.updateEgressReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateEgressReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeEgressStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeEgressStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.EgressStore = new(FakeEgressStore) diff --git a/livekit/pkg/service/servicefakes/fake_ingress_store.go b/livekit/pkg/service/servicefakes/fake_ingress_store.go new file mode 100644 index 0000000..f959ad4 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_ingress_store.go @@ -0,0 +1,574 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeIngressStore struct { + DeleteIngressStub func(context.Context, *livekit.IngressInfo) error + deleteIngressMutex sync.RWMutex + deleteIngressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + deleteIngressReturns struct { + result1 error + } + deleteIngressReturnsOnCall map[int]struct { + result1 error + } + ListIngressStub func(context.Context, livekit.RoomName) ([]*livekit.IngressInfo, error) + listIngressMutex sync.RWMutex + listIngressArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + listIngressReturns struct { + result1 []*livekit.IngressInfo + result2 error + } + listIngressReturnsOnCall map[int]struct { + result1 []*livekit.IngressInfo + result2 error + } + LoadIngressStub func(context.Context, string) (*livekit.IngressInfo, error) + loadIngressMutex sync.RWMutex + loadIngressArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadIngressReturns struct { + result1 *livekit.IngressInfo + result2 error + } + loadIngressReturnsOnCall map[int]struct { + result1 *livekit.IngressInfo + result2 error + } + LoadIngressFromStreamKeyStub func(context.Context, string) (*livekit.IngressInfo, error) + loadIngressFromStreamKeyMutex sync.RWMutex + loadIngressFromStreamKeyArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadIngressFromStreamKeyReturns struct { + result1 *livekit.IngressInfo + result2 error + } + loadIngressFromStreamKeyReturnsOnCall map[int]struct { + result1 *livekit.IngressInfo + result2 error + } + StoreIngressStub func(context.Context, *livekit.IngressInfo) error + storeIngressMutex sync.RWMutex + storeIngressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + storeIngressReturns struct { + result1 error + } + storeIngressReturnsOnCall map[int]struct { + result1 error + } + UpdateIngressStub func(context.Context, *livekit.IngressInfo) error + updateIngressMutex sync.RWMutex + updateIngressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + updateIngressReturns struct { + result1 error + } + updateIngressReturnsOnCall map[int]struct { + result1 error + } + UpdateIngressStateStub func(context.Context, string, *livekit.IngressState) error + updateIngressStateMutex sync.RWMutex + updateIngressStateArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 *livekit.IngressState + } + updateIngressStateReturns struct { + result1 error + } + updateIngressStateReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeIngressStore) DeleteIngress(arg1 context.Context, arg2 *livekit.IngressInfo) error { + fake.deleteIngressMutex.Lock() + ret, specificReturn := fake.deleteIngressReturnsOnCall[len(fake.deleteIngressArgsForCall)] + fake.deleteIngressArgsForCall = append(fake.deleteIngressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.DeleteIngressStub + fakeReturns := fake.deleteIngressReturns + fake.recordInvocation("DeleteIngress", []interface{}{arg1, arg2}) + fake.deleteIngressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeIngressStore) DeleteIngressCallCount() int { + fake.deleteIngressMutex.RLock() + defer fake.deleteIngressMutex.RUnlock() + return len(fake.deleteIngressArgsForCall) +} + +func (fake *FakeIngressStore) DeleteIngressCalls(stub func(context.Context, *livekit.IngressInfo) error) { + fake.deleteIngressMutex.Lock() + defer fake.deleteIngressMutex.Unlock() + fake.DeleteIngressStub = stub +} + +func (fake *FakeIngressStore) DeleteIngressArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.deleteIngressMutex.RLock() + defer fake.deleteIngressMutex.RUnlock() + argsForCall := fake.deleteIngressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIngressStore) DeleteIngressReturns(result1 error) { + fake.deleteIngressMutex.Lock() + defer fake.deleteIngressMutex.Unlock() + fake.DeleteIngressStub = nil + fake.deleteIngressReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) DeleteIngressReturnsOnCall(i int, result1 error) { + fake.deleteIngressMutex.Lock() + defer fake.deleteIngressMutex.Unlock() + fake.DeleteIngressStub = nil + if fake.deleteIngressReturnsOnCall == nil { + fake.deleteIngressReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteIngressReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) ListIngress(arg1 context.Context, arg2 livekit.RoomName) ([]*livekit.IngressInfo, error) { + fake.listIngressMutex.Lock() + ret, specificReturn := fake.listIngressReturnsOnCall[len(fake.listIngressArgsForCall)] + fake.listIngressArgsForCall = append(fake.listIngressArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.ListIngressStub + fakeReturns := fake.listIngressReturns + fake.recordInvocation("ListIngress", []interface{}{arg1, arg2}) + fake.listIngressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIngressStore) ListIngressCallCount() int { + fake.listIngressMutex.RLock() + defer fake.listIngressMutex.RUnlock() + return len(fake.listIngressArgsForCall) +} + +func (fake *FakeIngressStore) ListIngressCalls(stub func(context.Context, livekit.RoomName) ([]*livekit.IngressInfo, error)) { + fake.listIngressMutex.Lock() + defer fake.listIngressMutex.Unlock() + fake.ListIngressStub = stub +} + +func (fake *FakeIngressStore) ListIngressArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.listIngressMutex.RLock() + defer fake.listIngressMutex.RUnlock() + argsForCall := fake.listIngressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIngressStore) ListIngressReturns(result1 []*livekit.IngressInfo, result2 error) { + fake.listIngressMutex.Lock() + defer fake.listIngressMutex.Unlock() + fake.ListIngressStub = nil + fake.listIngressReturns = struct { + result1 []*livekit.IngressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIngressStore) ListIngressReturnsOnCall(i int, result1 []*livekit.IngressInfo, result2 error) { + fake.listIngressMutex.Lock() + defer fake.listIngressMutex.Unlock() + fake.ListIngressStub = nil + if fake.listIngressReturnsOnCall == nil { + fake.listIngressReturnsOnCall = make(map[int]struct { + result1 []*livekit.IngressInfo + result2 error + }) + } + fake.listIngressReturnsOnCall[i] = struct { + result1 []*livekit.IngressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIngressStore) LoadIngress(arg1 context.Context, arg2 string) (*livekit.IngressInfo, error) { + fake.loadIngressMutex.Lock() + ret, specificReturn := fake.loadIngressReturnsOnCall[len(fake.loadIngressArgsForCall)] + fake.loadIngressArgsForCall = append(fake.loadIngressArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadIngressStub + fakeReturns := fake.loadIngressReturns + fake.recordInvocation("LoadIngress", []interface{}{arg1, arg2}) + fake.loadIngressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIngressStore) LoadIngressCallCount() int { + fake.loadIngressMutex.RLock() + defer fake.loadIngressMutex.RUnlock() + return len(fake.loadIngressArgsForCall) +} + +func (fake *FakeIngressStore) LoadIngressCalls(stub func(context.Context, string) (*livekit.IngressInfo, error)) { + fake.loadIngressMutex.Lock() + defer fake.loadIngressMutex.Unlock() + fake.LoadIngressStub = stub +} + +func (fake *FakeIngressStore) LoadIngressArgsForCall(i int) (context.Context, string) { + fake.loadIngressMutex.RLock() + defer fake.loadIngressMutex.RUnlock() + argsForCall := fake.loadIngressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIngressStore) LoadIngressReturns(result1 *livekit.IngressInfo, result2 error) { + fake.loadIngressMutex.Lock() + defer fake.loadIngressMutex.Unlock() + fake.LoadIngressStub = nil + fake.loadIngressReturns = struct { + result1 *livekit.IngressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIngressStore) LoadIngressReturnsOnCall(i int, result1 *livekit.IngressInfo, result2 error) { + fake.loadIngressMutex.Lock() + defer fake.loadIngressMutex.Unlock() + fake.LoadIngressStub = nil + if fake.loadIngressReturnsOnCall == nil { + fake.loadIngressReturnsOnCall = make(map[int]struct { + result1 *livekit.IngressInfo + result2 error + }) + } + fake.loadIngressReturnsOnCall[i] = struct { + result1 *livekit.IngressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIngressStore) LoadIngressFromStreamKey(arg1 context.Context, arg2 string) (*livekit.IngressInfo, error) { + fake.loadIngressFromStreamKeyMutex.Lock() + ret, specificReturn := fake.loadIngressFromStreamKeyReturnsOnCall[len(fake.loadIngressFromStreamKeyArgsForCall)] + fake.loadIngressFromStreamKeyArgsForCall = append(fake.loadIngressFromStreamKeyArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadIngressFromStreamKeyStub + fakeReturns := fake.loadIngressFromStreamKeyReturns + fake.recordInvocation("LoadIngressFromStreamKey", []interface{}{arg1, arg2}) + fake.loadIngressFromStreamKeyMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIngressStore) LoadIngressFromStreamKeyCallCount() int { + fake.loadIngressFromStreamKeyMutex.RLock() + defer fake.loadIngressFromStreamKeyMutex.RUnlock() + return len(fake.loadIngressFromStreamKeyArgsForCall) +} + +func (fake *FakeIngressStore) LoadIngressFromStreamKeyCalls(stub func(context.Context, string) (*livekit.IngressInfo, error)) { + fake.loadIngressFromStreamKeyMutex.Lock() + defer fake.loadIngressFromStreamKeyMutex.Unlock() + fake.LoadIngressFromStreamKeyStub = stub +} + +func (fake *FakeIngressStore) LoadIngressFromStreamKeyArgsForCall(i int) (context.Context, string) { + fake.loadIngressFromStreamKeyMutex.RLock() + defer fake.loadIngressFromStreamKeyMutex.RUnlock() + argsForCall := fake.loadIngressFromStreamKeyArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIngressStore) LoadIngressFromStreamKeyReturns(result1 *livekit.IngressInfo, result2 error) { + fake.loadIngressFromStreamKeyMutex.Lock() + defer fake.loadIngressFromStreamKeyMutex.Unlock() + fake.LoadIngressFromStreamKeyStub = nil + fake.loadIngressFromStreamKeyReturns = struct { + result1 *livekit.IngressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIngressStore) LoadIngressFromStreamKeyReturnsOnCall(i int, result1 *livekit.IngressInfo, result2 error) { + fake.loadIngressFromStreamKeyMutex.Lock() + defer fake.loadIngressFromStreamKeyMutex.Unlock() + fake.LoadIngressFromStreamKeyStub = nil + if fake.loadIngressFromStreamKeyReturnsOnCall == nil { + fake.loadIngressFromStreamKeyReturnsOnCall = make(map[int]struct { + result1 *livekit.IngressInfo + result2 error + }) + } + fake.loadIngressFromStreamKeyReturnsOnCall[i] = struct { + result1 *livekit.IngressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIngressStore) StoreIngress(arg1 context.Context, arg2 *livekit.IngressInfo) error { + fake.storeIngressMutex.Lock() + ret, specificReturn := fake.storeIngressReturnsOnCall[len(fake.storeIngressArgsForCall)] + fake.storeIngressArgsForCall = append(fake.storeIngressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.StoreIngressStub + fakeReturns := fake.storeIngressReturns + fake.recordInvocation("StoreIngress", []interface{}{arg1, arg2}) + fake.storeIngressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeIngressStore) StoreIngressCallCount() int { + fake.storeIngressMutex.RLock() + defer fake.storeIngressMutex.RUnlock() + return len(fake.storeIngressArgsForCall) +} + +func (fake *FakeIngressStore) StoreIngressCalls(stub func(context.Context, *livekit.IngressInfo) error) { + fake.storeIngressMutex.Lock() + defer fake.storeIngressMutex.Unlock() + fake.StoreIngressStub = stub +} + +func (fake *FakeIngressStore) StoreIngressArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.storeIngressMutex.RLock() + defer fake.storeIngressMutex.RUnlock() + argsForCall := fake.storeIngressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIngressStore) StoreIngressReturns(result1 error) { + fake.storeIngressMutex.Lock() + defer fake.storeIngressMutex.Unlock() + fake.StoreIngressStub = nil + fake.storeIngressReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) StoreIngressReturnsOnCall(i int, result1 error) { + fake.storeIngressMutex.Lock() + defer fake.storeIngressMutex.Unlock() + fake.StoreIngressStub = nil + if fake.storeIngressReturnsOnCall == nil { + fake.storeIngressReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeIngressReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) UpdateIngress(arg1 context.Context, arg2 *livekit.IngressInfo) error { + fake.updateIngressMutex.Lock() + ret, specificReturn := fake.updateIngressReturnsOnCall[len(fake.updateIngressArgsForCall)] + fake.updateIngressArgsForCall = append(fake.updateIngressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.UpdateIngressStub + fakeReturns := fake.updateIngressReturns + fake.recordInvocation("UpdateIngress", []interface{}{arg1, arg2}) + fake.updateIngressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeIngressStore) UpdateIngressCallCount() int { + fake.updateIngressMutex.RLock() + defer fake.updateIngressMutex.RUnlock() + return len(fake.updateIngressArgsForCall) +} + +func (fake *FakeIngressStore) UpdateIngressCalls(stub func(context.Context, *livekit.IngressInfo) error) { + fake.updateIngressMutex.Lock() + defer fake.updateIngressMutex.Unlock() + fake.UpdateIngressStub = stub +} + +func (fake *FakeIngressStore) UpdateIngressArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.updateIngressMutex.RLock() + defer fake.updateIngressMutex.RUnlock() + argsForCall := fake.updateIngressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIngressStore) UpdateIngressReturns(result1 error) { + fake.updateIngressMutex.Lock() + defer fake.updateIngressMutex.Unlock() + fake.UpdateIngressStub = nil + fake.updateIngressReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) UpdateIngressReturnsOnCall(i int, result1 error) { + fake.updateIngressMutex.Lock() + defer fake.updateIngressMutex.Unlock() + fake.UpdateIngressStub = nil + if fake.updateIngressReturnsOnCall == nil { + fake.updateIngressReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateIngressReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) UpdateIngressState(arg1 context.Context, arg2 string, arg3 *livekit.IngressState) error { + fake.updateIngressStateMutex.Lock() + ret, specificReturn := fake.updateIngressStateReturnsOnCall[len(fake.updateIngressStateArgsForCall)] + fake.updateIngressStateArgsForCall = append(fake.updateIngressStateArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 *livekit.IngressState + }{arg1, arg2, arg3}) + stub := fake.UpdateIngressStateStub + fakeReturns := fake.updateIngressStateReturns + fake.recordInvocation("UpdateIngressState", []interface{}{arg1, arg2, arg3}) + fake.updateIngressStateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeIngressStore) UpdateIngressStateCallCount() int { + fake.updateIngressStateMutex.RLock() + defer fake.updateIngressStateMutex.RUnlock() + return len(fake.updateIngressStateArgsForCall) +} + +func (fake *FakeIngressStore) UpdateIngressStateCalls(stub func(context.Context, string, *livekit.IngressState) error) { + fake.updateIngressStateMutex.Lock() + defer fake.updateIngressStateMutex.Unlock() + fake.UpdateIngressStateStub = stub +} + +func (fake *FakeIngressStore) UpdateIngressStateArgsForCall(i int) (context.Context, string, *livekit.IngressState) { + fake.updateIngressStateMutex.RLock() + defer fake.updateIngressStateMutex.RUnlock() + argsForCall := fake.updateIngressStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeIngressStore) UpdateIngressStateReturns(result1 error) { + fake.updateIngressStateMutex.Lock() + defer fake.updateIngressStateMutex.Unlock() + fake.UpdateIngressStateStub = nil + fake.updateIngressStateReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) UpdateIngressStateReturnsOnCall(i int, result1 error) { + fake.updateIngressStateMutex.Lock() + defer fake.updateIngressStateMutex.Unlock() + fake.UpdateIngressStateStub = nil + if fake.updateIngressStateReturnsOnCall == nil { + fake.updateIngressStateReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateIngressStateReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeIngressStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeIngressStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.IngressStore = new(FakeIngressStore) diff --git a/livekit/pkg/service/servicefakes/fake_ioclient.go b/livekit/pkg/service/servicefakes/fake_ioclient.go new file mode 100644 index 0000000..0cdda41 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_ioclient.go @@ -0,0 +1,436 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +type FakeIOClient struct { + CreateEgressStub func(context.Context, *livekit.EgressInfo) (*emptypb.Empty, error) + createEgressMutex sync.RWMutex + createEgressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.EgressInfo + } + createEgressReturns struct { + result1 *emptypb.Empty + result2 error + } + createEgressReturnsOnCall map[int]struct { + result1 *emptypb.Empty + result2 error + } + CreateIngressStub func(context.Context, *livekit.IngressInfo) (*emptypb.Empty, error) + createIngressMutex sync.RWMutex + createIngressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + createIngressReturns struct { + result1 *emptypb.Empty + result2 error + } + createIngressReturnsOnCall map[int]struct { + result1 *emptypb.Empty + result2 error + } + GetEgressStub func(context.Context, *rpc.GetEgressRequest) (*livekit.EgressInfo, error) + getEgressMutex sync.RWMutex + getEgressArgsForCall []struct { + arg1 context.Context + arg2 *rpc.GetEgressRequest + } + getEgressReturns struct { + result1 *livekit.EgressInfo + result2 error + } + getEgressReturnsOnCall map[int]struct { + result1 *livekit.EgressInfo + result2 error + } + ListEgressStub func(context.Context, *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) + listEgressMutex sync.RWMutex + listEgressArgsForCall []struct { + arg1 context.Context + arg2 *livekit.ListEgressRequest + } + listEgressReturns struct { + result1 *livekit.ListEgressResponse + result2 error + } + listEgressReturnsOnCall map[int]struct { + result1 *livekit.ListEgressResponse + result2 error + } + UpdateIngressStateStub func(context.Context, *rpc.UpdateIngressStateRequest) (*emptypb.Empty, error) + updateIngressStateMutex sync.RWMutex + updateIngressStateArgsForCall []struct { + arg1 context.Context + arg2 *rpc.UpdateIngressStateRequest + } + updateIngressStateReturns struct { + result1 *emptypb.Empty + result2 error + } + updateIngressStateReturnsOnCall map[int]struct { + result1 *emptypb.Empty + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeIOClient) CreateEgress(arg1 context.Context, arg2 *livekit.EgressInfo) (*emptypb.Empty, error) { + fake.createEgressMutex.Lock() + ret, specificReturn := fake.createEgressReturnsOnCall[len(fake.createEgressArgsForCall)] + fake.createEgressArgsForCall = append(fake.createEgressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.EgressInfo + }{arg1, arg2}) + stub := fake.CreateEgressStub + fakeReturns := fake.createEgressReturns + fake.recordInvocation("CreateEgress", []interface{}{arg1, arg2}) + fake.createEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIOClient) CreateEgressCallCount() int { + fake.createEgressMutex.RLock() + defer fake.createEgressMutex.RUnlock() + return len(fake.createEgressArgsForCall) +} + +func (fake *FakeIOClient) CreateEgressCalls(stub func(context.Context, *livekit.EgressInfo) (*emptypb.Empty, error)) { + fake.createEgressMutex.Lock() + defer fake.createEgressMutex.Unlock() + fake.CreateEgressStub = stub +} + +func (fake *FakeIOClient) CreateEgressArgsForCall(i int) (context.Context, *livekit.EgressInfo) { + fake.createEgressMutex.RLock() + defer fake.createEgressMutex.RUnlock() + argsForCall := fake.createEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIOClient) CreateEgressReturns(result1 *emptypb.Empty, result2 error) { + fake.createEgressMutex.Lock() + defer fake.createEgressMutex.Unlock() + fake.CreateEgressStub = nil + fake.createEgressReturns = struct { + result1 *emptypb.Empty + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) CreateEgressReturnsOnCall(i int, result1 *emptypb.Empty, result2 error) { + fake.createEgressMutex.Lock() + defer fake.createEgressMutex.Unlock() + fake.CreateEgressStub = nil + if fake.createEgressReturnsOnCall == nil { + fake.createEgressReturnsOnCall = make(map[int]struct { + result1 *emptypb.Empty + result2 error + }) + } + fake.createEgressReturnsOnCall[i] = struct { + result1 *emptypb.Empty + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) CreateIngress(arg1 context.Context, arg2 *livekit.IngressInfo) (*emptypb.Empty, error) { + fake.createIngressMutex.Lock() + ret, specificReturn := fake.createIngressReturnsOnCall[len(fake.createIngressArgsForCall)] + fake.createIngressArgsForCall = append(fake.createIngressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.CreateIngressStub + fakeReturns := fake.createIngressReturns + fake.recordInvocation("CreateIngress", []interface{}{arg1, arg2}) + fake.createIngressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIOClient) CreateIngressCallCount() int { + fake.createIngressMutex.RLock() + defer fake.createIngressMutex.RUnlock() + return len(fake.createIngressArgsForCall) +} + +func (fake *FakeIOClient) CreateIngressCalls(stub func(context.Context, *livekit.IngressInfo) (*emptypb.Empty, error)) { + fake.createIngressMutex.Lock() + defer fake.createIngressMutex.Unlock() + fake.CreateIngressStub = stub +} + +func (fake *FakeIOClient) CreateIngressArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.createIngressMutex.RLock() + defer fake.createIngressMutex.RUnlock() + argsForCall := fake.createIngressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIOClient) CreateIngressReturns(result1 *emptypb.Empty, result2 error) { + fake.createIngressMutex.Lock() + defer fake.createIngressMutex.Unlock() + fake.CreateIngressStub = nil + fake.createIngressReturns = struct { + result1 *emptypb.Empty + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) CreateIngressReturnsOnCall(i int, result1 *emptypb.Empty, result2 error) { + fake.createIngressMutex.Lock() + defer fake.createIngressMutex.Unlock() + fake.CreateIngressStub = nil + if fake.createIngressReturnsOnCall == nil { + fake.createIngressReturnsOnCall = make(map[int]struct { + result1 *emptypb.Empty + result2 error + }) + } + fake.createIngressReturnsOnCall[i] = struct { + result1 *emptypb.Empty + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) GetEgress(arg1 context.Context, arg2 *rpc.GetEgressRequest) (*livekit.EgressInfo, error) { + fake.getEgressMutex.Lock() + ret, specificReturn := fake.getEgressReturnsOnCall[len(fake.getEgressArgsForCall)] + fake.getEgressArgsForCall = append(fake.getEgressArgsForCall, struct { + arg1 context.Context + arg2 *rpc.GetEgressRequest + }{arg1, arg2}) + stub := fake.GetEgressStub + fakeReturns := fake.getEgressReturns + fake.recordInvocation("GetEgress", []interface{}{arg1, arg2}) + fake.getEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIOClient) GetEgressCallCount() int { + fake.getEgressMutex.RLock() + defer fake.getEgressMutex.RUnlock() + return len(fake.getEgressArgsForCall) +} + +func (fake *FakeIOClient) GetEgressCalls(stub func(context.Context, *rpc.GetEgressRequest) (*livekit.EgressInfo, error)) { + fake.getEgressMutex.Lock() + defer fake.getEgressMutex.Unlock() + fake.GetEgressStub = stub +} + +func (fake *FakeIOClient) GetEgressArgsForCall(i int) (context.Context, *rpc.GetEgressRequest) { + fake.getEgressMutex.RLock() + defer fake.getEgressMutex.RUnlock() + argsForCall := fake.getEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIOClient) GetEgressReturns(result1 *livekit.EgressInfo, result2 error) { + fake.getEgressMutex.Lock() + defer fake.getEgressMutex.Unlock() + fake.GetEgressStub = nil + fake.getEgressReturns = struct { + result1 *livekit.EgressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) GetEgressReturnsOnCall(i int, result1 *livekit.EgressInfo, result2 error) { + fake.getEgressMutex.Lock() + defer fake.getEgressMutex.Unlock() + fake.GetEgressStub = nil + if fake.getEgressReturnsOnCall == nil { + fake.getEgressReturnsOnCall = make(map[int]struct { + result1 *livekit.EgressInfo + result2 error + }) + } + fake.getEgressReturnsOnCall[i] = struct { + result1 *livekit.EgressInfo + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) ListEgress(arg1 context.Context, arg2 *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error) { + fake.listEgressMutex.Lock() + ret, specificReturn := fake.listEgressReturnsOnCall[len(fake.listEgressArgsForCall)] + fake.listEgressArgsForCall = append(fake.listEgressArgsForCall, struct { + arg1 context.Context + arg2 *livekit.ListEgressRequest + }{arg1, arg2}) + stub := fake.ListEgressStub + fakeReturns := fake.listEgressReturns + fake.recordInvocation("ListEgress", []interface{}{arg1, arg2}) + fake.listEgressMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIOClient) ListEgressCallCount() int { + fake.listEgressMutex.RLock() + defer fake.listEgressMutex.RUnlock() + return len(fake.listEgressArgsForCall) +} + +func (fake *FakeIOClient) ListEgressCalls(stub func(context.Context, *livekit.ListEgressRequest) (*livekit.ListEgressResponse, error)) { + fake.listEgressMutex.Lock() + defer fake.listEgressMutex.Unlock() + fake.ListEgressStub = stub +} + +func (fake *FakeIOClient) ListEgressArgsForCall(i int) (context.Context, *livekit.ListEgressRequest) { + fake.listEgressMutex.RLock() + defer fake.listEgressMutex.RUnlock() + argsForCall := fake.listEgressArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIOClient) ListEgressReturns(result1 *livekit.ListEgressResponse, result2 error) { + fake.listEgressMutex.Lock() + defer fake.listEgressMutex.Unlock() + fake.ListEgressStub = nil + fake.listEgressReturns = struct { + result1 *livekit.ListEgressResponse + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) ListEgressReturnsOnCall(i int, result1 *livekit.ListEgressResponse, result2 error) { + fake.listEgressMutex.Lock() + defer fake.listEgressMutex.Unlock() + fake.ListEgressStub = nil + if fake.listEgressReturnsOnCall == nil { + fake.listEgressReturnsOnCall = make(map[int]struct { + result1 *livekit.ListEgressResponse + result2 error + }) + } + fake.listEgressReturnsOnCall[i] = struct { + result1 *livekit.ListEgressResponse + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) UpdateIngressState(arg1 context.Context, arg2 *rpc.UpdateIngressStateRequest) (*emptypb.Empty, error) { + fake.updateIngressStateMutex.Lock() + ret, specificReturn := fake.updateIngressStateReturnsOnCall[len(fake.updateIngressStateArgsForCall)] + fake.updateIngressStateArgsForCall = append(fake.updateIngressStateArgsForCall, struct { + arg1 context.Context + arg2 *rpc.UpdateIngressStateRequest + }{arg1, arg2}) + stub := fake.UpdateIngressStateStub + fakeReturns := fake.updateIngressStateReturns + fake.recordInvocation("UpdateIngressState", []interface{}{arg1, arg2}) + fake.updateIngressStateMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeIOClient) UpdateIngressStateCallCount() int { + fake.updateIngressStateMutex.RLock() + defer fake.updateIngressStateMutex.RUnlock() + return len(fake.updateIngressStateArgsForCall) +} + +func (fake *FakeIOClient) UpdateIngressStateCalls(stub func(context.Context, *rpc.UpdateIngressStateRequest) (*emptypb.Empty, error)) { + fake.updateIngressStateMutex.Lock() + defer fake.updateIngressStateMutex.Unlock() + fake.UpdateIngressStateStub = stub +} + +func (fake *FakeIOClient) UpdateIngressStateArgsForCall(i int) (context.Context, *rpc.UpdateIngressStateRequest) { + fake.updateIngressStateMutex.RLock() + defer fake.updateIngressStateMutex.RUnlock() + argsForCall := fake.updateIngressStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeIOClient) UpdateIngressStateReturns(result1 *emptypb.Empty, result2 error) { + fake.updateIngressStateMutex.Lock() + defer fake.updateIngressStateMutex.Unlock() + fake.UpdateIngressStateStub = nil + fake.updateIngressStateReturns = struct { + result1 *emptypb.Empty + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) UpdateIngressStateReturnsOnCall(i int, result1 *emptypb.Empty, result2 error) { + fake.updateIngressStateMutex.Lock() + defer fake.updateIngressStateMutex.Unlock() + fake.UpdateIngressStateStub = nil + if fake.updateIngressStateReturnsOnCall == nil { + fake.updateIngressStateReturnsOnCall = make(map[int]struct { + result1 *emptypb.Empty + result2 error + }) + } + fake.updateIngressStateReturnsOnCall[i] = struct { + result1 *emptypb.Empty + result2 error + }{result1, result2} +} + +func (fake *FakeIOClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeIOClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.IOClient = new(FakeIOClient) diff --git a/livekit/pkg/service/servicefakes/fake_object_store.go b/livekit/pkg/service/servicefakes/fake_object_store.go new file mode 100644 index 0000000..24d4273 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_object_store.go @@ -0,0 +1,989 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeObjectStore struct { + DeleteParticipantStub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) error + deleteParticipantMutex sync.RWMutex + deleteParticipantArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + } + deleteParticipantReturns struct { + result1 error + } + deleteParticipantReturnsOnCall map[int]struct { + result1 error + } + DeleteRoomStub func(context.Context, livekit.RoomName) error + deleteRoomMutex sync.RWMutex + deleteRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + deleteRoomReturns struct { + result1 error + } + deleteRoomReturnsOnCall map[int]struct { + result1 error + } + HasParticipantStub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (bool, error) + hasParticipantMutex sync.RWMutex + hasParticipantArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + } + hasParticipantReturns struct { + result1 bool + result2 error + } + hasParticipantReturnsOnCall map[int]struct { + result1 bool + result2 error + } + ListParticipantsStub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error) + listParticipantsMutex sync.RWMutex + listParticipantsArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + listParticipantsReturns struct { + result1 []*livekit.ParticipantInfo + result2 error + } + listParticipantsReturnsOnCall map[int]struct { + result1 []*livekit.ParticipantInfo + result2 error + } + ListRoomsStub func(context.Context, []livekit.RoomName) ([]*livekit.Room, error) + listRoomsMutex sync.RWMutex + listRoomsArgsForCall []struct { + arg1 context.Context + arg2 []livekit.RoomName + } + listRoomsReturns struct { + result1 []*livekit.Room + result2 error + } + listRoomsReturnsOnCall map[int]struct { + result1 []*livekit.Room + result2 error + } + LoadParticipantStub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) + loadParticipantMutex sync.RWMutex + loadParticipantArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + } + loadParticipantReturns struct { + result1 *livekit.ParticipantInfo + result2 error + } + loadParticipantReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + result2 error + } + LoadRoomStub func(context.Context, livekit.RoomName, bool) (*livekit.Room, *livekit.RoomInternal, error) + loadRoomMutex sync.RWMutex + loadRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 bool + } + loadRoomReturns struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + } + loadRoomReturnsOnCall map[int]struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + } + LockRoomStub func(context.Context, livekit.RoomName, time.Duration) (string, error) + lockRoomMutex sync.RWMutex + lockRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 time.Duration + } + lockRoomReturns struct { + result1 string + result2 error + } + lockRoomReturnsOnCall map[int]struct { + result1 string + result2 error + } + RoomExistsStub func(context.Context, livekit.RoomName) (bool, error) + roomExistsMutex sync.RWMutex + roomExistsArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + roomExistsReturns struct { + result1 bool + result2 error + } + roomExistsReturnsOnCall map[int]struct { + result1 bool + result2 error + } + StoreParticipantStub func(context.Context, livekit.RoomName, *livekit.ParticipantInfo) error + storeParticipantMutex sync.RWMutex + storeParticipantArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 *livekit.ParticipantInfo + } + storeParticipantReturns struct { + result1 error + } + storeParticipantReturnsOnCall map[int]struct { + result1 error + } + StoreRoomStub func(context.Context, *livekit.Room, *livekit.RoomInternal) error + storeRoomMutex sync.RWMutex + storeRoomArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.RoomInternal + } + storeRoomReturns struct { + result1 error + } + storeRoomReturnsOnCall map[int]struct { + result1 error + } + UnlockRoomStub func(context.Context, livekit.RoomName, string) error + unlockRoomMutex sync.RWMutex + unlockRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 string + } + unlockRoomReturns struct { + result1 error + } + unlockRoomReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeObjectStore) DeleteParticipant(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.ParticipantIdentity) error { + fake.deleteParticipantMutex.Lock() + ret, specificReturn := fake.deleteParticipantReturnsOnCall[len(fake.deleteParticipantArgsForCall)] + fake.deleteParticipantArgsForCall = append(fake.deleteParticipantArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + }{arg1, arg2, arg3}) + stub := fake.DeleteParticipantStub + fakeReturns := fake.deleteParticipantReturns + fake.recordInvocation("DeleteParticipant", []interface{}{arg1, arg2, arg3}) + fake.deleteParticipantMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeObjectStore) DeleteParticipantCallCount() int { + fake.deleteParticipantMutex.RLock() + defer fake.deleteParticipantMutex.RUnlock() + return len(fake.deleteParticipantArgsForCall) +} + +func (fake *FakeObjectStore) DeleteParticipantCalls(stub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) error) { + fake.deleteParticipantMutex.Lock() + defer fake.deleteParticipantMutex.Unlock() + fake.DeleteParticipantStub = stub +} + +func (fake *FakeObjectStore) DeleteParticipantArgsForCall(i int) (context.Context, livekit.RoomName, livekit.ParticipantIdentity) { + fake.deleteParticipantMutex.RLock() + defer fake.deleteParticipantMutex.RUnlock() + argsForCall := fake.deleteParticipantArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) DeleteParticipantReturns(result1 error) { + fake.deleteParticipantMutex.Lock() + defer fake.deleteParticipantMutex.Unlock() + fake.DeleteParticipantStub = nil + fake.deleteParticipantReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) DeleteParticipantReturnsOnCall(i int, result1 error) { + fake.deleteParticipantMutex.Lock() + defer fake.deleteParticipantMutex.Unlock() + fake.DeleteParticipantStub = nil + if fake.deleteParticipantReturnsOnCall == nil { + fake.deleteParticipantReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteParticipantReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) DeleteRoom(arg1 context.Context, arg2 livekit.RoomName) error { + fake.deleteRoomMutex.Lock() + ret, specificReturn := fake.deleteRoomReturnsOnCall[len(fake.deleteRoomArgsForCall)] + fake.deleteRoomArgsForCall = append(fake.deleteRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.DeleteRoomStub + fakeReturns := fake.deleteRoomReturns + fake.recordInvocation("DeleteRoom", []interface{}{arg1, arg2}) + fake.deleteRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeObjectStore) DeleteRoomCallCount() int { + fake.deleteRoomMutex.RLock() + defer fake.deleteRoomMutex.RUnlock() + return len(fake.deleteRoomArgsForCall) +} + +func (fake *FakeObjectStore) DeleteRoomCalls(stub func(context.Context, livekit.RoomName) error) { + fake.deleteRoomMutex.Lock() + defer fake.deleteRoomMutex.Unlock() + fake.DeleteRoomStub = stub +} + +func (fake *FakeObjectStore) DeleteRoomArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.deleteRoomMutex.RLock() + defer fake.deleteRoomMutex.RUnlock() + argsForCall := fake.deleteRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeObjectStore) DeleteRoomReturns(result1 error) { + fake.deleteRoomMutex.Lock() + defer fake.deleteRoomMutex.Unlock() + fake.DeleteRoomStub = nil + fake.deleteRoomReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) DeleteRoomReturnsOnCall(i int, result1 error) { + fake.deleteRoomMutex.Lock() + defer fake.deleteRoomMutex.Unlock() + fake.DeleteRoomStub = nil + if fake.deleteRoomReturnsOnCall == nil { + fake.deleteRoomReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteRoomReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) HasParticipant(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.ParticipantIdentity) (bool, error) { + fake.hasParticipantMutex.Lock() + ret, specificReturn := fake.hasParticipantReturnsOnCall[len(fake.hasParticipantArgsForCall)] + fake.hasParticipantArgsForCall = append(fake.hasParticipantArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + }{arg1, arg2, arg3}) + stub := fake.HasParticipantStub + fakeReturns := fake.hasParticipantReturns + fake.recordInvocation("HasParticipant", []interface{}{arg1, arg2, arg3}) + fake.hasParticipantMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeObjectStore) HasParticipantCallCount() int { + fake.hasParticipantMutex.RLock() + defer fake.hasParticipantMutex.RUnlock() + return len(fake.hasParticipantArgsForCall) +} + +func (fake *FakeObjectStore) HasParticipantCalls(stub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (bool, error)) { + fake.hasParticipantMutex.Lock() + defer fake.hasParticipantMutex.Unlock() + fake.HasParticipantStub = stub +} + +func (fake *FakeObjectStore) HasParticipantArgsForCall(i int) (context.Context, livekit.RoomName, livekit.ParticipantIdentity) { + fake.hasParticipantMutex.RLock() + defer fake.hasParticipantMutex.RUnlock() + argsForCall := fake.hasParticipantArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) HasParticipantReturns(result1 bool, result2 error) { + fake.hasParticipantMutex.Lock() + defer fake.hasParticipantMutex.Unlock() + fake.HasParticipantStub = nil + fake.hasParticipantReturns = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) HasParticipantReturnsOnCall(i int, result1 bool, result2 error) { + fake.hasParticipantMutex.Lock() + defer fake.hasParticipantMutex.Unlock() + fake.HasParticipantStub = nil + if fake.hasParticipantReturnsOnCall == nil { + fake.hasParticipantReturnsOnCall = make(map[int]struct { + result1 bool + result2 error + }) + } + fake.hasParticipantReturnsOnCall[i] = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) ListParticipants(arg1 context.Context, arg2 livekit.RoomName) ([]*livekit.ParticipantInfo, error) { + fake.listParticipantsMutex.Lock() + ret, specificReturn := fake.listParticipantsReturnsOnCall[len(fake.listParticipantsArgsForCall)] + fake.listParticipantsArgsForCall = append(fake.listParticipantsArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.ListParticipantsStub + fakeReturns := fake.listParticipantsReturns + fake.recordInvocation("ListParticipants", []interface{}{arg1, arg2}) + fake.listParticipantsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeObjectStore) ListParticipantsCallCount() int { + fake.listParticipantsMutex.RLock() + defer fake.listParticipantsMutex.RUnlock() + return len(fake.listParticipantsArgsForCall) +} + +func (fake *FakeObjectStore) ListParticipantsCalls(stub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error)) { + fake.listParticipantsMutex.Lock() + defer fake.listParticipantsMutex.Unlock() + fake.ListParticipantsStub = stub +} + +func (fake *FakeObjectStore) ListParticipantsArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.listParticipantsMutex.RLock() + defer fake.listParticipantsMutex.RUnlock() + argsForCall := fake.listParticipantsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeObjectStore) ListParticipantsReturns(result1 []*livekit.ParticipantInfo, result2 error) { + fake.listParticipantsMutex.Lock() + defer fake.listParticipantsMutex.Unlock() + fake.ListParticipantsStub = nil + fake.listParticipantsReturns = struct { + result1 []*livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) ListParticipantsReturnsOnCall(i int, result1 []*livekit.ParticipantInfo, result2 error) { + fake.listParticipantsMutex.Lock() + defer fake.listParticipantsMutex.Unlock() + fake.ListParticipantsStub = nil + if fake.listParticipantsReturnsOnCall == nil { + fake.listParticipantsReturnsOnCall = make(map[int]struct { + result1 []*livekit.ParticipantInfo + result2 error + }) + } + fake.listParticipantsReturnsOnCall[i] = struct { + result1 []*livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) ListRooms(arg1 context.Context, arg2 []livekit.RoomName) ([]*livekit.Room, error) { + var arg2Copy []livekit.RoomName + if arg2 != nil { + arg2Copy = make([]livekit.RoomName, len(arg2)) + copy(arg2Copy, arg2) + } + fake.listRoomsMutex.Lock() + ret, specificReturn := fake.listRoomsReturnsOnCall[len(fake.listRoomsArgsForCall)] + fake.listRoomsArgsForCall = append(fake.listRoomsArgsForCall, struct { + arg1 context.Context + arg2 []livekit.RoomName + }{arg1, arg2Copy}) + stub := fake.ListRoomsStub + fakeReturns := fake.listRoomsReturns + fake.recordInvocation("ListRooms", []interface{}{arg1, arg2Copy}) + fake.listRoomsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeObjectStore) ListRoomsCallCount() int { + fake.listRoomsMutex.RLock() + defer fake.listRoomsMutex.RUnlock() + return len(fake.listRoomsArgsForCall) +} + +func (fake *FakeObjectStore) ListRoomsCalls(stub func(context.Context, []livekit.RoomName) ([]*livekit.Room, error)) { + fake.listRoomsMutex.Lock() + defer fake.listRoomsMutex.Unlock() + fake.ListRoomsStub = stub +} + +func (fake *FakeObjectStore) ListRoomsArgsForCall(i int) (context.Context, []livekit.RoomName) { + fake.listRoomsMutex.RLock() + defer fake.listRoomsMutex.RUnlock() + argsForCall := fake.listRoomsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeObjectStore) ListRoomsReturns(result1 []*livekit.Room, result2 error) { + fake.listRoomsMutex.Lock() + defer fake.listRoomsMutex.Unlock() + fake.ListRoomsStub = nil + fake.listRoomsReturns = struct { + result1 []*livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) ListRoomsReturnsOnCall(i int, result1 []*livekit.Room, result2 error) { + fake.listRoomsMutex.Lock() + defer fake.listRoomsMutex.Unlock() + fake.ListRoomsStub = nil + if fake.listRoomsReturnsOnCall == nil { + fake.listRoomsReturnsOnCall = make(map[int]struct { + result1 []*livekit.Room + result2 error + }) + } + fake.listRoomsReturnsOnCall[i] = struct { + result1 []*livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) LoadParticipant(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) { + fake.loadParticipantMutex.Lock() + ret, specificReturn := fake.loadParticipantReturnsOnCall[len(fake.loadParticipantArgsForCall)] + fake.loadParticipantArgsForCall = append(fake.loadParticipantArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + }{arg1, arg2, arg3}) + stub := fake.LoadParticipantStub + fakeReturns := fake.loadParticipantReturns + fake.recordInvocation("LoadParticipant", []interface{}{arg1, arg2, arg3}) + fake.loadParticipantMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeObjectStore) LoadParticipantCallCount() int { + fake.loadParticipantMutex.RLock() + defer fake.loadParticipantMutex.RUnlock() + return len(fake.loadParticipantArgsForCall) +} + +func (fake *FakeObjectStore) LoadParticipantCalls(stub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error)) { + fake.loadParticipantMutex.Lock() + defer fake.loadParticipantMutex.Unlock() + fake.LoadParticipantStub = stub +} + +func (fake *FakeObjectStore) LoadParticipantArgsForCall(i int) (context.Context, livekit.RoomName, livekit.ParticipantIdentity) { + fake.loadParticipantMutex.RLock() + defer fake.loadParticipantMutex.RUnlock() + argsForCall := fake.loadParticipantArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) LoadParticipantReturns(result1 *livekit.ParticipantInfo, result2 error) { + fake.loadParticipantMutex.Lock() + defer fake.loadParticipantMutex.Unlock() + fake.LoadParticipantStub = nil + fake.loadParticipantReturns = struct { + result1 *livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) LoadParticipantReturnsOnCall(i int, result1 *livekit.ParticipantInfo, result2 error) { + fake.loadParticipantMutex.Lock() + defer fake.loadParticipantMutex.Unlock() + fake.LoadParticipantStub = nil + if fake.loadParticipantReturnsOnCall == nil { + fake.loadParticipantReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + result2 error + }) + } + fake.loadParticipantReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) LoadRoom(arg1 context.Context, arg2 livekit.RoomName, arg3 bool) (*livekit.Room, *livekit.RoomInternal, error) { + fake.loadRoomMutex.Lock() + ret, specificReturn := fake.loadRoomReturnsOnCall[len(fake.loadRoomArgsForCall)] + fake.loadRoomArgsForCall = append(fake.loadRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.LoadRoomStub + fakeReturns := fake.loadRoomReturns + fake.recordInvocation("LoadRoom", []interface{}{arg1, arg2, arg3}) + fake.loadRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeObjectStore) LoadRoomCallCount() int { + fake.loadRoomMutex.RLock() + defer fake.loadRoomMutex.RUnlock() + return len(fake.loadRoomArgsForCall) +} + +func (fake *FakeObjectStore) LoadRoomCalls(stub func(context.Context, livekit.RoomName, bool) (*livekit.Room, *livekit.RoomInternal, error)) { + fake.loadRoomMutex.Lock() + defer fake.loadRoomMutex.Unlock() + fake.LoadRoomStub = stub +} + +func (fake *FakeObjectStore) LoadRoomArgsForCall(i int) (context.Context, livekit.RoomName, bool) { + fake.loadRoomMutex.RLock() + defer fake.loadRoomMutex.RUnlock() + argsForCall := fake.loadRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) LoadRoomReturns(result1 *livekit.Room, result2 *livekit.RoomInternal, result3 error) { + fake.loadRoomMutex.Lock() + defer fake.loadRoomMutex.Unlock() + fake.LoadRoomStub = nil + fake.loadRoomReturns = struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + }{result1, result2, result3} +} + +func (fake *FakeObjectStore) LoadRoomReturnsOnCall(i int, result1 *livekit.Room, result2 *livekit.RoomInternal, result3 error) { + fake.loadRoomMutex.Lock() + defer fake.loadRoomMutex.Unlock() + fake.LoadRoomStub = nil + if fake.loadRoomReturnsOnCall == nil { + fake.loadRoomReturnsOnCall = make(map[int]struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + }) + } + fake.loadRoomReturnsOnCall[i] = struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + }{result1, result2, result3} +} + +func (fake *FakeObjectStore) LockRoom(arg1 context.Context, arg2 livekit.RoomName, arg3 time.Duration) (string, error) { + fake.lockRoomMutex.Lock() + ret, specificReturn := fake.lockRoomReturnsOnCall[len(fake.lockRoomArgsForCall)] + fake.lockRoomArgsForCall = append(fake.lockRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 time.Duration + }{arg1, arg2, arg3}) + stub := fake.LockRoomStub + fakeReturns := fake.lockRoomReturns + fake.recordInvocation("LockRoom", []interface{}{arg1, arg2, arg3}) + fake.lockRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeObjectStore) LockRoomCallCount() int { + fake.lockRoomMutex.RLock() + defer fake.lockRoomMutex.RUnlock() + return len(fake.lockRoomArgsForCall) +} + +func (fake *FakeObjectStore) LockRoomCalls(stub func(context.Context, livekit.RoomName, time.Duration) (string, error)) { + fake.lockRoomMutex.Lock() + defer fake.lockRoomMutex.Unlock() + fake.LockRoomStub = stub +} + +func (fake *FakeObjectStore) LockRoomArgsForCall(i int) (context.Context, livekit.RoomName, time.Duration) { + fake.lockRoomMutex.RLock() + defer fake.lockRoomMutex.RUnlock() + argsForCall := fake.lockRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) LockRoomReturns(result1 string, result2 error) { + fake.lockRoomMutex.Lock() + defer fake.lockRoomMutex.Unlock() + fake.LockRoomStub = nil + fake.lockRoomReturns = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) LockRoomReturnsOnCall(i int, result1 string, result2 error) { + fake.lockRoomMutex.Lock() + defer fake.lockRoomMutex.Unlock() + fake.LockRoomStub = nil + if fake.lockRoomReturnsOnCall == nil { + fake.lockRoomReturnsOnCall = make(map[int]struct { + result1 string + result2 error + }) + } + fake.lockRoomReturnsOnCall[i] = struct { + result1 string + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) RoomExists(arg1 context.Context, arg2 livekit.RoomName) (bool, error) { + fake.roomExistsMutex.Lock() + ret, specificReturn := fake.roomExistsReturnsOnCall[len(fake.roomExistsArgsForCall)] + fake.roomExistsArgsForCall = append(fake.roomExistsArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.RoomExistsStub + fakeReturns := fake.roomExistsReturns + fake.recordInvocation("RoomExists", []interface{}{arg1, arg2}) + fake.roomExistsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeObjectStore) RoomExistsCallCount() int { + fake.roomExistsMutex.RLock() + defer fake.roomExistsMutex.RUnlock() + return len(fake.roomExistsArgsForCall) +} + +func (fake *FakeObjectStore) RoomExistsCalls(stub func(context.Context, livekit.RoomName) (bool, error)) { + fake.roomExistsMutex.Lock() + defer fake.roomExistsMutex.Unlock() + fake.RoomExistsStub = stub +} + +func (fake *FakeObjectStore) RoomExistsArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.roomExistsMutex.RLock() + defer fake.roomExistsMutex.RUnlock() + argsForCall := fake.roomExistsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeObjectStore) RoomExistsReturns(result1 bool, result2 error) { + fake.roomExistsMutex.Lock() + defer fake.roomExistsMutex.Unlock() + fake.RoomExistsStub = nil + fake.roomExistsReturns = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) RoomExistsReturnsOnCall(i int, result1 bool, result2 error) { + fake.roomExistsMutex.Lock() + defer fake.roomExistsMutex.Unlock() + fake.RoomExistsStub = nil + if fake.roomExistsReturnsOnCall == nil { + fake.roomExistsReturnsOnCall = make(map[int]struct { + result1 bool + result2 error + }) + } + fake.roomExistsReturnsOnCall[i] = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeObjectStore) StoreParticipant(arg1 context.Context, arg2 livekit.RoomName, arg3 *livekit.ParticipantInfo) error { + fake.storeParticipantMutex.Lock() + ret, specificReturn := fake.storeParticipantReturnsOnCall[len(fake.storeParticipantArgsForCall)] + fake.storeParticipantArgsForCall = append(fake.storeParticipantArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 *livekit.ParticipantInfo + }{arg1, arg2, arg3}) + stub := fake.StoreParticipantStub + fakeReturns := fake.storeParticipantReturns + fake.recordInvocation("StoreParticipant", []interface{}{arg1, arg2, arg3}) + fake.storeParticipantMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeObjectStore) StoreParticipantCallCount() int { + fake.storeParticipantMutex.RLock() + defer fake.storeParticipantMutex.RUnlock() + return len(fake.storeParticipantArgsForCall) +} + +func (fake *FakeObjectStore) StoreParticipantCalls(stub func(context.Context, livekit.RoomName, *livekit.ParticipantInfo) error) { + fake.storeParticipantMutex.Lock() + defer fake.storeParticipantMutex.Unlock() + fake.StoreParticipantStub = stub +} + +func (fake *FakeObjectStore) StoreParticipantArgsForCall(i int) (context.Context, livekit.RoomName, *livekit.ParticipantInfo) { + fake.storeParticipantMutex.RLock() + defer fake.storeParticipantMutex.RUnlock() + argsForCall := fake.storeParticipantArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) StoreParticipantReturns(result1 error) { + fake.storeParticipantMutex.Lock() + defer fake.storeParticipantMutex.Unlock() + fake.StoreParticipantStub = nil + fake.storeParticipantReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) StoreParticipantReturnsOnCall(i int, result1 error) { + fake.storeParticipantMutex.Lock() + defer fake.storeParticipantMutex.Unlock() + fake.StoreParticipantStub = nil + if fake.storeParticipantReturnsOnCall == nil { + fake.storeParticipantReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeParticipantReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) StoreRoom(arg1 context.Context, arg2 *livekit.Room, arg3 *livekit.RoomInternal) error { + fake.storeRoomMutex.Lock() + ret, specificReturn := fake.storeRoomReturnsOnCall[len(fake.storeRoomArgsForCall)] + fake.storeRoomArgsForCall = append(fake.storeRoomArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.RoomInternal + }{arg1, arg2, arg3}) + stub := fake.StoreRoomStub + fakeReturns := fake.storeRoomReturns + fake.recordInvocation("StoreRoom", []interface{}{arg1, arg2, arg3}) + fake.storeRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeObjectStore) StoreRoomCallCount() int { + fake.storeRoomMutex.RLock() + defer fake.storeRoomMutex.RUnlock() + return len(fake.storeRoomArgsForCall) +} + +func (fake *FakeObjectStore) StoreRoomCalls(stub func(context.Context, *livekit.Room, *livekit.RoomInternal) error) { + fake.storeRoomMutex.Lock() + defer fake.storeRoomMutex.Unlock() + fake.StoreRoomStub = stub +} + +func (fake *FakeObjectStore) StoreRoomArgsForCall(i int) (context.Context, *livekit.Room, *livekit.RoomInternal) { + fake.storeRoomMutex.RLock() + defer fake.storeRoomMutex.RUnlock() + argsForCall := fake.storeRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) StoreRoomReturns(result1 error) { + fake.storeRoomMutex.Lock() + defer fake.storeRoomMutex.Unlock() + fake.StoreRoomStub = nil + fake.storeRoomReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) StoreRoomReturnsOnCall(i int, result1 error) { + fake.storeRoomMutex.Lock() + defer fake.storeRoomMutex.Unlock() + fake.StoreRoomStub = nil + if fake.storeRoomReturnsOnCall == nil { + fake.storeRoomReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeRoomReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) UnlockRoom(arg1 context.Context, arg2 livekit.RoomName, arg3 string) error { + fake.unlockRoomMutex.Lock() + ret, specificReturn := fake.unlockRoomReturnsOnCall[len(fake.unlockRoomArgsForCall)] + fake.unlockRoomArgsForCall = append(fake.unlockRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 string + }{arg1, arg2, arg3}) + stub := fake.UnlockRoomStub + fakeReturns := fake.unlockRoomReturns + fake.recordInvocation("UnlockRoom", []interface{}{arg1, arg2, arg3}) + fake.unlockRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeObjectStore) UnlockRoomCallCount() int { + fake.unlockRoomMutex.RLock() + defer fake.unlockRoomMutex.RUnlock() + return len(fake.unlockRoomArgsForCall) +} + +func (fake *FakeObjectStore) UnlockRoomCalls(stub func(context.Context, livekit.RoomName, string) error) { + fake.unlockRoomMutex.Lock() + defer fake.unlockRoomMutex.Unlock() + fake.UnlockRoomStub = stub +} + +func (fake *FakeObjectStore) UnlockRoomArgsForCall(i int) (context.Context, livekit.RoomName, string) { + fake.unlockRoomMutex.RLock() + defer fake.unlockRoomMutex.RUnlock() + argsForCall := fake.unlockRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeObjectStore) UnlockRoomReturns(result1 error) { + fake.unlockRoomMutex.Lock() + defer fake.unlockRoomMutex.Unlock() + fake.UnlockRoomStub = nil + fake.unlockRoomReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) UnlockRoomReturnsOnCall(i int, result1 error) { + fake.unlockRoomMutex.Lock() + defer fake.unlockRoomMutex.Unlock() + fake.UnlockRoomStub = nil + if fake.unlockRoomReturnsOnCall == nil { + fake.unlockRoomReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.unlockRoomReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeObjectStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeObjectStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.ObjectStore = new(FakeObjectStore) diff --git a/livekit/pkg/service/servicefakes/fake_room_allocator.go b/livekit/pkg/service/servicefakes/fake_room_allocator.go new file mode 100644 index 0000000..b15ae56 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_room_allocator.go @@ -0,0 +1,352 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeRoomAllocator struct { + AutoCreateEnabledStub func(context.Context) bool + autoCreateEnabledMutex sync.RWMutex + autoCreateEnabledArgsForCall []struct { + arg1 context.Context + } + autoCreateEnabledReturns struct { + result1 bool + } + autoCreateEnabledReturnsOnCall map[int]struct { + result1 bool + } + CreateRoomStub func(context.Context, *livekit.CreateRoomRequest, bool) (*livekit.Room, *livekit.RoomInternal, bool, error) + createRoomMutex sync.RWMutex + createRoomArgsForCall []struct { + arg1 context.Context + arg2 *livekit.CreateRoomRequest + arg3 bool + } + createRoomReturns struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 bool + result4 error + } + createRoomReturnsOnCall map[int]struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 bool + result4 error + } + SelectRoomNodeStub func(context.Context, livekit.RoomName, livekit.NodeID) error + selectRoomNodeMutex sync.RWMutex + selectRoomNodeArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.NodeID + } + selectRoomNodeReturns struct { + result1 error + } + selectRoomNodeReturnsOnCall map[int]struct { + result1 error + } + ValidateCreateRoomStub func(context.Context, livekit.RoomName) error + validateCreateRoomMutex sync.RWMutex + validateCreateRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + validateCreateRoomReturns struct { + result1 error + } + validateCreateRoomReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRoomAllocator) AutoCreateEnabled(arg1 context.Context) bool { + fake.autoCreateEnabledMutex.Lock() + ret, specificReturn := fake.autoCreateEnabledReturnsOnCall[len(fake.autoCreateEnabledArgsForCall)] + fake.autoCreateEnabledArgsForCall = append(fake.autoCreateEnabledArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.AutoCreateEnabledStub + fakeReturns := fake.autoCreateEnabledReturns + fake.recordInvocation("AutoCreateEnabled", []interface{}{arg1}) + fake.autoCreateEnabledMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoomAllocator) AutoCreateEnabledCallCount() int { + fake.autoCreateEnabledMutex.RLock() + defer fake.autoCreateEnabledMutex.RUnlock() + return len(fake.autoCreateEnabledArgsForCall) +} + +func (fake *FakeRoomAllocator) AutoCreateEnabledCalls(stub func(context.Context) bool) { + fake.autoCreateEnabledMutex.Lock() + defer fake.autoCreateEnabledMutex.Unlock() + fake.AutoCreateEnabledStub = stub +} + +func (fake *FakeRoomAllocator) AutoCreateEnabledArgsForCall(i int) context.Context { + fake.autoCreateEnabledMutex.RLock() + defer fake.autoCreateEnabledMutex.RUnlock() + argsForCall := fake.autoCreateEnabledArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRoomAllocator) AutoCreateEnabledReturns(result1 bool) { + fake.autoCreateEnabledMutex.Lock() + defer fake.autoCreateEnabledMutex.Unlock() + fake.AutoCreateEnabledStub = nil + fake.autoCreateEnabledReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeRoomAllocator) AutoCreateEnabledReturnsOnCall(i int, result1 bool) { + fake.autoCreateEnabledMutex.Lock() + defer fake.autoCreateEnabledMutex.Unlock() + fake.AutoCreateEnabledStub = nil + if fake.autoCreateEnabledReturnsOnCall == nil { + fake.autoCreateEnabledReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.autoCreateEnabledReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeRoomAllocator) CreateRoom(arg1 context.Context, arg2 *livekit.CreateRoomRequest, arg3 bool) (*livekit.Room, *livekit.RoomInternal, bool, error) { + fake.createRoomMutex.Lock() + ret, specificReturn := fake.createRoomReturnsOnCall[len(fake.createRoomArgsForCall)] + fake.createRoomArgsForCall = append(fake.createRoomArgsForCall, struct { + arg1 context.Context + arg2 *livekit.CreateRoomRequest + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.CreateRoomStub + fakeReturns := fake.createRoomReturns + fake.recordInvocation("CreateRoom", []interface{}{arg1, arg2, arg3}) + fake.createRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3, ret.result4 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 +} + +func (fake *FakeRoomAllocator) CreateRoomCallCount() int { + fake.createRoomMutex.RLock() + defer fake.createRoomMutex.RUnlock() + return len(fake.createRoomArgsForCall) +} + +func (fake *FakeRoomAllocator) CreateRoomCalls(stub func(context.Context, *livekit.CreateRoomRequest, bool) (*livekit.Room, *livekit.RoomInternal, bool, error)) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = stub +} + +func (fake *FakeRoomAllocator) CreateRoomArgsForCall(i int) (context.Context, *livekit.CreateRoomRequest, bool) { + fake.createRoomMutex.RLock() + defer fake.createRoomMutex.RUnlock() + argsForCall := fake.createRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRoomAllocator) CreateRoomReturns(result1 *livekit.Room, result2 *livekit.RoomInternal, result3 bool, result4 error) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = nil + fake.createRoomReturns = struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 bool + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeRoomAllocator) CreateRoomReturnsOnCall(i int, result1 *livekit.Room, result2 *livekit.RoomInternal, result3 bool, result4 error) { + fake.createRoomMutex.Lock() + defer fake.createRoomMutex.Unlock() + fake.CreateRoomStub = nil + if fake.createRoomReturnsOnCall == nil { + fake.createRoomReturnsOnCall = make(map[int]struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 bool + result4 error + }) + } + fake.createRoomReturnsOnCall[i] = struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 bool + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeRoomAllocator) SelectRoomNode(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.NodeID) error { + fake.selectRoomNodeMutex.Lock() + ret, specificReturn := fake.selectRoomNodeReturnsOnCall[len(fake.selectRoomNodeArgsForCall)] + fake.selectRoomNodeArgsForCall = append(fake.selectRoomNodeArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.NodeID + }{arg1, arg2, arg3}) + stub := fake.SelectRoomNodeStub + fakeReturns := fake.selectRoomNodeReturns + fake.recordInvocation("SelectRoomNode", []interface{}{arg1, arg2, arg3}) + fake.selectRoomNodeMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoomAllocator) SelectRoomNodeCallCount() int { + fake.selectRoomNodeMutex.RLock() + defer fake.selectRoomNodeMutex.RUnlock() + return len(fake.selectRoomNodeArgsForCall) +} + +func (fake *FakeRoomAllocator) SelectRoomNodeCalls(stub func(context.Context, livekit.RoomName, livekit.NodeID) error) { + fake.selectRoomNodeMutex.Lock() + defer fake.selectRoomNodeMutex.Unlock() + fake.SelectRoomNodeStub = stub +} + +func (fake *FakeRoomAllocator) SelectRoomNodeArgsForCall(i int) (context.Context, livekit.RoomName, livekit.NodeID) { + fake.selectRoomNodeMutex.RLock() + defer fake.selectRoomNodeMutex.RUnlock() + argsForCall := fake.selectRoomNodeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeRoomAllocator) SelectRoomNodeReturns(result1 error) { + fake.selectRoomNodeMutex.Lock() + defer fake.selectRoomNodeMutex.Unlock() + fake.SelectRoomNodeStub = nil + fake.selectRoomNodeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRoomAllocator) SelectRoomNodeReturnsOnCall(i int, result1 error) { + fake.selectRoomNodeMutex.Lock() + defer fake.selectRoomNodeMutex.Unlock() + fake.SelectRoomNodeStub = nil + if fake.selectRoomNodeReturnsOnCall == nil { + fake.selectRoomNodeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.selectRoomNodeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRoomAllocator) ValidateCreateRoom(arg1 context.Context, arg2 livekit.RoomName) error { + fake.validateCreateRoomMutex.Lock() + ret, specificReturn := fake.validateCreateRoomReturnsOnCall[len(fake.validateCreateRoomArgsForCall)] + fake.validateCreateRoomArgsForCall = append(fake.validateCreateRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.ValidateCreateRoomStub + fakeReturns := fake.validateCreateRoomReturns + fake.recordInvocation("ValidateCreateRoom", []interface{}{arg1, arg2}) + fake.validateCreateRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRoomAllocator) ValidateCreateRoomCallCount() int { + fake.validateCreateRoomMutex.RLock() + defer fake.validateCreateRoomMutex.RUnlock() + return len(fake.validateCreateRoomArgsForCall) +} + +func (fake *FakeRoomAllocator) ValidateCreateRoomCalls(stub func(context.Context, livekit.RoomName) error) { + fake.validateCreateRoomMutex.Lock() + defer fake.validateCreateRoomMutex.Unlock() + fake.ValidateCreateRoomStub = stub +} + +func (fake *FakeRoomAllocator) ValidateCreateRoomArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.validateCreateRoomMutex.RLock() + defer fake.validateCreateRoomMutex.RUnlock() + argsForCall := fake.validateCreateRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRoomAllocator) ValidateCreateRoomReturns(result1 error) { + fake.validateCreateRoomMutex.Lock() + defer fake.validateCreateRoomMutex.Unlock() + fake.ValidateCreateRoomStub = nil + fake.validateCreateRoomReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRoomAllocator) ValidateCreateRoomReturnsOnCall(i int, result1 error) { + fake.validateCreateRoomMutex.Lock() + defer fake.validateCreateRoomMutex.Unlock() + fake.ValidateCreateRoomStub = nil + if fake.validateCreateRoomReturnsOnCall == nil { + fake.validateCreateRoomReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.validateCreateRoomReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRoomAllocator) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRoomAllocator) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.RoomAllocator = new(FakeRoomAllocator) diff --git a/livekit/pkg/service/servicefakes/fake_service_store.go b/livekit/pkg/service/servicefakes/fake_service_store.go new file mode 100644 index 0000000..e4808aa --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_service_store.go @@ -0,0 +1,448 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeServiceStore struct { + ListParticipantsStub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error) + listParticipantsMutex sync.RWMutex + listParticipantsArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + listParticipantsReturns struct { + result1 []*livekit.ParticipantInfo + result2 error + } + listParticipantsReturnsOnCall map[int]struct { + result1 []*livekit.ParticipantInfo + result2 error + } + ListRoomsStub func(context.Context, []livekit.RoomName) ([]*livekit.Room, error) + listRoomsMutex sync.RWMutex + listRoomsArgsForCall []struct { + arg1 context.Context + arg2 []livekit.RoomName + } + listRoomsReturns struct { + result1 []*livekit.Room + result2 error + } + listRoomsReturnsOnCall map[int]struct { + result1 []*livekit.Room + result2 error + } + LoadParticipantStub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) + loadParticipantMutex sync.RWMutex + loadParticipantArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + } + loadParticipantReturns struct { + result1 *livekit.ParticipantInfo + result2 error + } + loadParticipantReturnsOnCall map[int]struct { + result1 *livekit.ParticipantInfo + result2 error + } + LoadRoomStub func(context.Context, livekit.RoomName, bool) (*livekit.Room, *livekit.RoomInternal, error) + loadRoomMutex sync.RWMutex + loadRoomArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 bool + } + loadRoomReturns struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + } + loadRoomReturnsOnCall map[int]struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + } + RoomExistsStub func(context.Context, livekit.RoomName) (bool, error) + roomExistsMutex sync.RWMutex + roomExistsArgsForCall []struct { + arg1 context.Context + arg2 livekit.RoomName + } + roomExistsReturns struct { + result1 bool + result2 error + } + roomExistsReturnsOnCall map[int]struct { + result1 bool + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeServiceStore) ListParticipants(arg1 context.Context, arg2 livekit.RoomName) ([]*livekit.ParticipantInfo, error) { + fake.listParticipantsMutex.Lock() + ret, specificReturn := fake.listParticipantsReturnsOnCall[len(fake.listParticipantsArgsForCall)] + fake.listParticipantsArgsForCall = append(fake.listParticipantsArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.ListParticipantsStub + fakeReturns := fake.listParticipantsReturns + fake.recordInvocation("ListParticipants", []interface{}{arg1, arg2}) + fake.listParticipantsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeServiceStore) ListParticipantsCallCount() int { + fake.listParticipantsMutex.RLock() + defer fake.listParticipantsMutex.RUnlock() + return len(fake.listParticipantsArgsForCall) +} + +func (fake *FakeServiceStore) ListParticipantsCalls(stub func(context.Context, livekit.RoomName) ([]*livekit.ParticipantInfo, error)) { + fake.listParticipantsMutex.Lock() + defer fake.listParticipantsMutex.Unlock() + fake.ListParticipantsStub = stub +} + +func (fake *FakeServiceStore) ListParticipantsArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.listParticipantsMutex.RLock() + defer fake.listParticipantsMutex.RUnlock() + argsForCall := fake.listParticipantsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeServiceStore) ListParticipantsReturns(result1 []*livekit.ParticipantInfo, result2 error) { + fake.listParticipantsMutex.Lock() + defer fake.listParticipantsMutex.Unlock() + fake.ListParticipantsStub = nil + fake.listParticipantsReturns = struct { + result1 []*livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) ListParticipantsReturnsOnCall(i int, result1 []*livekit.ParticipantInfo, result2 error) { + fake.listParticipantsMutex.Lock() + defer fake.listParticipantsMutex.Unlock() + fake.ListParticipantsStub = nil + if fake.listParticipantsReturnsOnCall == nil { + fake.listParticipantsReturnsOnCall = make(map[int]struct { + result1 []*livekit.ParticipantInfo + result2 error + }) + } + fake.listParticipantsReturnsOnCall[i] = struct { + result1 []*livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) ListRooms(arg1 context.Context, arg2 []livekit.RoomName) ([]*livekit.Room, error) { + var arg2Copy []livekit.RoomName + if arg2 != nil { + arg2Copy = make([]livekit.RoomName, len(arg2)) + copy(arg2Copy, arg2) + } + fake.listRoomsMutex.Lock() + ret, specificReturn := fake.listRoomsReturnsOnCall[len(fake.listRoomsArgsForCall)] + fake.listRoomsArgsForCall = append(fake.listRoomsArgsForCall, struct { + arg1 context.Context + arg2 []livekit.RoomName + }{arg1, arg2Copy}) + stub := fake.ListRoomsStub + fakeReturns := fake.listRoomsReturns + fake.recordInvocation("ListRooms", []interface{}{arg1, arg2Copy}) + fake.listRoomsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeServiceStore) ListRoomsCallCount() int { + fake.listRoomsMutex.RLock() + defer fake.listRoomsMutex.RUnlock() + return len(fake.listRoomsArgsForCall) +} + +func (fake *FakeServiceStore) ListRoomsCalls(stub func(context.Context, []livekit.RoomName) ([]*livekit.Room, error)) { + fake.listRoomsMutex.Lock() + defer fake.listRoomsMutex.Unlock() + fake.ListRoomsStub = stub +} + +func (fake *FakeServiceStore) ListRoomsArgsForCall(i int) (context.Context, []livekit.RoomName) { + fake.listRoomsMutex.RLock() + defer fake.listRoomsMutex.RUnlock() + argsForCall := fake.listRoomsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeServiceStore) ListRoomsReturns(result1 []*livekit.Room, result2 error) { + fake.listRoomsMutex.Lock() + defer fake.listRoomsMutex.Unlock() + fake.ListRoomsStub = nil + fake.listRoomsReturns = struct { + result1 []*livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) ListRoomsReturnsOnCall(i int, result1 []*livekit.Room, result2 error) { + fake.listRoomsMutex.Lock() + defer fake.listRoomsMutex.Unlock() + fake.ListRoomsStub = nil + if fake.listRoomsReturnsOnCall == nil { + fake.listRoomsReturnsOnCall = make(map[int]struct { + result1 []*livekit.Room + result2 error + }) + } + fake.listRoomsReturnsOnCall[i] = struct { + result1 []*livekit.Room + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) LoadParticipant(arg1 context.Context, arg2 livekit.RoomName, arg3 livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error) { + fake.loadParticipantMutex.Lock() + ret, specificReturn := fake.loadParticipantReturnsOnCall[len(fake.loadParticipantArgsForCall)] + fake.loadParticipantArgsForCall = append(fake.loadParticipantArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 livekit.ParticipantIdentity + }{arg1, arg2, arg3}) + stub := fake.LoadParticipantStub + fakeReturns := fake.loadParticipantReturns + fake.recordInvocation("LoadParticipant", []interface{}{arg1, arg2, arg3}) + fake.loadParticipantMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeServiceStore) LoadParticipantCallCount() int { + fake.loadParticipantMutex.RLock() + defer fake.loadParticipantMutex.RUnlock() + return len(fake.loadParticipantArgsForCall) +} + +func (fake *FakeServiceStore) LoadParticipantCalls(stub func(context.Context, livekit.RoomName, livekit.ParticipantIdentity) (*livekit.ParticipantInfo, error)) { + fake.loadParticipantMutex.Lock() + defer fake.loadParticipantMutex.Unlock() + fake.LoadParticipantStub = stub +} + +func (fake *FakeServiceStore) LoadParticipantArgsForCall(i int) (context.Context, livekit.RoomName, livekit.ParticipantIdentity) { + fake.loadParticipantMutex.RLock() + defer fake.loadParticipantMutex.RUnlock() + argsForCall := fake.loadParticipantArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeServiceStore) LoadParticipantReturns(result1 *livekit.ParticipantInfo, result2 error) { + fake.loadParticipantMutex.Lock() + defer fake.loadParticipantMutex.Unlock() + fake.LoadParticipantStub = nil + fake.loadParticipantReturns = struct { + result1 *livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) LoadParticipantReturnsOnCall(i int, result1 *livekit.ParticipantInfo, result2 error) { + fake.loadParticipantMutex.Lock() + defer fake.loadParticipantMutex.Unlock() + fake.LoadParticipantStub = nil + if fake.loadParticipantReturnsOnCall == nil { + fake.loadParticipantReturnsOnCall = make(map[int]struct { + result1 *livekit.ParticipantInfo + result2 error + }) + } + fake.loadParticipantReturnsOnCall[i] = struct { + result1 *livekit.ParticipantInfo + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) LoadRoom(arg1 context.Context, arg2 livekit.RoomName, arg3 bool) (*livekit.Room, *livekit.RoomInternal, error) { + fake.loadRoomMutex.Lock() + ret, specificReturn := fake.loadRoomReturnsOnCall[len(fake.loadRoomArgsForCall)] + fake.loadRoomArgsForCall = append(fake.loadRoomArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + arg3 bool + }{arg1, arg2, arg3}) + stub := fake.LoadRoomStub + fakeReturns := fake.loadRoomReturns + fake.recordInvocation("LoadRoom", []interface{}{arg1, arg2, arg3}) + fake.loadRoomMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeServiceStore) LoadRoomCallCount() int { + fake.loadRoomMutex.RLock() + defer fake.loadRoomMutex.RUnlock() + return len(fake.loadRoomArgsForCall) +} + +func (fake *FakeServiceStore) LoadRoomCalls(stub func(context.Context, livekit.RoomName, bool) (*livekit.Room, *livekit.RoomInternal, error)) { + fake.loadRoomMutex.Lock() + defer fake.loadRoomMutex.Unlock() + fake.LoadRoomStub = stub +} + +func (fake *FakeServiceStore) LoadRoomArgsForCall(i int) (context.Context, livekit.RoomName, bool) { + fake.loadRoomMutex.RLock() + defer fake.loadRoomMutex.RUnlock() + argsForCall := fake.loadRoomArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeServiceStore) LoadRoomReturns(result1 *livekit.Room, result2 *livekit.RoomInternal, result3 error) { + fake.loadRoomMutex.Lock() + defer fake.loadRoomMutex.Unlock() + fake.LoadRoomStub = nil + fake.loadRoomReturns = struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + }{result1, result2, result3} +} + +func (fake *FakeServiceStore) LoadRoomReturnsOnCall(i int, result1 *livekit.Room, result2 *livekit.RoomInternal, result3 error) { + fake.loadRoomMutex.Lock() + defer fake.loadRoomMutex.Unlock() + fake.LoadRoomStub = nil + if fake.loadRoomReturnsOnCall == nil { + fake.loadRoomReturnsOnCall = make(map[int]struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + }) + } + fake.loadRoomReturnsOnCall[i] = struct { + result1 *livekit.Room + result2 *livekit.RoomInternal + result3 error + }{result1, result2, result3} +} + +func (fake *FakeServiceStore) RoomExists(arg1 context.Context, arg2 livekit.RoomName) (bool, error) { + fake.roomExistsMutex.Lock() + ret, specificReturn := fake.roomExistsReturnsOnCall[len(fake.roomExistsArgsForCall)] + fake.roomExistsArgsForCall = append(fake.roomExistsArgsForCall, struct { + arg1 context.Context + arg2 livekit.RoomName + }{arg1, arg2}) + stub := fake.RoomExistsStub + fakeReturns := fake.roomExistsReturns + fake.recordInvocation("RoomExists", []interface{}{arg1, arg2}) + fake.roomExistsMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeServiceStore) RoomExistsCallCount() int { + fake.roomExistsMutex.RLock() + defer fake.roomExistsMutex.RUnlock() + return len(fake.roomExistsArgsForCall) +} + +func (fake *FakeServiceStore) RoomExistsCalls(stub func(context.Context, livekit.RoomName) (bool, error)) { + fake.roomExistsMutex.Lock() + defer fake.roomExistsMutex.Unlock() + fake.RoomExistsStub = stub +} + +func (fake *FakeServiceStore) RoomExistsArgsForCall(i int) (context.Context, livekit.RoomName) { + fake.roomExistsMutex.RLock() + defer fake.roomExistsMutex.RUnlock() + argsForCall := fake.roomExistsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeServiceStore) RoomExistsReturns(result1 bool, result2 error) { + fake.roomExistsMutex.Lock() + defer fake.roomExistsMutex.Unlock() + fake.RoomExistsStub = nil + fake.roomExistsReturns = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) RoomExistsReturnsOnCall(i int, result1 bool, result2 error) { + fake.roomExistsMutex.Lock() + defer fake.roomExistsMutex.Unlock() + fake.RoomExistsStub = nil + if fake.roomExistsReturnsOnCall == nil { + fake.roomExistsReturnsOnCall = make(map[int]struct { + result1 bool + result2 error + }) + } + fake.roomExistsReturnsOnCall[i] = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeServiceStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeServiceStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.ServiceStore = new(FakeServiceStore) diff --git a/livekit/pkg/service/servicefakes/fake_session_handler.go b/livekit/pkg/service/servicefakes/fake_session_handler.go new file mode 100644 index 0000000..e112e67 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_session_handler.go @@ -0,0 +1,193 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +type FakeSessionHandler struct { + HandleSessionStub func(context.Context, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error + handleSessionMutex sync.RWMutex + handleSessionArgsForCall []struct { + arg1 context.Context + arg2 routing.ParticipantInit + arg3 livekit.ConnectionID + arg4 routing.MessageSource + arg5 routing.MessageSink + } + handleSessionReturns struct { + result1 error + } + handleSessionReturnsOnCall map[int]struct { + result1 error + } + LoggerStub func(context.Context) logger.Logger + loggerMutex sync.RWMutex + loggerArgsForCall []struct { + arg1 context.Context + } + loggerReturns struct { + result1 logger.Logger + } + loggerReturnsOnCall map[int]struct { + result1 logger.Logger + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSessionHandler) HandleSession(arg1 context.Context, arg2 routing.ParticipantInit, arg3 livekit.ConnectionID, arg4 routing.MessageSource, arg5 routing.MessageSink) error { + fake.handleSessionMutex.Lock() + ret, specificReturn := fake.handleSessionReturnsOnCall[len(fake.handleSessionArgsForCall)] + fake.handleSessionArgsForCall = append(fake.handleSessionArgsForCall, struct { + arg1 context.Context + arg2 routing.ParticipantInit + arg3 livekit.ConnectionID + arg4 routing.MessageSource + arg5 routing.MessageSink + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.HandleSessionStub + fakeReturns := fake.handleSessionReturns + fake.recordInvocation("HandleSession", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.handleSessionMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4, arg5) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSessionHandler) HandleSessionCallCount() int { + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + return len(fake.handleSessionArgsForCall) +} + +func (fake *FakeSessionHandler) HandleSessionCalls(stub func(context.Context, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = stub +} + +func (fake *FakeSessionHandler) HandleSessionArgsForCall(i int) (context.Context, routing.ParticipantInit, livekit.ConnectionID, routing.MessageSource, routing.MessageSink) { + fake.handleSessionMutex.RLock() + defer fake.handleSessionMutex.RUnlock() + argsForCall := fake.handleSessionArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeSessionHandler) HandleSessionReturns(result1 error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = nil + fake.handleSessionReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSessionHandler) HandleSessionReturnsOnCall(i int, result1 error) { + fake.handleSessionMutex.Lock() + defer fake.handleSessionMutex.Unlock() + fake.HandleSessionStub = nil + if fake.handleSessionReturnsOnCall == nil { + fake.handleSessionReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.handleSessionReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSessionHandler) Logger(arg1 context.Context) logger.Logger { + fake.loggerMutex.Lock() + ret, specificReturn := fake.loggerReturnsOnCall[len(fake.loggerArgsForCall)] + fake.loggerArgsForCall = append(fake.loggerArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.LoggerStub + fakeReturns := fake.loggerReturns + fake.recordInvocation("Logger", []interface{}{arg1}) + fake.loggerMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSessionHandler) LoggerCallCount() int { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + return len(fake.loggerArgsForCall) +} + +func (fake *FakeSessionHandler) LoggerCalls(stub func(context.Context) logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = stub +} + +func (fake *FakeSessionHandler) LoggerArgsForCall(i int) context.Context { + fake.loggerMutex.RLock() + defer fake.loggerMutex.RUnlock() + argsForCall := fake.loggerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSessionHandler) LoggerReturns(result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + fake.loggerReturns = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeSessionHandler) LoggerReturnsOnCall(i int, result1 logger.Logger) { + fake.loggerMutex.Lock() + defer fake.loggerMutex.Unlock() + fake.LoggerStub = nil + if fake.loggerReturnsOnCall == nil { + fake.loggerReturnsOnCall = make(map[int]struct { + result1 logger.Logger + }) + } + fake.loggerReturnsOnCall[i] = struct { + result1 logger.Logger + }{result1} +} + +func (fake *FakeSessionHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSessionHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.SessionHandler = new(FakeSessionHandler) diff --git a/livekit/pkg/service/servicefakes/fake_sipstore.go b/livekit/pkg/service/servicefakes/fake_sipstore.go new file mode 100644 index 0000000..ba88c87 --- /dev/null +++ b/livekit/pkg/service/servicefakes/fake_sipstore.go @@ -0,0 +1,1115 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package servicefakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/protocol/livekit" +) + +type FakeSIPStore struct { + DeleteSIPDispatchRuleStub func(context.Context, string) error + deleteSIPDispatchRuleMutex sync.RWMutex + deleteSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 string + } + deleteSIPDispatchRuleReturns struct { + result1 error + } + deleteSIPDispatchRuleReturnsOnCall map[int]struct { + result1 error + } + DeleteSIPTrunkStub func(context.Context, string) error + deleteSIPTrunkMutex sync.RWMutex + deleteSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 string + } + deleteSIPTrunkReturns struct { + result1 error + } + deleteSIPTrunkReturnsOnCall map[int]struct { + result1 error + } + ListSIPDispatchRuleStub func(context.Context, *livekit.ListSIPDispatchRuleRequest) (*livekit.ListSIPDispatchRuleResponse, error) + listSIPDispatchRuleMutex sync.RWMutex + listSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 *livekit.ListSIPDispatchRuleRequest + } + listSIPDispatchRuleReturns struct { + result1 *livekit.ListSIPDispatchRuleResponse + result2 error + } + listSIPDispatchRuleReturnsOnCall map[int]struct { + result1 *livekit.ListSIPDispatchRuleResponse + result2 error + } + ListSIPInboundTrunkStub func(context.Context, *livekit.ListSIPInboundTrunkRequest) (*livekit.ListSIPInboundTrunkResponse, error) + listSIPInboundTrunkMutex sync.RWMutex + listSIPInboundTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.ListSIPInboundTrunkRequest + } + listSIPInboundTrunkReturns struct { + result1 *livekit.ListSIPInboundTrunkResponse + result2 error + } + listSIPInboundTrunkReturnsOnCall map[int]struct { + result1 *livekit.ListSIPInboundTrunkResponse + result2 error + } + ListSIPOutboundTrunkStub func(context.Context, *livekit.ListSIPOutboundTrunkRequest) (*livekit.ListSIPOutboundTrunkResponse, error) + listSIPOutboundTrunkMutex sync.RWMutex + listSIPOutboundTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.ListSIPOutboundTrunkRequest + } + listSIPOutboundTrunkReturns struct { + result1 *livekit.ListSIPOutboundTrunkResponse + result2 error + } + listSIPOutboundTrunkReturnsOnCall map[int]struct { + result1 *livekit.ListSIPOutboundTrunkResponse + result2 error + } + ListSIPTrunkStub func(context.Context, *livekit.ListSIPTrunkRequest) (*livekit.ListSIPTrunkResponse, error) + listSIPTrunkMutex sync.RWMutex + listSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.ListSIPTrunkRequest + } + listSIPTrunkReturns struct { + result1 *livekit.ListSIPTrunkResponse + result2 error + } + listSIPTrunkReturnsOnCall map[int]struct { + result1 *livekit.ListSIPTrunkResponse + result2 error + } + LoadSIPDispatchRuleStub func(context.Context, string) (*livekit.SIPDispatchRuleInfo, error) + loadSIPDispatchRuleMutex sync.RWMutex + loadSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadSIPDispatchRuleReturns struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + } + loadSIPDispatchRuleReturnsOnCall map[int]struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + } + LoadSIPInboundTrunkStub func(context.Context, string) (*livekit.SIPInboundTrunkInfo, error) + loadSIPInboundTrunkMutex sync.RWMutex + loadSIPInboundTrunkArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadSIPInboundTrunkReturns struct { + result1 *livekit.SIPInboundTrunkInfo + result2 error + } + loadSIPInboundTrunkReturnsOnCall map[int]struct { + result1 *livekit.SIPInboundTrunkInfo + result2 error + } + LoadSIPOutboundTrunkStub func(context.Context, string) (*livekit.SIPOutboundTrunkInfo, error) + loadSIPOutboundTrunkMutex sync.RWMutex + loadSIPOutboundTrunkArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadSIPOutboundTrunkReturns struct { + result1 *livekit.SIPOutboundTrunkInfo + result2 error + } + loadSIPOutboundTrunkReturnsOnCall map[int]struct { + result1 *livekit.SIPOutboundTrunkInfo + result2 error + } + LoadSIPTrunkStub func(context.Context, string) (*livekit.SIPTrunkInfo, error) + loadSIPTrunkMutex sync.RWMutex + loadSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 string + } + loadSIPTrunkReturns struct { + result1 *livekit.SIPTrunkInfo + result2 error + } + loadSIPTrunkReturnsOnCall map[int]struct { + result1 *livekit.SIPTrunkInfo + result2 error + } + StoreSIPDispatchRuleStub func(context.Context, *livekit.SIPDispatchRuleInfo) error + storeSIPDispatchRuleMutex sync.RWMutex + storeSIPDispatchRuleArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPDispatchRuleInfo + } + storeSIPDispatchRuleReturns struct { + result1 error + } + storeSIPDispatchRuleReturnsOnCall map[int]struct { + result1 error + } + StoreSIPInboundTrunkStub func(context.Context, *livekit.SIPInboundTrunkInfo) error + storeSIPInboundTrunkMutex sync.RWMutex + storeSIPInboundTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPInboundTrunkInfo + } + storeSIPInboundTrunkReturns struct { + result1 error + } + storeSIPInboundTrunkReturnsOnCall map[int]struct { + result1 error + } + StoreSIPOutboundTrunkStub func(context.Context, *livekit.SIPOutboundTrunkInfo) error + storeSIPOutboundTrunkMutex sync.RWMutex + storeSIPOutboundTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPOutboundTrunkInfo + } + storeSIPOutboundTrunkReturns struct { + result1 error + } + storeSIPOutboundTrunkReturnsOnCall map[int]struct { + result1 error + } + StoreSIPTrunkStub func(context.Context, *livekit.SIPTrunkInfo) error + storeSIPTrunkMutex sync.RWMutex + storeSIPTrunkArgsForCall []struct { + arg1 context.Context + arg2 *livekit.SIPTrunkInfo + } + storeSIPTrunkReturns struct { + result1 error + } + storeSIPTrunkReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRule(arg1 context.Context, arg2 string) error { + fake.deleteSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.deleteSIPDispatchRuleReturnsOnCall[len(fake.deleteSIPDispatchRuleArgsForCall)] + fake.deleteSIPDispatchRuleArgsForCall = append(fake.deleteSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.DeleteSIPDispatchRuleStub + fakeReturns := fake.deleteSIPDispatchRuleReturns + fake.recordInvocation("DeleteSIPDispatchRule", []interface{}{arg1, arg2}) + fake.deleteSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleCallCount() int { + fake.deleteSIPDispatchRuleMutex.RLock() + defer fake.deleteSIPDispatchRuleMutex.RUnlock() + return len(fake.deleteSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleCalls(stub func(context.Context, string) error) { + fake.deleteSIPDispatchRuleMutex.Lock() + defer fake.deleteSIPDispatchRuleMutex.Unlock() + fake.DeleteSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleArgsForCall(i int) (context.Context, string) { + fake.deleteSIPDispatchRuleMutex.RLock() + defer fake.deleteSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.deleteSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleReturns(result1 error) { + fake.deleteSIPDispatchRuleMutex.Lock() + defer fake.deleteSIPDispatchRuleMutex.Unlock() + fake.DeleteSIPDispatchRuleStub = nil + fake.deleteSIPDispatchRuleReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) DeleteSIPDispatchRuleReturnsOnCall(i int, result1 error) { + fake.deleteSIPDispatchRuleMutex.Lock() + defer fake.deleteSIPDispatchRuleMutex.Unlock() + fake.DeleteSIPDispatchRuleStub = nil + if fake.deleteSIPDispatchRuleReturnsOnCall == nil { + fake.deleteSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteSIPDispatchRuleReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) DeleteSIPTrunk(arg1 context.Context, arg2 string) error { + fake.deleteSIPTrunkMutex.Lock() + ret, specificReturn := fake.deleteSIPTrunkReturnsOnCall[len(fake.deleteSIPTrunkArgsForCall)] + fake.deleteSIPTrunkArgsForCall = append(fake.deleteSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.DeleteSIPTrunkStub + fakeReturns := fake.deleteSIPTrunkReturns + fake.recordInvocation("DeleteSIPTrunk", []interface{}{arg1, arg2}) + fake.deleteSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) DeleteSIPTrunkCallCount() int { + fake.deleteSIPTrunkMutex.RLock() + defer fake.deleteSIPTrunkMutex.RUnlock() + return len(fake.deleteSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) DeleteSIPTrunkCalls(stub func(context.Context, string) error) { + fake.deleteSIPTrunkMutex.Lock() + defer fake.deleteSIPTrunkMutex.Unlock() + fake.DeleteSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) DeleteSIPTrunkArgsForCall(i int) (context.Context, string) { + fake.deleteSIPTrunkMutex.RLock() + defer fake.deleteSIPTrunkMutex.RUnlock() + argsForCall := fake.deleteSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) DeleteSIPTrunkReturns(result1 error) { + fake.deleteSIPTrunkMutex.Lock() + defer fake.deleteSIPTrunkMutex.Unlock() + fake.DeleteSIPTrunkStub = nil + fake.deleteSIPTrunkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) DeleteSIPTrunkReturnsOnCall(i int, result1 error) { + fake.deleteSIPTrunkMutex.Lock() + defer fake.deleteSIPTrunkMutex.Unlock() + fake.DeleteSIPTrunkStub = nil + if fake.deleteSIPTrunkReturnsOnCall == nil { + fake.deleteSIPTrunkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteSIPTrunkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) ListSIPDispatchRule(arg1 context.Context, arg2 *livekit.ListSIPDispatchRuleRequest) (*livekit.ListSIPDispatchRuleResponse, error) { + fake.listSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.listSIPDispatchRuleReturnsOnCall[len(fake.listSIPDispatchRuleArgsForCall)] + fake.listSIPDispatchRuleArgsForCall = append(fake.listSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 *livekit.ListSIPDispatchRuleRequest + }{arg1, arg2}) + stub := fake.ListSIPDispatchRuleStub + fakeReturns := fake.listSIPDispatchRuleReturns + fake.recordInvocation("ListSIPDispatchRule", []interface{}{arg1, arg2}) + fake.listSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleCallCount() int { + fake.listSIPDispatchRuleMutex.RLock() + defer fake.listSIPDispatchRuleMutex.RUnlock() + return len(fake.listSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleCalls(stub func(context.Context, *livekit.ListSIPDispatchRuleRequest) (*livekit.ListSIPDispatchRuleResponse, error)) { + fake.listSIPDispatchRuleMutex.Lock() + defer fake.listSIPDispatchRuleMutex.Unlock() + fake.ListSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleArgsForCall(i int) (context.Context, *livekit.ListSIPDispatchRuleRequest) { + fake.listSIPDispatchRuleMutex.RLock() + defer fake.listSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.listSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleReturns(result1 *livekit.ListSIPDispatchRuleResponse, result2 error) { + fake.listSIPDispatchRuleMutex.Lock() + defer fake.listSIPDispatchRuleMutex.Unlock() + fake.ListSIPDispatchRuleStub = nil + fake.listSIPDispatchRuleReturns = struct { + result1 *livekit.ListSIPDispatchRuleResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPDispatchRuleReturnsOnCall(i int, result1 *livekit.ListSIPDispatchRuleResponse, result2 error) { + fake.listSIPDispatchRuleMutex.Lock() + defer fake.listSIPDispatchRuleMutex.Unlock() + fake.ListSIPDispatchRuleStub = nil + if fake.listSIPDispatchRuleReturnsOnCall == nil { + fake.listSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 *livekit.ListSIPDispatchRuleResponse + result2 error + }) + } + fake.listSIPDispatchRuleReturnsOnCall[i] = struct { + result1 *livekit.ListSIPDispatchRuleResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPInboundTrunk(arg1 context.Context, arg2 *livekit.ListSIPInboundTrunkRequest) (*livekit.ListSIPInboundTrunkResponse, error) { + fake.listSIPInboundTrunkMutex.Lock() + ret, specificReturn := fake.listSIPInboundTrunkReturnsOnCall[len(fake.listSIPInboundTrunkArgsForCall)] + fake.listSIPInboundTrunkArgsForCall = append(fake.listSIPInboundTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.ListSIPInboundTrunkRequest + }{arg1, arg2}) + stub := fake.ListSIPInboundTrunkStub + fakeReturns := fake.listSIPInboundTrunkReturns + fake.recordInvocation("ListSIPInboundTrunk", []interface{}{arg1, arg2}) + fake.listSIPInboundTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) ListSIPInboundTrunkCallCount() int { + fake.listSIPInboundTrunkMutex.RLock() + defer fake.listSIPInboundTrunkMutex.RUnlock() + return len(fake.listSIPInboundTrunkArgsForCall) +} + +func (fake *FakeSIPStore) ListSIPInboundTrunkCalls(stub func(context.Context, *livekit.ListSIPInboundTrunkRequest) (*livekit.ListSIPInboundTrunkResponse, error)) { + fake.listSIPInboundTrunkMutex.Lock() + defer fake.listSIPInboundTrunkMutex.Unlock() + fake.ListSIPInboundTrunkStub = stub +} + +func (fake *FakeSIPStore) ListSIPInboundTrunkArgsForCall(i int) (context.Context, *livekit.ListSIPInboundTrunkRequest) { + fake.listSIPInboundTrunkMutex.RLock() + defer fake.listSIPInboundTrunkMutex.RUnlock() + argsForCall := fake.listSIPInboundTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) ListSIPInboundTrunkReturns(result1 *livekit.ListSIPInboundTrunkResponse, result2 error) { + fake.listSIPInboundTrunkMutex.Lock() + defer fake.listSIPInboundTrunkMutex.Unlock() + fake.ListSIPInboundTrunkStub = nil + fake.listSIPInboundTrunkReturns = struct { + result1 *livekit.ListSIPInboundTrunkResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPInboundTrunkReturnsOnCall(i int, result1 *livekit.ListSIPInboundTrunkResponse, result2 error) { + fake.listSIPInboundTrunkMutex.Lock() + defer fake.listSIPInboundTrunkMutex.Unlock() + fake.ListSIPInboundTrunkStub = nil + if fake.listSIPInboundTrunkReturnsOnCall == nil { + fake.listSIPInboundTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.ListSIPInboundTrunkResponse + result2 error + }) + } + fake.listSIPInboundTrunkReturnsOnCall[i] = struct { + result1 *livekit.ListSIPInboundTrunkResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPOutboundTrunk(arg1 context.Context, arg2 *livekit.ListSIPOutboundTrunkRequest) (*livekit.ListSIPOutboundTrunkResponse, error) { + fake.listSIPOutboundTrunkMutex.Lock() + ret, specificReturn := fake.listSIPOutboundTrunkReturnsOnCall[len(fake.listSIPOutboundTrunkArgsForCall)] + fake.listSIPOutboundTrunkArgsForCall = append(fake.listSIPOutboundTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.ListSIPOutboundTrunkRequest + }{arg1, arg2}) + stub := fake.ListSIPOutboundTrunkStub + fakeReturns := fake.listSIPOutboundTrunkReturns + fake.recordInvocation("ListSIPOutboundTrunk", []interface{}{arg1, arg2}) + fake.listSIPOutboundTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) ListSIPOutboundTrunkCallCount() int { + fake.listSIPOutboundTrunkMutex.RLock() + defer fake.listSIPOutboundTrunkMutex.RUnlock() + return len(fake.listSIPOutboundTrunkArgsForCall) +} + +func (fake *FakeSIPStore) ListSIPOutboundTrunkCalls(stub func(context.Context, *livekit.ListSIPOutboundTrunkRequest) (*livekit.ListSIPOutboundTrunkResponse, error)) { + fake.listSIPOutboundTrunkMutex.Lock() + defer fake.listSIPOutboundTrunkMutex.Unlock() + fake.ListSIPOutboundTrunkStub = stub +} + +func (fake *FakeSIPStore) ListSIPOutboundTrunkArgsForCall(i int) (context.Context, *livekit.ListSIPOutboundTrunkRequest) { + fake.listSIPOutboundTrunkMutex.RLock() + defer fake.listSIPOutboundTrunkMutex.RUnlock() + argsForCall := fake.listSIPOutboundTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) ListSIPOutboundTrunkReturns(result1 *livekit.ListSIPOutboundTrunkResponse, result2 error) { + fake.listSIPOutboundTrunkMutex.Lock() + defer fake.listSIPOutboundTrunkMutex.Unlock() + fake.ListSIPOutboundTrunkStub = nil + fake.listSIPOutboundTrunkReturns = struct { + result1 *livekit.ListSIPOutboundTrunkResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPOutboundTrunkReturnsOnCall(i int, result1 *livekit.ListSIPOutboundTrunkResponse, result2 error) { + fake.listSIPOutboundTrunkMutex.Lock() + defer fake.listSIPOutboundTrunkMutex.Unlock() + fake.ListSIPOutboundTrunkStub = nil + if fake.listSIPOutboundTrunkReturnsOnCall == nil { + fake.listSIPOutboundTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.ListSIPOutboundTrunkResponse + result2 error + }) + } + fake.listSIPOutboundTrunkReturnsOnCall[i] = struct { + result1 *livekit.ListSIPOutboundTrunkResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPTrunk(arg1 context.Context, arg2 *livekit.ListSIPTrunkRequest) (*livekit.ListSIPTrunkResponse, error) { + fake.listSIPTrunkMutex.Lock() + ret, specificReturn := fake.listSIPTrunkReturnsOnCall[len(fake.listSIPTrunkArgsForCall)] + fake.listSIPTrunkArgsForCall = append(fake.listSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.ListSIPTrunkRequest + }{arg1, arg2}) + stub := fake.ListSIPTrunkStub + fakeReturns := fake.listSIPTrunkReturns + fake.recordInvocation("ListSIPTrunk", []interface{}{arg1, arg2}) + fake.listSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) ListSIPTrunkCallCount() int { + fake.listSIPTrunkMutex.RLock() + defer fake.listSIPTrunkMutex.RUnlock() + return len(fake.listSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) ListSIPTrunkCalls(stub func(context.Context, *livekit.ListSIPTrunkRequest) (*livekit.ListSIPTrunkResponse, error)) { + fake.listSIPTrunkMutex.Lock() + defer fake.listSIPTrunkMutex.Unlock() + fake.ListSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) ListSIPTrunkArgsForCall(i int) (context.Context, *livekit.ListSIPTrunkRequest) { + fake.listSIPTrunkMutex.RLock() + defer fake.listSIPTrunkMutex.RUnlock() + argsForCall := fake.listSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) ListSIPTrunkReturns(result1 *livekit.ListSIPTrunkResponse, result2 error) { + fake.listSIPTrunkMutex.Lock() + defer fake.listSIPTrunkMutex.Unlock() + fake.ListSIPTrunkStub = nil + fake.listSIPTrunkReturns = struct { + result1 *livekit.ListSIPTrunkResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) ListSIPTrunkReturnsOnCall(i int, result1 *livekit.ListSIPTrunkResponse, result2 error) { + fake.listSIPTrunkMutex.Lock() + defer fake.listSIPTrunkMutex.Unlock() + fake.ListSIPTrunkStub = nil + if fake.listSIPTrunkReturnsOnCall == nil { + fake.listSIPTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.ListSIPTrunkResponse + result2 error + }) + } + fake.listSIPTrunkReturnsOnCall[i] = struct { + result1 *livekit.ListSIPTrunkResponse + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPDispatchRule(arg1 context.Context, arg2 string) (*livekit.SIPDispatchRuleInfo, error) { + fake.loadSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.loadSIPDispatchRuleReturnsOnCall[len(fake.loadSIPDispatchRuleArgsForCall)] + fake.loadSIPDispatchRuleArgsForCall = append(fake.loadSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadSIPDispatchRuleStub + fakeReturns := fake.loadSIPDispatchRuleReturns + fake.recordInvocation("LoadSIPDispatchRule", []interface{}{arg1, arg2}) + fake.loadSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleCallCount() int { + fake.loadSIPDispatchRuleMutex.RLock() + defer fake.loadSIPDispatchRuleMutex.RUnlock() + return len(fake.loadSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleCalls(stub func(context.Context, string) (*livekit.SIPDispatchRuleInfo, error)) { + fake.loadSIPDispatchRuleMutex.Lock() + defer fake.loadSIPDispatchRuleMutex.Unlock() + fake.LoadSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleArgsForCall(i int) (context.Context, string) { + fake.loadSIPDispatchRuleMutex.RLock() + defer fake.loadSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.loadSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleReturns(result1 *livekit.SIPDispatchRuleInfo, result2 error) { + fake.loadSIPDispatchRuleMutex.Lock() + defer fake.loadSIPDispatchRuleMutex.Unlock() + fake.LoadSIPDispatchRuleStub = nil + fake.loadSIPDispatchRuleReturns = struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPDispatchRuleReturnsOnCall(i int, result1 *livekit.SIPDispatchRuleInfo, result2 error) { + fake.loadSIPDispatchRuleMutex.Lock() + defer fake.loadSIPDispatchRuleMutex.Unlock() + fake.LoadSIPDispatchRuleStub = nil + if fake.loadSIPDispatchRuleReturnsOnCall == nil { + fake.loadSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + }) + } + fake.loadSIPDispatchRuleReturnsOnCall[i] = struct { + result1 *livekit.SIPDispatchRuleInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPInboundTrunk(arg1 context.Context, arg2 string) (*livekit.SIPInboundTrunkInfo, error) { + fake.loadSIPInboundTrunkMutex.Lock() + ret, specificReturn := fake.loadSIPInboundTrunkReturnsOnCall[len(fake.loadSIPInboundTrunkArgsForCall)] + fake.loadSIPInboundTrunkArgsForCall = append(fake.loadSIPInboundTrunkArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadSIPInboundTrunkStub + fakeReturns := fake.loadSIPInboundTrunkReturns + fake.recordInvocation("LoadSIPInboundTrunk", []interface{}{arg1, arg2}) + fake.loadSIPInboundTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) LoadSIPInboundTrunkCallCount() int { + fake.loadSIPInboundTrunkMutex.RLock() + defer fake.loadSIPInboundTrunkMutex.RUnlock() + return len(fake.loadSIPInboundTrunkArgsForCall) +} + +func (fake *FakeSIPStore) LoadSIPInboundTrunkCalls(stub func(context.Context, string) (*livekit.SIPInboundTrunkInfo, error)) { + fake.loadSIPInboundTrunkMutex.Lock() + defer fake.loadSIPInboundTrunkMutex.Unlock() + fake.LoadSIPInboundTrunkStub = stub +} + +func (fake *FakeSIPStore) LoadSIPInboundTrunkArgsForCall(i int) (context.Context, string) { + fake.loadSIPInboundTrunkMutex.RLock() + defer fake.loadSIPInboundTrunkMutex.RUnlock() + argsForCall := fake.loadSIPInboundTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) LoadSIPInboundTrunkReturns(result1 *livekit.SIPInboundTrunkInfo, result2 error) { + fake.loadSIPInboundTrunkMutex.Lock() + defer fake.loadSIPInboundTrunkMutex.Unlock() + fake.LoadSIPInboundTrunkStub = nil + fake.loadSIPInboundTrunkReturns = struct { + result1 *livekit.SIPInboundTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPInboundTrunkReturnsOnCall(i int, result1 *livekit.SIPInboundTrunkInfo, result2 error) { + fake.loadSIPInboundTrunkMutex.Lock() + defer fake.loadSIPInboundTrunkMutex.Unlock() + fake.LoadSIPInboundTrunkStub = nil + if fake.loadSIPInboundTrunkReturnsOnCall == nil { + fake.loadSIPInboundTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.SIPInboundTrunkInfo + result2 error + }) + } + fake.loadSIPInboundTrunkReturnsOnCall[i] = struct { + result1 *livekit.SIPInboundTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPOutboundTrunk(arg1 context.Context, arg2 string) (*livekit.SIPOutboundTrunkInfo, error) { + fake.loadSIPOutboundTrunkMutex.Lock() + ret, specificReturn := fake.loadSIPOutboundTrunkReturnsOnCall[len(fake.loadSIPOutboundTrunkArgsForCall)] + fake.loadSIPOutboundTrunkArgsForCall = append(fake.loadSIPOutboundTrunkArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadSIPOutboundTrunkStub + fakeReturns := fake.loadSIPOutboundTrunkReturns + fake.recordInvocation("LoadSIPOutboundTrunk", []interface{}{arg1, arg2}) + fake.loadSIPOutboundTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) LoadSIPOutboundTrunkCallCount() int { + fake.loadSIPOutboundTrunkMutex.RLock() + defer fake.loadSIPOutboundTrunkMutex.RUnlock() + return len(fake.loadSIPOutboundTrunkArgsForCall) +} + +func (fake *FakeSIPStore) LoadSIPOutboundTrunkCalls(stub func(context.Context, string) (*livekit.SIPOutboundTrunkInfo, error)) { + fake.loadSIPOutboundTrunkMutex.Lock() + defer fake.loadSIPOutboundTrunkMutex.Unlock() + fake.LoadSIPOutboundTrunkStub = stub +} + +func (fake *FakeSIPStore) LoadSIPOutboundTrunkArgsForCall(i int) (context.Context, string) { + fake.loadSIPOutboundTrunkMutex.RLock() + defer fake.loadSIPOutboundTrunkMutex.RUnlock() + argsForCall := fake.loadSIPOutboundTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) LoadSIPOutboundTrunkReturns(result1 *livekit.SIPOutboundTrunkInfo, result2 error) { + fake.loadSIPOutboundTrunkMutex.Lock() + defer fake.loadSIPOutboundTrunkMutex.Unlock() + fake.LoadSIPOutboundTrunkStub = nil + fake.loadSIPOutboundTrunkReturns = struct { + result1 *livekit.SIPOutboundTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPOutboundTrunkReturnsOnCall(i int, result1 *livekit.SIPOutboundTrunkInfo, result2 error) { + fake.loadSIPOutboundTrunkMutex.Lock() + defer fake.loadSIPOutboundTrunkMutex.Unlock() + fake.LoadSIPOutboundTrunkStub = nil + if fake.loadSIPOutboundTrunkReturnsOnCall == nil { + fake.loadSIPOutboundTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.SIPOutboundTrunkInfo + result2 error + }) + } + fake.loadSIPOutboundTrunkReturnsOnCall[i] = struct { + result1 *livekit.SIPOutboundTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPTrunk(arg1 context.Context, arg2 string) (*livekit.SIPTrunkInfo, error) { + fake.loadSIPTrunkMutex.Lock() + ret, specificReturn := fake.loadSIPTrunkReturnsOnCall[len(fake.loadSIPTrunkArgsForCall)] + fake.loadSIPTrunkArgsForCall = append(fake.loadSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.LoadSIPTrunkStub + fakeReturns := fake.loadSIPTrunkReturns + fake.recordInvocation("LoadSIPTrunk", []interface{}{arg1, arg2}) + fake.loadSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSIPStore) LoadSIPTrunkCallCount() int { + fake.loadSIPTrunkMutex.RLock() + defer fake.loadSIPTrunkMutex.RUnlock() + return len(fake.loadSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) LoadSIPTrunkCalls(stub func(context.Context, string) (*livekit.SIPTrunkInfo, error)) { + fake.loadSIPTrunkMutex.Lock() + defer fake.loadSIPTrunkMutex.Unlock() + fake.LoadSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) LoadSIPTrunkArgsForCall(i int) (context.Context, string) { + fake.loadSIPTrunkMutex.RLock() + defer fake.loadSIPTrunkMutex.RUnlock() + argsForCall := fake.loadSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) LoadSIPTrunkReturns(result1 *livekit.SIPTrunkInfo, result2 error) { + fake.loadSIPTrunkMutex.Lock() + defer fake.loadSIPTrunkMutex.Unlock() + fake.LoadSIPTrunkStub = nil + fake.loadSIPTrunkReturns = struct { + result1 *livekit.SIPTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) LoadSIPTrunkReturnsOnCall(i int, result1 *livekit.SIPTrunkInfo, result2 error) { + fake.loadSIPTrunkMutex.Lock() + defer fake.loadSIPTrunkMutex.Unlock() + fake.LoadSIPTrunkStub = nil + if fake.loadSIPTrunkReturnsOnCall == nil { + fake.loadSIPTrunkReturnsOnCall = make(map[int]struct { + result1 *livekit.SIPTrunkInfo + result2 error + }) + } + fake.loadSIPTrunkReturnsOnCall[i] = struct { + result1 *livekit.SIPTrunkInfo + result2 error + }{result1, result2} +} + +func (fake *FakeSIPStore) StoreSIPDispatchRule(arg1 context.Context, arg2 *livekit.SIPDispatchRuleInfo) error { + fake.storeSIPDispatchRuleMutex.Lock() + ret, specificReturn := fake.storeSIPDispatchRuleReturnsOnCall[len(fake.storeSIPDispatchRuleArgsForCall)] + fake.storeSIPDispatchRuleArgsForCall = append(fake.storeSIPDispatchRuleArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPDispatchRuleInfo + }{arg1, arg2}) + stub := fake.StoreSIPDispatchRuleStub + fakeReturns := fake.storeSIPDispatchRuleReturns + fake.recordInvocation("StoreSIPDispatchRule", []interface{}{arg1, arg2}) + fake.storeSIPDispatchRuleMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleCallCount() int { + fake.storeSIPDispatchRuleMutex.RLock() + defer fake.storeSIPDispatchRuleMutex.RUnlock() + return len(fake.storeSIPDispatchRuleArgsForCall) +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleCalls(stub func(context.Context, *livekit.SIPDispatchRuleInfo) error) { + fake.storeSIPDispatchRuleMutex.Lock() + defer fake.storeSIPDispatchRuleMutex.Unlock() + fake.StoreSIPDispatchRuleStub = stub +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleArgsForCall(i int) (context.Context, *livekit.SIPDispatchRuleInfo) { + fake.storeSIPDispatchRuleMutex.RLock() + defer fake.storeSIPDispatchRuleMutex.RUnlock() + argsForCall := fake.storeSIPDispatchRuleArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleReturns(result1 error) { + fake.storeSIPDispatchRuleMutex.Lock() + defer fake.storeSIPDispatchRuleMutex.Unlock() + fake.StoreSIPDispatchRuleStub = nil + fake.storeSIPDispatchRuleReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPDispatchRuleReturnsOnCall(i int, result1 error) { + fake.storeSIPDispatchRuleMutex.Lock() + defer fake.storeSIPDispatchRuleMutex.Unlock() + fake.StoreSIPDispatchRuleStub = nil + if fake.storeSIPDispatchRuleReturnsOnCall == nil { + fake.storeSIPDispatchRuleReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeSIPDispatchRuleReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPInboundTrunk(arg1 context.Context, arg2 *livekit.SIPInboundTrunkInfo) error { + fake.storeSIPInboundTrunkMutex.Lock() + ret, specificReturn := fake.storeSIPInboundTrunkReturnsOnCall[len(fake.storeSIPInboundTrunkArgsForCall)] + fake.storeSIPInboundTrunkArgsForCall = append(fake.storeSIPInboundTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPInboundTrunkInfo + }{arg1, arg2}) + stub := fake.StoreSIPInboundTrunkStub + fakeReturns := fake.storeSIPInboundTrunkReturns + fake.recordInvocation("StoreSIPInboundTrunk", []interface{}{arg1, arg2}) + fake.storeSIPInboundTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) StoreSIPInboundTrunkCallCount() int { + fake.storeSIPInboundTrunkMutex.RLock() + defer fake.storeSIPInboundTrunkMutex.RUnlock() + return len(fake.storeSIPInboundTrunkArgsForCall) +} + +func (fake *FakeSIPStore) StoreSIPInboundTrunkCalls(stub func(context.Context, *livekit.SIPInboundTrunkInfo) error) { + fake.storeSIPInboundTrunkMutex.Lock() + defer fake.storeSIPInboundTrunkMutex.Unlock() + fake.StoreSIPInboundTrunkStub = stub +} + +func (fake *FakeSIPStore) StoreSIPInboundTrunkArgsForCall(i int) (context.Context, *livekit.SIPInboundTrunkInfo) { + fake.storeSIPInboundTrunkMutex.RLock() + defer fake.storeSIPInboundTrunkMutex.RUnlock() + argsForCall := fake.storeSIPInboundTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) StoreSIPInboundTrunkReturns(result1 error) { + fake.storeSIPInboundTrunkMutex.Lock() + defer fake.storeSIPInboundTrunkMutex.Unlock() + fake.StoreSIPInboundTrunkStub = nil + fake.storeSIPInboundTrunkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPInboundTrunkReturnsOnCall(i int, result1 error) { + fake.storeSIPInboundTrunkMutex.Lock() + defer fake.storeSIPInboundTrunkMutex.Unlock() + fake.StoreSIPInboundTrunkStub = nil + if fake.storeSIPInboundTrunkReturnsOnCall == nil { + fake.storeSIPInboundTrunkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeSIPInboundTrunkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPOutboundTrunk(arg1 context.Context, arg2 *livekit.SIPOutboundTrunkInfo) error { + fake.storeSIPOutboundTrunkMutex.Lock() + ret, specificReturn := fake.storeSIPOutboundTrunkReturnsOnCall[len(fake.storeSIPOutboundTrunkArgsForCall)] + fake.storeSIPOutboundTrunkArgsForCall = append(fake.storeSIPOutboundTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPOutboundTrunkInfo + }{arg1, arg2}) + stub := fake.StoreSIPOutboundTrunkStub + fakeReturns := fake.storeSIPOutboundTrunkReturns + fake.recordInvocation("StoreSIPOutboundTrunk", []interface{}{arg1, arg2}) + fake.storeSIPOutboundTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) StoreSIPOutboundTrunkCallCount() int { + fake.storeSIPOutboundTrunkMutex.RLock() + defer fake.storeSIPOutboundTrunkMutex.RUnlock() + return len(fake.storeSIPOutboundTrunkArgsForCall) +} + +func (fake *FakeSIPStore) StoreSIPOutboundTrunkCalls(stub func(context.Context, *livekit.SIPOutboundTrunkInfo) error) { + fake.storeSIPOutboundTrunkMutex.Lock() + defer fake.storeSIPOutboundTrunkMutex.Unlock() + fake.StoreSIPOutboundTrunkStub = stub +} + +func (fake *FakeSIPStore) StoreSIPOutboundTrunkArgsForCall(i int) (context.Context, *livekit.SIPOutboundTrunkInfo) { + fake.storeSIPOutboundTrunkMutex.RLock() + defer fake.storeSIPOutboundTrunkMutex.RUnlock() + argsForCall := fake.storeSIPOutboundTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) StoreSIPOutboundTrunkReturns(result1 error) { + fake.storeSIPOutboundTrunkMutex.Lock() + defer fake.storeSIPOutboundTrunkMutex.Unlock() + fake.StoreSIPOutboundTrunkStub = nil + fake.storeSIPOutboundTrunkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPOutboundTrunkReturnsOnCall(i int, result1 error) { + fake.storeSIPOutboundTrunkMutex.Lock() + defer fake.storeSIPOutboundTrunkMutex.Unlock() + fake.StoreSIPOutboundTrunkStub = nil + if fake.storeSIPOutboundTrunkReturnsOnCall == nil { + fake.storeSIPOutboundTrunkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeSIPOutboundTrunkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPTrunk(arg1 context.Context, arg2 *livekit.SIPTrunkInfo) error { + fake.storeSIPTrunkMutex.Lock() + ret, specificReturn := fake.storeSIPTrunkReturnsOnCall[len(fake.storeSIPTrunkArgsForCall)] + fake.storeSIPTrunkArgsForCall = append(fake.storeSIPTrunkArgsForCall, struct { + arg1 context.Context + arg2 *livekit.SIPTrunkInfo + }{arg1, arg2}) + stub := fake.StoreSIPTrunkStub + fakeReturns := fake.storeSIPTrunkReturns + fake.recordInvocation("StoreSIPTrunk", []interface{}{arg1, arg2}) + fake.storeSIPTrunkMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSIPStore) StoreSIPTrunkCallCount() int { + fake.storeSIPTrunkMutex.RLock() + defer fake.storeSIPTrunkMutex.RUnlock() + return len(fake.storeSIPTrunkArgsForCall) +} + +func (fake *FakeSIPStore) StoreSIPTrunkCalls(stub func(context.Context, *livekit.SIPTrunkInfo) error) { + fake.storeSIPTrunkMutex.Lock() + defer fake.storeSIPTrunkMutex.Unlock() + fake.StoreSIPTrunkStub = stub +} + +func (fake *FakeSIPStore) StoreSIPTrunkArgsForCall(i int) (context.Context, *livekit.SIPTrunkInfo) { + fake.storeSIPTrunkMutex.RLock() + defer fake.storeSIPTrunkMutex.RUnlock() + argsForCall := fake.storeSIPTrunkArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeSIPStore) StoreSIPTrunkReturns(result1 error) { + fake.storeSIPTrunkMutex.Lock() + defer fake.storeSIPTrunkMutex.Unlock() + fake.StoreSIPTrunkStub = nil + fake.storeSIPTrunkReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) StoreSIPTrunkReturnsOnCall(i int, result1 error) { + fake.storeSIPTrunkMutex.Lock() + defer fake.storeSIPTrunkMutex.Unlock() + fake.StoreSIPTrunkStub = nil + if fake.storeSIPTrunkReturnsOnCall == nil { + fake.storeSIPTrunkReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.storeSIPTrunkReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeSIPStore) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSIPStore) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ service.SIPStore = new(FakeSIPStore) diff --git a/livekit/pkg/service/signal.go b/livekit/pkg/service/signal.go new file mode 100644 index 0000000..648fec8 --- /dev/null +++ b/livekit/pkg/service/signal.go @@ -0,0 +1,213 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + + "github.com/pkg/errors" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/metadata" + "github.com/livekit/psrpc/pkg/middleware" +) + +//counterfeiter:generate . SessionHandler +type SessionHandler interface { + Logger(ctx context.Context) logger.Logger + + HandleSession( + ctx context.Context, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error +} + +type SignalServer struct { + server rpc.TypedSignalServer + nodeID livekit.NodeID +} + +func NewSignalServer( + nodeID livekit.NodeID, + region string, + bus psrpc.MessageBus, + config config.SignalRelayConfig, + sessionHandler SessionHandler, +) (*SignalServer, error) { + s, err := rpc.NewTypedSignalServer( + nodeID, + &signalService{region, sessionHandler, config}, + bus, + middleware.WithServerMetrics(rpc.PSRPCMetricsObserver{}), + psrpc.WithServerChannelSize(config.StreamBufferSize), + ) + if err != nil { + return nil, err + } + return &SignalServer{s, nodeID}, nil +} + +func NewDefaultSignalServer( + currentNode routing.LocalNode, + bus psrpc.MessageBus, + config config.SignalRelayConfig, + router routing.Router, + roomManager *RoomManager, +) (r *SignalServer, err error) { + return NewSignalServer(currentNode.NodeID(), currentNode.Region(), bus, config, &defaultSessionHandler{currentNode, router, roomManager}) +} + +type defaultSessionHandler struct { + currentNode routing.LocalNode + router routing.Router + roomManager *RoomManager +} + +func (s *defaultSessionHandler) Logger(ctx context.Context) logger.Logger { + return utils.GetLogger(ctx) +} + +func (s *defaultSessionHandler) HandleSession( + ctx context.Context, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, +) error { + prometheus.IncrementParticipantRtcInit(1) + + rtcNode, err := s.router.GetNodeForRoom(ctx, livekit.RoomName(pi.CreateRoom.Name)) + if err != nil { + return err + } + + if livekit.NodeID(rtcNode.Id) != s.currentNode.NodeID() { + err = routing.ErrIncorrectRTCNode + logger.Errorw("called participant on incorrect node", err, + "rtcNode", rtcNode, + ) + return err + } + + return s.roomManager.StartSession(ctx, pi, requestSource, responseSink, false) +} + +func (s *SignalServer) Start() error { + logger.Debugw("starting relay signal server", "topic", s.nodeID) + return s.server.RegisterAllNodeTopics(s.nodeID) +} + +func (r *SignalServer) Stop() { + r.server.Kill() +} + +type signalService struct { + region string + sessionHandler SessionHandler + config config.SignalRelayConfig +} + +func (r *signalService) RelaySignal(stream psrpc.ServerStream[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]) (err error) { + req, ok := <-stream.Channel() + if !ok { + return nil + } + + ss := req.StartSession + if ss == nil { + return errors.New("expected start session message") + } + + pi, err := routing.ParticipantInitFromStartSession(ss, r.region) + if err != nil { + return errors.Wrap(err, "failed to read participant from session") + } + + l := r.sessionHandler.Logger(stream.Context()).WithValues( + "room", ss.RoomName, + "participant", ss.Identity, + "connID", ss.ConnectionId, + ) + + stream.Hijack() + sink := routing.NewSignalMessageSink(routing.SignalSinkParams[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]{ + Logger: l, + Stream: stream, + Config: r.config, + Writer: signalResponseMessageWriter{}, + ConnectionID: livekit.ConnectionID(ss.ConnectionId), + }) + reqChan := routing.NewDefaultMessageChannel(livekit.ConnectionID(ss.ConnectionId)) + + go func() { + err := routing.CopySignalStreamToMessageChannel[*rpc.RelaySignalResponse, *rpc.RelaySignalRequest]( + stream, + reqChan, + signalRequestMessageReader{}, + r.config, + prometheus.RecordSignalRequestSuccess, + prometheus.RecordSignalRequestFailure, + ) + l.Debugw("signal stream closed", "error", err) + + reqChan.Close() + }() + + // copy the context to prevent a race between the session handler closing + // and the delivery of any parting messages from the client. take care to + // copy the incoming rpc headers to avoid dropping any session vars. + ctx := metadata.NewContextWithIncomingHeader(context.Background(), metadata.IncomingHeader(stream.Context())) + err = r.sessionHandler.HandleSession(ctx, *pi, livekit.ConnectionID(ss.ConnectionId), reqChan, sink) + if err != nil { + sink.Close() + l.Errorw("could not handle new participant", err) + } + return +} + +type signalResponseMessageWriter struct{} + +func (e signalResponseMessageWriter) Write(seq uint64, close bool, msgs []proto.Message) *rpc.RelaySignalResponse { + r := &rpc.RelaySignalResponse{ + Seq: seq, + Responses: make([]*livekit.SignalResponse, 0, len(msgs)), + Close: close, + } + for _, m := range msgs { + r.Responses = append(r.Responses, m.(*livekit.SignalResponse)) + } + return r +} + +type signalRequestMessageReader struct{} + +func (e signalRequestMessageReader) Read(rm *rpc.RelaySignalRequest) ([]proto.Message, error) { + msgs := make([]proto.Message, 0, len(rm.Requests)) + for _, m := range rm.Requests { + msgs = append(msgs, m) + } + return msgs, nil +} diff --git a/livekit/pkg/service/signal_test.go b/livekit/pkg/service/signal_test.go new file mode 100644 index 0000000..81755cd --- /dev/null +++ b/livekit/pkg/service/signal_test.go @@ -0,0 +1,156 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/service/servicefakes" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/psrpc" +) + +func init() { + prometheus.Init("node", livekit.NodeType_CONTROLLER) +} + +func TestSignal(t *testing.T) { + cfg := config.SignalRelayConfig{ + RetryTimeout: 30 * time.Second, + MinRetryInterval: 500 * time.Millisecond, + MaxRetryInterval: 5 * time.Second, + StreamBufferSize: 1000, + } + + t.Run("messages are delivered", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() + + reqMessageIn := &livekit.SignalRequest{ + Message: &livekit.SignalRequest_Ping{Ping: 123}, + } + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } + + var reqMessageOut proto.Message + var resErr error + done := make(chan struct{}) + + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) + + handler := &servicefakes.FakeSessionHandler{ + LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() }, + HandleSessionStub: func( + ctx context.Context, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + go func() { + reqMessageOut = <-requestSource.ReadChan() + resErr = responseSink.WriteMessage(resMessageIn) + responseSink.Close() + close(done) + }() + return nil + }, + } + server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler) + require.NoError(t, err) + + err = server.Start() + require.NoError(t, err) + + _, reqSink, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + err = reqSink.WriteMessage(reqMessageIn) + require.NoError(t, err) + + <-done + require.True(t, proto.Equal(reqMessageIn, reqMessageOut), "req message should match %s %s", protojson.Format(reqMessageIn), protojson.Format(reqMessageOut)) + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) + }) + + t.Run("messages are delivered when session handler fails", func(t *testing.T) { + bus := psrpc.NewLocalMessageBus() + + resMessageIn := &livekit.SignalResponse{ + Message: &livekit.SignalResponse_Pong{Pong: 321}, + } + + var resErr error + done := make(chan struct{}) + + client, err := routing.NewSignalClient(livekit.NodeID("node0"), bus, cfg) + require.NoError(t, err) + + handler := &servicefakes.FakeSessionHandler{ + LoggerStub: func(context.Context) logger.Logger { return logger.GetLogger() }, + HandleSessionStub: func( + ctx context.Context, + pi routing.ParticipantInit, + connectionID livekit.ConnectionID, + requestSource routing.MessageSource, + responseSink routing.MessageSink, + ) error { + defer close(done) + resErr = responseSink.WriteMessage(resMessageIn) + return errors.New("start session failed") + }, + } + server, err := service.NewSignalServer(livekit.NodeID("node1"), "region", bus, cfg, handler) + require.NoError(t, err) + + err = server.Start() + require.NoError(t, err) + + _, _, resSource, err := client.StartParticipantSignal( + context.Background(), + livekit.RoomName("room1"), + routing.ParticipantInit{}, + livekit.NodeID("node1"), + ) + require.NoError(t, err) + + <-done + require.NoError(t, resErr) + + resMessageOut := <-resSource.ReadChan() + require.True(t, proto.Equal(resMessageIn, resMessageOut), "res message should match %s %s", protojson.Format(resMessageIn), protojson.Format(resMessageOut)) + }) +} diff --git a/livekit/pkg/service/sip.go b/livekit/pkg/service/sip.go new file mode 100644 index 0000000..6f68345 --- /dev/null +++ b/livekit/pkg/service/sip.go @@ -0,0 +1,730 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "errors" + "time" + + "github.com/dennwc/iters" + "github.com/twitchtv/twirp" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/sip" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +type SIPService struct { + conf *config.SIPConfig + nodeID livekit.NodeID + bus psrpc.MessageBus + psrpcClient rpc.SIPClient + store SIPStore + roomService livekit.RoomService +} + +func NewSIPService( + conf *config.SIPConfig, + nodeID livekit.NodeID, + bus psrpc.MessageBus, + psrpcClient rpc.SIPClient, + store SIPStore, + rs livekit.RoomService, + ts telemetry.TelemetryService, +) *SIPService { + return &SIPService{ + conf: conf, + nodeID: nodeID, + bus: bus, + psrpcClient: psrpcClient, + store: store, + roomService: rs, + } +} + +func (s *SIPService) CreateSIPTrunk(ctx context.Context, req *livekit.CreateSIPTrunkRequest) (*livekit.SIPTrunkInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if len(req.InboundNumbersRegex) != 0 { + return nil, twirp.NewError(twirp.InvalidArgument, "Trunks with InboundNumbersRegex are deprecated. Use InboundNumbers instead.") + } + + // Keep ID empty, so that validation can print "" instead of a non-existent ID in the error. + info := &livekit.SIPTrunkInfo{ + InboundAddresses: req.InboundAddresses, + OutboundAddress: req.OutboundAddress, + OutboundNumber: req.OutboundNumber, + InboundNumbers: req.InboundNumbers, + InboundUsername: req.InboundUsername, + InboundPassword: req.InboundPassword, + OutboundUsername: req.OutboundUsername, + OutboundPassword: req.OutboundPassword, + Name: req.Name, + Metadata: req.Metadata, + } + if err := info.Validate(); err != nil { + return nil, err + } + + // Validate all trunks including the new one first. + it, err := ListSIPInboundTrunk(ctx, s.store, &livekit.ListSIPInboundTrunkRequest{}, info.AsInbound()) + if err != nil { + return nil, err + } + defer it.Close() + if err = sip.ValidateTrunksIter(it); err != nil { + return nil, err + } + + // Now we can generate ID and store. + info.SipTrunkId = guid.New(utils.SIPTrunkPrefix) + if err := s.store.StoreSIPTrunk(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func (s *SIPService) CreateSIPInboundTrunk(ctx context.Context, req *livekit.CreateSIPInboundTrunkRequest) (*livekit.SIPInboundTrunkInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, twirp.WrapError(twirp.NewError(twirp.InvalidArgument, err.Error()), err) + } + + info := req.Trunk + if info.SipTrunkId != "" { + return nil, twirp.NewError(twirp.InvalidArgument, "trunk ID must be empty") + } + AppendLogFields(ctx, "trunk", logger.Proto(info)) + + // Keep ID empty still, so that validation can print "" instead of a non-existent ID in the error. + + // Validate all trunks including the new one first. + it, err := ListSIPInboundTrunk(ctx, s.store, &livekit.ListSIPInboundTrunkRequest{ + Numbers: req.GetTrunk().GetNumbers(), + }, info) + if err != nil { + return nil, err + } + defer it.Close() + if err = sip.ValidateTrunksIter(it); err != nil { + return nil, err + } + + // Now we can generate ID and store. + info.SipTrunkId = guid.New(utils.SIPTrunkPrefix) + if err := s.store.StoreSIPInboundTrunk(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func (s *SIPService) CreateSIPOutboundTrunk(ctx context.Context, req *livekit.CreateSIPOutboundTrunkRequest) (*livekit.SIPOutboundTrunkInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, twirp.WrapError(twirp.NewError(twirp.InvalidArgument, err.Error()), err) + } + + info := req.Trunk + if info.SipTrunkId != "" { + return nil, twirp.NewError(twirp.InvalidArgument, "trunk ID must be empty") + } + AppendLogFields(ctx, "trunk", logger.Proto(info)) + + // No additional validation needed for outbound. + info.SipTrunkId = guid.New(utils.SIPTrunkPrefix) + if err := s.store.StoreSIPOutboundTrunk(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func (s *SIPService) UpdateSIPInboundTrunk(ctx context.Context, req *livekit.UpdateSIPInboundTrunkRequest) (*livekit.SIPInboundTrunkInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, err + } + + AppendLogFields(ctx, + "request", logger.Proto(req), + "trunkID", req.SipTrunkId, + ) + + // Validate all trunks including the new one first. + info, err := s.store.LoadSIPInboundTrunk(ctx, req.SipTrunkId) + if err != nil { + return nil, err + } + switch a := req.Action.(type) { + default: + return nil, errors.New("missing or unsupported action") + case livekit.UpdateSIPInboundTrunkRequestAction: + info, err = a.Apply(info) + if err != nil { + return nil, err + } + } + + it, err := ListSIPInboundTrunk(ctx, s.store, &livekit.ListSIPInboundTrunkRequest{ + Numbers: info.Numbers, + }) + if err != nil { + return nil, err + } + defer it.Close() + if err = sip.ValidateTrunksIter(it, sip.WithTrunkReplace(func(t *livekit.SIPInboundTrunkInfo) *livekit.SIPInboundTrunkInfo { + if req.SipTrunkId == t.SipTrunkId { + return info // updated one + } + return t + })); err != nil { + return nil, err + } + if err := s.store.StoreSIPInboundTrunk(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func (s *SIPService) UpdateSIPOutboundTrunk(ctx context.Context, req *livekit.UpdateSIPOutboundTrunkRequest) (*livekit.SIPOutboundTrunkInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, err + } + + AppendLogFields(ctx, + "request", logger.Proto(req), + "trunkID", req.SipTrunkId, + ) + + info, err := s.store.LoadSIPOutboundTrunk(ctx, req.SipTrunkId) + if err != nil { + return nil, err + } + switch a := req.Action.(type) { + default: + return nil, errors.New("missing or unsupported action") + case livekit.UpdateSIPOutboundTrunkRequestAction: + info, err = a.Apply(info) + if err != nil { + return nil, err + } + } + // No additional validation needed for outbound. + if err := s.store.StoreSIPOutboundTrunk(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func (s *SIPService) GetSIPInboundTrunk(ctx context.Context, req *livekit.GetSIPInboundTrunkRequest) (*livekit.GetSIPInboundTrunkResponse, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if req.SipTrunkId == "" { + return nil, twirp.NewError(twirp.InvalidArgument, "trunk ID is required") + } + AppendLogFields(ctx, "trunkID", req.SipTrunkId) + + trunk, err := s.store.LoadSIPInboundTrunk(ctx, req.SipTrunkId) + if err != nil { + return nil, err + } + + return &livekit.GetSIPInboundTrunkResponse{Trunk: trunk}, nil +} + +func (s *SIPService) GetSIPOutboundTrunk(ctx context.Context, req *livekit.GetSIPOutboundTrunkRequest) (*livekit.GetSIPOutboundTrunkResponse, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if req.SipTrunkId == "" { + return nil, twirp.NewError(twirp.InvalidArgument, "trunk ID is required") + } + AppendLogFields(ctx, "trunkID", req.SipTrunkId) + + trunk, err := s.store.LoadSIPOutboundTrunk(ctx, req.SipTrunkId) + if err != nil { + return nil, err + } + + return &livekit.GetSIPOutboundTrunkResponse{Trunk: trunk}, nil +} + +// deprecated: ListSIPTrunk will be removed in the future +func (s *SIPService) ListSIPTrunk(ctx context.Context, req *livekit.ListSIPTrunkRequest) (*livekit.ListSIPTrunkResponse, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + it := livekit.ListPageIter(s.store.ListSIPTrunk, req) + defer it.Close() + + items, err := iters.AllPages(ctx, it) + if err != nil { + return nil, err + } + return &livekit.ListSIPTrunkResponse{Items: items}, nil +} + +func ListSIPInboundTrunk(ctx context.Context, s SIPStore, req *livekit.ListSIPInboundTrunkRequest, add ...*livekit.SIPInboundTrunkInfo) (iters.Iter[*livekit.SIPInboundTrunkInfo], error) { + if s == nil { + return nil, ErrSIPNotConnected + } + pages := livekit.ListPageIter(s.ListSIPInboundTrunk, req) + it := iters.PagesAsIter(ctx, pages) + if len(add) != 0 { + it = iters.MultiIter(true, it, iters.Slice(add)) + } + return it, nil +} + +func ListSIPOutboundTrunk(ctx context.Context, s SIPStore, req *livekit.ListSIPOutboundTrunkRequest, add ...*livekit.SIPOutboundTrunkInfo) (iters.Iter[*livekit.SIPOutboundTrunkInfo], error) { + if s == nil { + return nil, ErrSIPNotConnected + } + pages := livekit.ListPageIter(s.ListSIPOutboundTrunk, req) + it := iters.PagesAsIter(ctx, pages) + if len(add) != 0 { + it = iters.MultiIter(true, it, iters.Slice(add)) + } + return it, nil +} + +func (s *SIPService) ListSIPInboundTrunk(ctx context.Context, req *livekit.ListSIPInboundTrunkRequest) (*livekit.ListSIPInboundTrunkResponse, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + it, err := ListSIPInboundTrunk(ctx, s.store, req) + if err != nil { + return nil, err + } + defer it.Close() + + items, err := iters.All(it) + if err != nil { + return nil, err + } + return &livekit.ListSIPInboundTrunkResponse{Items: items}, nil +} + +func (s *SIPService) ListSIPOutboundTrunk(ctx context.Context, req *livekit.ListSIPOutboundTrunkRequest) (*livekit.ListSIPOutboundTrunkResponse, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + it, err := ListSIPOutboundTrunk(ctx, s.store, req) + if err != nil { + return nil, err + } + defer it.Close() + + items, err := iters.All(it) + if err != nil { + return nil, err + } + return &livekit.ListSIPOutboundTrunkResponse{Items: items}, nil +} + +func (s *SIPService) DeleteSIPTrunk(ctx context.Context, req *livekit.DeleteSIPTrunkRequest) (*livekit.SIPTrunkInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if req.SipTrunkId == "" { + return nil, twirp.NewError(twirp.InvalidArgument, "trunk ID is required") + } + + AppendLogFields(ctx, "trunkID", req.SipTrunkId) + if err := s.store.DeleteSIPTrunk(ctx, req.SipTrunkId); err != nil { + return nil, err + } + + return &livekit.SIPTrunkInfo{SipTrunkId: req.SipTrunkId}, nil +} + +func (s *SIPService) CreateSIPDispatchRule(ctx context.Context, req *livekit.CreateSIPDispatchRuleRequest) (*livekit.SIPDispatchRuleInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, twirp.WrapError(twirp.NewError(twirp.InvalidArgument, err.Error()), err) + } + + AppendLogFields(ctx, + "request", logger.Proto(req), + "trunkID", req.TrunkIds, + ) + // Keep ID empty, so that validation can print "" instead of a non-existent ID in the error. + info := req.DispatchRuleInfo() + info.SipDispatchRuleId = "" + + // Validate all rules including the new one first. + it, err := ListSIPDispatchRule(ctx, s.store, &livekit.ListSIPDispatchRuleRequest{ + TrunkIds: req.TrunkIds, + }, info) + if err != nil { + return nil, err + } + defer it.Close() + if _, err = sip.ValidateDispatchRulesIter(it); err != nil { + return nil, err + } + + // Now we can generate ID and store. + info.SipDispatchRuleId = guid.New(utils.SIPDispatchRulePrefix) + if err := s.store.StoreSIPDispatchRule(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func (s *SIPService) UpdateSIPDispatchRule(ctx context.Context, req *livekit.UpdateSIPDispatchRuleRequest) (*livekit.SIPDispatchRuleInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, err + } + + AppendLogFields(ctx, + "request", logger.Proto(req), + "ruleID", req.SipDispatchRuleId, + ) + + // Validate all trunks including the new one first. + info, err := s.store.LoadSIPDispatchRule(ctx, req.SipDispatchRuleId) + if err != nil { + return nil, err + } + switch a := req.Action.(type) { + default: + return nil, errors.New("missing or unsupported action") + case livekit.UpdateSIPDispatchRuleRequestAction: + info, err = a.Apply(info) + if err != nil { + return nil, err + } + } + + it, err := ListSIPDispatchRule(ctx, s.store, &livekit.ListSIPDispatchRuleRequest{ + TrunkIds: info.TrunkIds, + }) + if err != nil { + return nil, err + } + defer it.Close() + if _, err = sip.ValidateDispatchRulesIter(it, sip.WithDispatchRuleReplace(func(t *livekit.SIPDispatchRuleInfo) *livekit.SIPDispatchRuleInfo { + if req.SipDispatchRuleId == t.SipDispatchRuleId { + return info // updated one + } + return t + })); err != nil { + return nil, err + } + + if err := s.store.StoreSIPDispatchRule(ctx, info); err != nil { + return nil, err + } + return info, nil +} + +func ListSIPDispatchRule(ctx context.Context, s SIPStore, req *livekit.ListSIPDispatchRuleRequest, add ...*livekit.SIPDispatchRuleInfo) (iters.Iter[*livekit.SIPDispatchRuleInfo], error) { + if s == nil { + return nil, ErrSIPNotConnected + } + pages := livekit.ListPageIter(s.ListSIPDispatchRule, req) + it := iters.PagesAsIter(ctx, pages) + if len(add) != 0 { + it = iters.MultiIter(true, it, iters.Slice(add)) + } + return it, nil +} + +func (s *SIPService) ListSIPDispatchRule(ctx context.Context, req *livekit.ListSIPDispatchRuleRequest) (*livekit.ListSIPDispatchRuleResponse, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + it, err := ListSIPDispatchRule(ctx, s.store, req) + if err != nil { + return nil, err + } + defer it.Close() + + items, err := iters.All(it) + if err != nil { + return nil, err + } + return &livekit.ListSIPDispatchRuleResponse{Items: items}, nil +} + +func (s *SIPService) DeleteSIPDispatchRule(ctx context.Context, req *livekit.DeleteSIPDispatchRuleRequest) (*livekit.SIPDispatchRuleInfo, error) { + if err := EnsureSIPAdminPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if req.SipDispatchRuleId == "" { + return nil, twirp.NewError(twirp.InvalidArgument, "dispatch rule ID is required") + } + + info, err := s.store.LoadSIPDispatchRule(ctx, req.SipDispatchRuleId) + if err != nil { + return nil, err + } + + if err = s.store.DeleteSIPDispatchRule(ctx, info.SipDispatchRuleId); err != nil { + return nil, err + } + + return info, nil +} + +func (s *SIPService) CreateSIPParticipant(ctx context.Context, req *livekit.CreateSIPParticipantRequest) (*livekit.SIPParticipantInfo, error) { + unlikelyLogger := logger.GetLogger().WithUnlikelyValues( + "room", req.RoomName, + "sipTrunk", req.SipTrunkId, + "toUser", req.SipCallTo, + "participant", req.ParticipantIdentity, + ) + AppendLogFields(ctx, + "room", req.RoomName, + "participant", req.ParticipantIdentity, + "toUser", req.SipCallTo, + "trunkID", req.SipTrunkId, + ) + ireq, err := s.CreateSIPParticipantRequest(ctx, req, "", "", "", "") + if err != nil { + unlikelyLogger.Errorw("cannot create sip participant request", err) + return nil, err + } + unlikelyLogger = unlikelyLogger.WithValues( + "callID", ireq.SipCallId, + "fromUser", ireq.Number, + "toHost", ireq.Address, + ) + AppendLogFields(ctx, + "callID", ireq.SipCallId, + "fromUser", ireq.Number, + "toHost", ireq.Address, + ) + + // CreateSIPParticipant will wait for LiveKit Participant to be created and that can take some time. + // Thus, we must set a higher deadline for it, if it's not set already. + timeout := 30 * time.Second + if req.WaitUntilAnswered { + timeout = 80 * time.Second + } + if deadline, ok := ctx.Deadline(); ok { + timeout = time.Until(deadline) + } else { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + resp, err := s.psrpcClient.CreateSIPParticipant(ctx, "", ireq, psrpc.WithRequestTimeout(timeout)) + if err != nil { + unlikelyLogger.Errorw("cannot create sip participant", err) + return nil, err + } + return &livekit.SIPParticipantInfo{ + ParticipantId: resp.ParticipantId, + ParticipantIdentity: resp.ParticipantIdentity, + RoomName: req.RoomName, + SipCallId: ireq.SipCallId, + }, nil +} + +func (s *SIPService) CreateSIPParticipantRequest(ctx context.Context, req *livekit.CreateSIPParticipantRequest, projectID, host, wsUrl, token string) (*rpc.InternalCreateSIPParticipantRequest, error) { + if err := EnsureSIPCallPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if s.store == nil { + return nil, ErrSIPNotConnected + } + if err := req.Validate(); err != nil { + return nil, twirp.WrapError(twirp.NewError(twirp.InvalidArgument, err.Error()), err) + } + callID := sip.NewCallID() + log := logger.GetLogger().WithUnlikelyValues( + "callID", callID, + "room", req.RoomName, + "sipTrunk", req.SipTrunkId, + "toUser", req.SipCallTo, + ) + if projectID != "" { + log = log.WithValues("projectID", projectID) + } + + var trunk *livekit.SIPOutboundTrunkInfo + if req.SipTrunkId != "" { + var err error + trunk, err = s.store.LoadSIPOutboundTrunk(ctx, req.SipTrunkId) + if err != nil { + log.Errorw("cannot get trunk to update sip participant", err) + return nil, err + } + } + return rpc.NewCreateSIPParticipantRequest(projectID, callID, host, wsUrl, token, req, trunk) +} + +func (s *SIPService) TransferSIPParticipant(ctx context.Context, req *livekit.TransferSIPParticipantRequest) (*emptypb.Empty, error) { + log := logger.GetLogger().WithUnlikelyValues( + "room", req.RoomName, + "participant", req.ParticipantIdentity, + "transferTo", req.TransferTo, + "playDialtone", req.PlayDialtone, + ) + AppendLogFields(ctx, + "room", req.RoomName, + "participant", req.ParticipantIdentity, + "transferTo", req.TransferTo, + "playDialtone", req.PlayDialtone, + ) + + ireq, err := s.transferSIPParticipantRequest(ctx, req) + if err != nil { + log.Errorw("cannot create transfer sip participant request", err) + return nil, err + } + + // by default we set the timeout to be 30 seconds. + // this timeout covers: + // - a network failure between this process and the LiveKit SIP bridge + // - the SIP transfer target not returning 200 OK fast enough. + // WARN: any timeout/cancellation of a SIP transfer risks leaving + // either the SIP bridge, or the SIP REFER exchange, in a "unknown" state. + timeout := 30 * time.Second + if req.RingingTimeout != nil { + timeout = req.RingingTimeout.AsDuration() + } + + // it's also possible the ctx has a Deadline. + // in that case we want to use that deadline, + // or our timeout, whichover is soonest. + if deadline, ok := ctx.Deadline(); ok { + timeout = min(timeout, time.Until(deadline)) + } else { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + _, err = s.psrpcClient.TransferSIPParticipant(ctx, ireq.SipCallId, ireq, psrpc.WithRequestTimeout(timeout)) + if err != nil { + log.Errorw("cannot transfer sip participant", err) + return nil, err + } + + return &emptypb.Empty{}, nil +} + +func (s *SIPService) transferSIPParticipantRequest(ctx context.Context, req *livekit.TransferSIPParticipantRequest) (*rpc.InternalTransferSIPParticipantRequest, error) { + if req.RoomName == "" { + return nil, psrpc.NewErrorf(psrpc.InvalidArgument, "Missing room name") + } + + if req.ParticipantIdentity == "" { + return nil, psrpc.NewErrorf(psrpc.InvalidArgument, "Missing participant identity") + } + + if err := EnsureSIPCallPermission(ctx); err != nil { + return nil, twirpAuthError(err) + } + if err := EnsureAdminPermission(ctx, livekit.RoomName(req.RoomName)); err != nil { + return nil, twirpAuthError(err) + } + if err := req.Validate(); err != nil { + return nil, err + } + + resp, err := s.roomService.GetParticipant(ctx, &livekit.RoomParticipantIdentity{ + Room: req.RoomName, + Identity: req.ParticipantIdentity, + }) + + if err != nil { + return nil, err + } + + callID, ok := resp.Attributes[livekit.AttrSIPCallID] + if !ok { + return nil, psrpc.NewErrorf(psrpc.InvalidArgument, "no SIP session associated with participant") + } + + return &rpc.InternalTransferSIPParticipantRequest{ + SipCallId: callID, + TransferTo: req.TransferTo, + PlayDialtone: req.PlayDialtone, + Headers: req.Headers, + RingingTimeout: req.RingingTimeout, + }, nil +} diff --git a/livekit/pkg/service/turn.go b/livekit/pkg/service/turn.go new file mode 100644 index 0000000..2e26b1a --- /dev/null +++ b/livekit/pkg/service/turn.go @@ -0,0 +1,205 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "crypto/sha256" + "crypto/tls" + "fmt" + "net" + "strconv" + "strings" + + "github.com/jxskiss/base62" + "github.com/pion/turn/v4" + "github.com/pkg/errors" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/logger/pionlogger" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" +) + +const ( + LivekitRealm = "livekit" + + allocateRetries = 50 + turnMinPort = 1024 + turnMaxPort = 30000 +) + +func NewTurnServer(conf *config.Config, authHandler turn.AuthHandler, standalone bool) (*turn.Server, error) { + turnConf := conf.TURN + if !turnConf.Enabled { + return nil, nil + } + + if turnConf.TLSPort <= 0 && turnConf.UDPPort <= 0 { + return nil, errors.New("invalid TURN ports") + } + + serverConfig := turn.ServerConfig{ + Realm: LivekitRealm, + AuthHandler: authHandler, + LoggerFactory: pionlogger.NewLoggerFactory(logger.GetLogger()), + } + var relayAddrGen turn.RelayAddressGenerator = &turn.RelayAddressGeneratorPortRange{ + RelayAddress: net.ParseIP(conf.RTC.NodeIP), + Address: "0.0.0.0", + MinPort: turnConf.RelayPortRangeStart, + MaxPort: turnConf.RelayPortRangeEnd, + MaxRetries: allocateRetries, + } + if standalone { + relayAddrGen = telemetry.NewRelayAddressGenerator(relayAddrGen) + } + var logValues []any + + logValues = append(logValues, "turn.relay_range_start", turnConf.RelayPortRangeStart) + logValues = append(logValues, "turn.relay_range_end", turnConf.RelayPortRangeEnd) + + if turnConf.TLSPort > 0 { + if turnConf.Domain == "" { + return nil, errors.New("TURN domain required") + } + + if !IsValidDomain(turnConf.Domain) { + return nil, errors.New("TURN domain is not correct") + } + + if !turnConf.ExternalTLS { + cert, err := tls.LoadX509KeyPair(turnConf.CertFile, turnConf.KeyFile) + if err != nil { + return nil, errors.Wrap(err, "TURN tls cert required") + } + + tlsListener, err := tls.Listen("tcp4", "0.0.0.0:"+strconv.Itoa(turnConf.TLSPort), + &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }) + if err != nil { + return nil, errors.Wrap(err, "could not listen on TURN TCP port") + } + if standalone { + tlsListener = telemetry.NewListener(tlsListener) + } + + listenerConfig := turn.ListenerConfig{ + Listener: tlsListener, + RelayAddressGenerator: relayAddrGen, + } + serverConfig.ListenerConfigs = append(serverConfig.ListenerConfigs, listenerConfig) + } else { + tcpListener, err := net.Listen("tcp4", "0.0.0.0:"+strconv.Itoa(turnConf.TLSPort)) + if err != nil { + return nil, errors.Wrap(err, "could not listen on TURN TCP port") + } + if standalone { + tcpListener = telemetry.NewListener(tcpListener) + } + + listenerConfig := turn.ListenerConfig{ + Listener: tcpListener, + RelayAddressGenerator: relayAddrGen, + } + serverConfig.ListenerConfigs = append(serverConfig.ListenerConfigs, listenerConfig) + } + logValues = append(logValues, "turn.portTLS", turnConf.TLSPort, "turn.externalTLS", turnConf.ExternalTLS) + } + + if turnConf.UDPPort > 0 { + udpListener, err := net.ListenPacket("udp4", "0.0.0.0:"+strconv.Itoa(turnConf.UDPPort)) + if err != nil { + return nil, errors.Wrap(err, "could not listen on TURN UDP port") + } + + if standalone { + udpListener = telemetry.NewPacketConn(udpListener, prometheus.Incoming) + } + + packetConfig := turn.PacketConnConfig{ + PacketConn: udpListener, + RelayAddressGenerator: relayAddrGen, + } + serverConfig.PacketConnConfigs = append(serverConfig.PacketConnConfigs, packetConfig) + logValues = append(logValues, "turn.portUDP", turnConf.UDPPort) + } + + logger.Infow("Starting TURN server", logValues...) + return turn.NewServer(serverConfig) +} + +func getTURNAuthHandlerFunc(handler *TURNAuthHandler) turn.AuthHandler { + return handler.HandleAuth +} + +type TURNAuthHandler struct { + keyProvider auth.KeyProvider +} + +func NewTURNAuthHandler(keyProvider auth.KeyProvider) *TURNAuthHandler { + return &TURNAuthHandler{ + keyProvider: keyProvider, + } +} + +func (h *TURNAuthHandler) CreateUsername(apiKey string, pID livekit.ParticipantID) string { + return base62.EncodeToString([]byte(fmt.Sprintf("%s|%s", apiKey, pID))) +} + +func (h *TURNAuthHandler) ParseUsername(username string) (apiKey string, pID livekit.ParticipantID, err error) { + decoded, err := base62.DecodeString(username) + if err != nil { + return "", "", err + } + parts := strings.Split(string(decoded), "|") + if len(parts) != 2 { + return "", "", errors.New("invalid username") + } + + return parts[0], livekit.ParticipantID(parts[1]), nil +} + +func (h *TURNAuthHandler) CreatePassword(apiKey string, pID livekit.ParticipantID) (string, error) { + secret := h.keyProvider.GetSecret(apiKey) + if secret == "" { + return "", ErrInvalidAPIKey + } + keyInput := fmt.Sprintf("%s|%s", secret, pID) + sum := sha256.Sum256([]byte(keyInput)) + return base62.EncodeToString(sum[:]), nil +} + +func (h *TURNAuthHandler) HandleAuth(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { + decoded, err := base62.DecodeString(username) + if err != nil { + return nil, false + } + parts := strings.Split(string(decoded), "|") + if len(parts) != 2 { + return nil, false + } + password, err := h.CreatePassword(parts[0], livekit.ParticipantID(parts[1])) + if err != nil { + logger.Warnw("could not create TURN password", err, "username", username) + return nil, false + } + return turn.GenerateAuthKey(username, LivekitRealm, password), true +} diff --git a/livekit/pkg/service/twirp.go b/livekit/pkg/service/twirp.go new file mode 100644 index 0000000..e41d13c --- /dev/null +++ b/livekit/pkg/service/twirp.go @@ -0,0 +1,431 @@ +/* + * Copyright 2022 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "context" + "strconv" + "sync" + "time" + + "github.com/twitchtv/twirp" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" +) + +type twirpRequestFields struct { + service string + method string + error twirp.Error +} + +// -------------------------------------------------------------------------- + +type twirpLoggerKey struct{} + +// logging handling inspired by https://github.com/bakins/twirpzap +// License: Apache-2.0 +func TwirpLogger() *twirp.ServerHooks { + loggerPool := &sync.Pool{ + New: func() any { + return &twirpLogger{ + fieldsOrig: make([]any, 0, 30), + } + }, + } + return &twirp.ServerHooks{ + RequestReceived: func(ctx context.Context) (context.Context, error) { + return loggerRequestReceived(ctx, loggerPool) + }, + RequestRouted: loggerRequestRouted, + Error: loggerErrorReceived, + ResponseSent: func(ctx context.Context) { + loggerResponseSent(ctx, loggerPool) + }, + } +} + +type twirpLogger struct { + twirpRequestFields + + fieldsOrig []any + fields []any + startedAt time.Time + deadline time.Time +} + +func (t *twirpLogger) reset() { + t.fields = t.fieldsOrig + t.error = nil + t.startedAt = time.Time{} + t.deadline = time.Time{} +} + +func AppendLogFields(ctx context.Context, fields ...any) { + r, ok := ctx.Value(twirpLoggerKey{}).(*twirpLogger) + if !ok || r == nil { + return + } + + r.fields = append(r.fields, fields...) +} + +func loggerRequestReceived(ctx context.Context, twirpLoggerPool *sync.Pool) (context.Context, error) { + r := twirpLoggerPool.Get().(*twirpLogger) + r.startedAt = time.Now() + if deadline, ok := ctx.Deadline(); ok { + r.deadline = deadline + } + r.fields = r.fieldsOrig + r.error = nil + + if svc, ok := twirp.ServiceName(ctx); ok { + r.service = svc + r.fields = append(r.fields, "service", svc) + } + + return context.WithValue(ctx, twirpLoggerKey{}, r), nil +} + +func loggerRequestRouted(ctx context.Context) (context.Context, error) { + if meth, ok := twirp.MethodName(ctx); ok { + l, ok := ctx.Value(twirpLoggerKey{}).(*twirpLogger) + if !ok || l == nil { + return ctx, nil + } + l.method = meth + l.fields = append(l.fields, "method", meth) + } + + return ctx, nil +} + +func loggerResponseSent(ctx context.Context, twirpLoggerPool *sync.Pool) { + r, ok := ctx.Value(twirpLoggerKey{}).(*twirpLogger) + if !ok || r == nil { + return + } + + r.fields = append(r.fields, "duration", time.Since(r.startedAt)) + if !r.deadline.IsZero() { + r.fields = append(r.fields, "requestedTimeout", r.deadline.Sub(r.startedAt)) + } + if deadline, ok := ctx.Deadline(); ok { + r.fields = append(r.fields, "modifiedTimeout", deadline.Sub(r.startedAt)) + } + + if status, ok := twirp.StatusCode(ctx); ok { + r.fields = append(r.fields, "status", status) + } + if r.error != nil { + r.fields = append(r.fields, "error", r.error.Msg()) + r.fields = append(r.fields, "code", r.error.Code()) + } + + serviceMethod := "API " + r.service + "." + r.method + utils.GetLogger(ctx).WithComponent(utils.ComponentAPI).Infow(serviceMethod, r.fields...) + + // reset fields and return to pool + r.reset() + twirpLoggerPool.Put(r) +} + +func loggerErrorReceived(ctx context.Context, e twirp.Error) context.Context { + r, ok := ctx.Value(twirpLoggerKey{}).(*twirpLogger) + if !ok || r == nil { + return ctx + } + + r.error = e + return ctx +} + +// -------------------------------------------------------------------------- + +type statusReporterKey struct{} + +func TwirpRequestStatusReporter() *twirp.ServerHooks { + return &twirp.ServerHooks{ + RequestReceived: statusReporterRequestReceived, + RequestRouted: statusReporterRequestRouted, + Error: statusReporterErrorReceived, + ResponseSent: statusReporterResponseSent, + } +} + +func statusReporterRequestReceived(ctx context.Context) (context.Context, error) { + r := &twirpRequestFields{} + + if svc, ok := twirp.ServiceName(ctx); ok { + r.service = svc + } + + return context.WithValue(ctx, statusReporterKey{}, r), nil +} + +func statusReporterRequestRouted(ctx context.Context) (context.Context, error) { + if meth, ok := twirp.MethodName(ctx); ok { + l, ok := ctx.Value(statusReporterKey{}).(*twirpRequestFields) + if !ok || l == nil { + return ctx, nil + } + l.method = meth + } + + return ctx, nil +} + +func statusReporterResponseSent(ctx context.Context) { + r, ok := ctx.Value(statusReporterKey{}).(*twirpRequestFields) + if !ok || r == nil { + return + } + + var statusFamily string + if statusCode, ok := twirp.StatusCode(ctx); ok { + if status, err := strconv.Atoi(statusCode); err == nil { + switch { + case status >= 400 && status <= 499: + statusFamily = "4xx" + case status >= 500 && status <= 599: + statusFamily = "5xx" + default: + statusFamily = statusCode + } + } + } + + var code twirp.ErrorCode + if r.error != nil { + code = r.error.Code() + } + + prometheus.RecordTwirpRequestStatus(r.service, r.method, statusFamily, code) +} + +func statusReporterErrorReceived(ctx context.Context, e twirp.Error) context.Context { + r, ok := ctx.Value(statusReporterKey{}).(*twirpRequestFields) + if !ok || r == nil { + return ctx + } + + r.error = e + return ctx +} + +// -------------------------------------------------------------------------- + +type twirpTelemetryKey struct{} + +func TwirpTelemetry( + nodeID livekit.NodeID, + getProjectID func(ctx context.Context) string, + telemetry telemetry.TelemetryService, +) *twirp.ServerHooks { + return &twirp.ServerHooks{ + RequestReceived: telemetryRequestReceived, + Error: telemetryErrorReceived, + ResponseSent: func(ctx context.Context) { + telemetryResponseSent(ctx, nodeID, getProjectID, telemetry) + }, + RequestRouted: telemetryRequestRouted, + } +} + +func RecordRequest(ctx context.Context, request proto.Message) { + if request == nil { + return + } + + a, ok := ctx.Value(twirpTelemetryKey{}).(*livekit.APICallInfo) + if !ok || a == nil { + return + } + + // capture request and extract common fields to top level as appropriate + switch msg := request.(type) { + case *livekit.CreateRoomRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_CreateRoomRequest{ + CreateRoomRequest: msg, + }, + } + a.RoomName = msg.GetName() + + case *livekit.ListRoomsRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_ListRoomsRequest{ + ListRoomsRequest: msg, + }, + } + + case *livekit.DeleteRoomRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_DeleteRoomRequest{ + DeleteRoomRequest: msg, + }, + } + a.RoomName = msg.GetRoom() + + case *livekit.ListParticipantsRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_ListParticipantsRequest{ + ListParticipantsRequest: msg, + }, + } + a.RoomName = msg.GetRoom() + + case *livekit.RoomParticipantIdentity: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_RoomParticipantIdentity{ + RoomParticipantIdentity: msg, + }, + } + a.RoomName = msg.GetRoom() + a.ParticipantIdentity = msg.GetIdentity() + + case *livekit.MuteRoomTrackRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_MuteRoomTrackRequest{ + MuteRoomTrackRequest: msg, + }, + } + a.RoomName = msg.GetRoom() + a.ParticipantIdentity = msg.GetIdentity() + a.TrackId = msg.GetTrackSid() + + case *livekit.UpdateParticipantRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_UpdateParticipantRequest{ + UpdateParticipantRequest: msg, + }, + } + a.RoomName = msg.GetRoom() + a.ParticipantIdentity = msg.GetIdentity() + + case *livekit.UpdateSubscriptionsRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_UpdateSubscriptionsRequest{ + UpdateSubscriptionsRequest: msg, + }, + } + a.RoomName = msg.GetRoom() + a.ParticipantIdentity = msg.GetIdentity() + + case *livekit.SendDataRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_SendDataRequest{ + SendDataRequest: msg, + }, + } + a.RoomName = msg.GetRoom() + + case *livekit.UpdateRoomMetadataRequest: + a.Request = &livekit.APICallRequest{ + Message: &livekit.APICallRequest_UpdateRoomMetadataRequest{ + UpdateRoomMetadataRequest: msg, + }, + } + } +} + +func RecordResponse(ctx context.Context, response proto.Message) { + if response == nil { + return + } + + a, ok := ctx.Value(twirpTelemetryKey{}).(*livekit.APICallInfo) + if !ok || a == nil { + return + } + + // extract common fields to top level as appropriate + switch msg := response.(type) { + case *livekit.Room: + a.RoomId = msg.GetSid() + + case *livekit.ParticipantInfo: + a.ParticipantId = msg.GetSid() + } +} + +func telemetryRequestReceived(ctx context.Context) (context.Context, error) { + a := &livekit.APICallInfo{} + a.StartedAt = timestamppb.Now() + + if svc, ok := twirp.ServiceName(ctx); ok { + a.Service = svc + } + + return context.WithValue(ctx, twirpTelemetryKey{}, a), nil +} + +func telemetryRequestRouted(ctx context.Context) (context.Context, error) { + if meth, ok := twirp.MethodName(ctx); ok { + a, ok := ctx.Value(twirpTelemetryKey{}).(*livekit.APICallInfo) + if !ok || a == nil { + return ctx, nil + } + a.Method = meth + } + + return ctx, nil +} + +func telemetryResponseSent( + ctx context.Context, + nodeID livekit.NodeID, + getProjectID func(ctx context.Context) string, + telemetry telemetry.TelemetryService, +) { + a, ok := ctx.Value(twirpTelemetryKey{}).(*livekit.APICallInfo) + if !ok || a == nil { + return + } + + if getProjectID != nil { + a.ProjectId = getProjectID(ctx) + } + a.NodeId = string(nodeID) + if statusCode, ok := twirp.StatusCode(ctx); ok { + if status, err := strconv.Atoi(statusCode); err == nil { + a.Status = int32(status) + } + } + a.DurationNs = time.Since(a.StartedAt.AsTime()).Nanoseconds() + if telemetry != nil { + telemetry.APICall(ctx, a) + } +} + +func telemetryErrorReceived(ctx context.Context, e twirp.Error) context.Context { + a, ok := ctx.Value(twirpTelemetryKey{}).(*livekit.APICallInfo) + if !ok || a == nil { + return ctx + } + + a.TwirpErrorCode = string(e.Code()) + a.TwirpErrorMessage = e.Msg() + return ctx +} + +// -------------------------------------------------------------------------- diff --git a/livekit/pkg/service/twirp_test.go b/livekit/pkg/service/twirp_test.go new file mode 100644 index 0000000..103baad --- /dev/null +++ b/livekit/pkg/service/twirp_test.go @@ -0,0 +1,34 @@ +/* + * Copyright 2024 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/twitchtv/twirp" +) + +func TestConvertErrToTwirp(t *testing.T) { + t.Run("handles not found", func(t *testing.T) { + err := ErrRoomNotFound + var tErr twirp.Error + require.True(t, errors.As(err, &tErr)) + require.Equal(t, twirp.NotFound, tErr.Code()) + }) +} diff --git a/livekit/pkg/service/utils.go b/livekit/pkg/service/utils.go new file mode 100644 index 0000000..a54625a --- /dev/null +++ b/livekit/pkg/service/utils.go @@ -0,0 +1,355 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "regexp" + "strconv" + "strings" + "sync" + + "github.com/ua-parser/uap-go/uaparser" + "gopkg.in/yaml.v3" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/routing/selector" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +func handleError(w http.ResponseWriter, r *http.Request, status int, err error, keysAndValues ...any) { + keysAndValues = append(keysAndValues, "status", status) + if r != nil && r.URL != nil { + keysAndValues = append(keysAndValues, "method", r.Method, "path", r.URL.Path) + } + if !errors.Is(err, context.Canceled) && !errors.Is(r.Context().Err(), context.Canceled) { + utils.GetLogger(r.Context()).WithCallDepth(1).Warnw("error handling request", err, keysAndValues...) + } + w.WriteHeader(status) +} + +func HandleError(w http.ResponseWriter, r *http.Request, status int, err error, keysAndValues ...any) { + handleError(w, r, status, err, keysAndValues...) + _, _ = w.Write([]byte(err.Error())) +} + +func HandleErrorJson(w http.ResponseWriter, r *http.Request, status int, err error, keysAndValues ...any) { + handleError(w, r, status, err, keysAndValues...) + json.NewEncoder(w).Encode(struct { + Error string `json:"error"` + }{ + Error: err.Error(), + }) + w.Header().Add("Content-type", "application/json") +} + +func boolValue(s string) bool { + return s == "1" || s == "true" +} + +func RemoveDoubleSlashes(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + if strings.HasPrefix(r.URL.Path, "//") { + r.URL.Path = r.URL.Path[1:] + } + next(w, r) +} + +func IsValidDomain(domain string) bool { + domainRegexp := regexp.MustCompile(`^(?i)[a-z0-9-]+(\.[a-z0-9-]+)+\.?$`) + return domainRegexp.MatchString(domain) +} + +func GetClientIP(r *http.Request) string { + // CF proxy typically is first thing the user reaches + if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + return ip + } + if ip := r.Header.Get("X-Real-IP"); ip != "" { + return ip + } + ip, _, _ := net.SplitHostPort(r.RemoteAddr) + return ip +} + +func SetRoomConfiguration(createRequest *livekit.CreateRoomRequest, conf *livekit.RoomConfiguration) { + if conf == nil { + return + } + createRequest.Agents = conf.Agents + createRequest.Egress = conf.Egress + createRequest.EmptyTimeout = conf.EmptyTimeout + createRequest.DepartureTimeout = conf.DepartureTimeout + createRequest.MaxParticipants = conf.MaxParticipants + createRequest.MinPlayoutDelay = conf.MinPlayoutDelay + createRequest.MaxPlayoutDelay = conf.MaxPlayoutDelay + createRequest.SyncStreams = conf.SyncStreams + createRequest.Metadata = conf.Metadata +} + +func ParseClientInfo(r *http.Request) *livekit.ClientInfo { + values := r.Form + ci := &livekit.ClientInfo{} + if pv, err := strconv.Atoi(values.Get("protocol")); err == nil { + ci.Protocol = int32(pv) + } + sdkString := values.Get("sdk") + switch sdkString { + case "js": + ci.Sdk = livekit.ClientInfo_JS + case "ios", "swift": + ci.Sdk = livekit.ClientInfo_SWIFT + case "android": + ci.Sdk = livekit.ClientInfo_ANDROID + case "flutter": + ci.Sdk = livekit.ClientInfo_FLUTTER + case "go": + ci.Sdk = livekit.ClientInfo_GO + case "unity": + ci.Sdk = livekit.ClientInfo_UNITY + case "reactnative": + ci.Sdk = livekit.ClientInfo_REACT_NATIVE + case "rust": + ci.Sdk = livekit.ClientInfo_RUST + case "python": + ci.Sdk = livekit.ClientInfo_PYTHON + case "cpp": + ci.Sdk = livekit.ClientInfo_CPP + case "unityweb": + ci.Sdk = livekit.ClientInfo_UNITY_WEB + case "node": + ci.Sdk = livekit.ClientInfo_NODE + } + + ci.Version = values.Get("version") + ci.Os = values.Get("os") + ci.OsVersion = values.Get("os_version") + ci.Browser = values.Get("browser") + ci.BrowserVersion = values.Get("browser_version") + ci.DeviceModel = values.Get("device_model") + ci.Network = values.Get("network") + + AugmentClientInfo(ci, r) + + return ci +} + +var ( + userAgentParserCache *uaparser.Parser + userAgentParserInit sync.Once +) + +func createUserAgentParserWithCustomRules() (*uaparser.Parser, error) { + defaultYaml := uaparser.DefinitionYaml + + rules := make(map[string]any) + err := yaml.Unmarshal(defaultYaml, rules) + if err != nil { + return nil, err + } + + rules["user_agent_parsers"] = append(rules["user_agent_parsers"].([]any), map[string]any{ + "regex": "OBS-Studio\\/([0-9\\.]+)", + "family_replacement": "OBS Studio", + "v1_replacement": "$1", + }) + + customYaml, err := yaml.Marshal(rules) + if err != nil { + return nil, err + } + + return uaparser.NewFromBytes([]byte(customYaml)) +} + +func getUserAgentParser() *uaparser.Parser { + userAgentParserInit.Do(func() { + if parser, err := createUserAgentParserWithCustomRules(); err != nil { + logger.Warnw("could not create user agent parser with custom rules, using default", err) + userAgentParserCache = uaparser.NewFromSaved() + } else { + userAgentParserCache = parser + } + }) + return userAgentParserCache +} + +func AugmentClientInfo(ci *livekit.ClientInfo, req *http.Request) { + // get real address (forwarded http header) - check Cloudflare headers first, fall back to X-Forwarded-For + ci.Address = GetClientIP(req) + + // attempt to parse types for SDKs that support browser as a platform + if ci.Sdk == livekit.ClientInfo_JS || + ci.Sdk == livekit.ClientInfo_REACT_NATIVE || + ci.Sdk == livekit.ClientInfo_FLUTTER || + ci.Sdk == livekit.ClientInfo_UNITY || + ci.Sdk == livekit.ClientInfo_UNKNOWN { + client := getUserAgentParser().Parse(req.UserAgent()) + if ci.Browser == "" { + ci.Browser = client.UserAgent.Family + ci.BrowserVersion = client.UserAgent.ToVersionString() + } + if ci.Os == "" { + ci.Os = client.Os.Family + ci.OsVersion = client.Os.ToVersionString() + } + if ci.DeviceModel == "" { + model := client.Device.Family + if model != "" && client.Device.Model != "" && model != client.Device.Model { + model += " " + client.Device.Model + } + + ci.DeviceModel = model + } + } +} + +type ValidateConnectRequestParams struct { + roomName livekit.RoomName + publish string + metadata string + attributes map[string]string +} + +type ValidateConnectRequestResult struct { + roomName livekit.RoomName + grants *auth.ClaimGrants + region string + createRoomRequest *livekit.CreateRoomRequest +} + +func ValidateConnectRequest( + lgr logger.Logger, + r *http.Request, + limitConfig config.LimitConfig, + params ValidateConnectRequestParams, + router routing.MessageRouter, + roomAllocator RoomAllocator, +) (ValidateConnectRequestResult, int, error) { + var res ValidateConnectRequestResult + + // require a claim + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil { + return res, http.StatusUnauthorized, rtc.ErrPermissionDenied + } + + roomNameInToken, err := EnsureJoinPermission(r.Context()) + if err != nil { + return res, http.StatusUnauthorized, err + } + + if claims.Identity == "" { + return res, http.StatusBadRequest, ErrIdentityEmpty + } + if !limitConfig.CheckParticipantIdentityLength(claims.Identity) { + return res, http.StatusBadRequest, fmt.Errorf("%w: max length %d", ErrParticipantIdentityExceedsLimits, limitConfig.MaxParticipantIdentityLength) + } + + if claims.RoomConfig != nil { + if err := claims.RoomConfig.CheckCredentials(); err != nil { + lgr.Warnw("credentials found in token", nil) + // TODO(dz): in a future version, we'll reject these connections + } + } + + res.roomName = params.roomName + if roomNameInToken != "" { + res.roomName = roomNameInToken + } + if res.roomName == "" { + return res, http.StatusBadRequest, ErrNoRoomName + } + if !limitConfig.CheckRoomNameLength(string(res.roomName)) { + return res, http.StatusBadRequest, fmt.Errorf("%w: max length %d", ErrRoomNameExceedsLimits, limitConfig.MaxRoomNameLength) + } + + // this is new connection for existing participant - with publish only permissions + if params.publish != "" { + // Make sure grant has GetCanPublish set, + if !claims.Video.GetCanPublish() { + return res, http.StatusUnauthorized, rtc.ErrPermissionDenied + } + // Make sure by default subscribe is off + claims.Video.SetCanSubscribe(false) + claims.Identity += "#" + params.publish + } + + // room allocator validations + err = roomAllocator.ValidateCreateRoom(r.Context(), res.roomName) + if err != nil { + if errors.Is(err, ErrRoomNotFound) { + return res, http.StatusNotFound, err + } else { + return res, http.StatusInternalServerError, err + } + } + + if router, ok := router.(routing.Router); ok { + res.region = router.GetRegion() + if foundNode, err := router.GetNodeForRoom(r.Context(), res.roomName); err == nil { + if selector.LimitsReached(limitConfig, foundNode.Stats) { + return res, http.StatusServiceUnavailable, rtc.ErrLimitExceeded + } + } + } + + createRequest := &livekit.CreateRoomRequest{ + Name: string(res.roomName), + RoomPreset: claims.RoomPreset, + } + SetRoomConfiguration(createRequest, claims.GetRoomConfiguration()) + res.createRoomRequest = createRequest + + if len(params.metadata) != 0 { + // Make sure grant has GetCanUpdateOwnMetadata set + if !claims.Video.GetCanUpdateOwnMetadata() { + return res, http.StatusUnauthorized, rtc.ErrPermissionDenied + } + claims.Metadata = params.metadata + } + + // Add extra attributes to the participant + if len(params.attributes) != 0 { + // Make sure grant has GetCanUpdateOwnMetadata set + if !claims.Video.GetCanUpdateOwnMetadata() { + return res, http.StatusUnauthorized, rtc.ErrPermissionDenied + } + if claims.Attributes == nil { + claims.Attributes = make(map[string]string, len(params.attributes)) + } + for k, v := range params.attributes { + if v == "" { + continue // do not allow deleting existing attributes + } + claims.Attributes[k] = v + } + } + + res.grants = claims + return res, http.StatusOK, nil +} diff --git a/livekit/pkg/service/utils_test.go b/livekit/pkg/service/utils_test.go new file mode 100644 index 0000000..5d8cdca --- /dev/null +++ b/livekit/pkg/service/utils_test.go @@ -0,0 +1,77 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service_test + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/service" +) + +func redisClientDocker(t testing.TB) *redis.Client { + addr := runRedis(t) + cli := redis.NewClient(&redis.Options{ + Addr: addr, + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := cli.Ping(ctx).Err(); err != nil { + _ = cli.Close() + t.Fatal(err) + } + t.Cleanup(func() { + _ = cli.Close() + }) + return cli +} + +func redisClient(t testing.TB) *redis.Client { + cli := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := cli.Ping(ctx).Err() + if err == nil { + t.Cleanup(func() { + _ = cli.Close() + }) + return cli + } + _ = cli.Close() + t.Logf("local redis not available: %v", err) + + t.Logf("starting redis in docker") + return redisClientDocker(t) +} + +func TestIsValidDomain(t *testing.T) { + list := map[string]bool{ + "turn.myhost.com": true, + "turn.google.com": true, + "https://host.com": false, + "turn://host.com": false, + } + for key, result := range list { + service.IsValidDomain(key) + require.Equal(t, service.IsValidDomain(key), result) + } +} diff --git a/livekit/pkg/service/whipservice.go b/livekit/pkg/service/whipservice.go new file mode 100644 index 0000000..65b607c --- /dev/null +++ b/livekit/pkg/service/whipservice.go @@ -0,0 +1,535 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/pion/webrtc/v4" + "github.com/tomnomnom/linkheader" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/types" + sutils "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/psrpc" +) + +const ( + cParticipantPath = "/whip/v1" + cParticipantIDPath = "/whip/v1/{participant_id}" +) + +type WHIPService struct { + http.Handler + + config *config.Config + router routing.Router + roomAllocator RoomAllocator + client rpc.WHIPClient[livekit.NodeID] + topicFormatter rpc.TopicFormatter + participantClient rpc.TypedWHIPParticipantClient +} + +func NewWHIPService( + config *config.Config, + router routing.Router, + roomAllocator RoomAllocator, + clientParams rpc.ClientParams, + topicFormatter rpc.TopicFormatter, + participantClient rpc.TypedWHIPParticipantClient, +) (*WHIPService, error) { + client, err := rpc.NewWHIPClient[livekit.NodeID](clientParams.Args()) + if err != nil { + return nil, err + } + + return &WHIPService{ + config: config, + router: router, + roomAllocator: roomAllocator, + client: client, + topicFormatter: topicFormatter, + participantClient: participantClient, + }, nil +} + +func (s *WHIPService) SetupRoutes(mux *http.ServeMux) { + mux.HandleFunc("GET "+cParticipantPath, s.handleGet) + mux.HandleFunc("OPTIONS "+cParticipantPath, s.handleOptions) + mux.HandleFunc("POST "+cParticipantPath, s.handleCreate) + mux.HandleFunc("GET "+cParticipantIDPath, s.handleParticipantGet) + mux.HandleFunc("PATCH "+cParticipantIDPath, s.handleParticipantPatch) + mux.HandleFunc("DELETE "+cParticipantIDPath, s.handleParticipantDelete) +} + +func (s *WHIPService) handleGet(w http.ResponseWriter, r *http.Request) { + // https:/www.rfc-editor.org/rfc/rfc9725.html#name-http-usage + w.WriteHeader(http.StatusNoContent) +} + +func (s *WHIPService) handleOptions(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + w.Header().Set("Access-Control-Allow-Methods", "PATCH, OPTIONS, GET, POST, DELETE") + w.Header().Set("Access-Control-Expose-Headers", "*") + + w.WriteHeader(http.StatusOK) + + // According to https://www.rfc-editor.org/rfc/rfc9725.html#name-stun-turn-server-configurat, + // ICE servers can be returned in OPTIONS response, but not recommended. + // + // Supporting that here is tricky. This would have to get region settings like the + // session CREATE POST request and send a request to get ICE servers from a + // region + media node that is selected. The issue is that a subsequent POST, + // although unlikely, may end up in a different region. Media node in one region and + // TURN in another region, although shuttling media across regions, should still work. + // But, as this is not a recommended way, not supporting it. +} + +type createRequest struct { + RoomName livekit.RoomName + ParticipantInit routing.ParticipantInit + ClientIP string + OfferSDP string + SubscribedParticipantTrackNames map[string][]string + FromIngress bool +} + +func (s *WHIPService) validateCreate(r *http.Request) (*createRequest, int, error) { + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil { + return nil, http.StatusUnauthorized, rtc.ErrPermissionDenied + } + + roomName, err := EnsureJoinPermission(r.Context()) + if err != nil { + return nil, http.StatusUnauthorized, err + } + if roomName == "" { + return nil, http.StatusUnauthorized, errors.New("room name cannot be empty") + } + if !s.config.Limit.CheckRoomNameLength(string(roomName)) { + return nil, http.StatusBadRequest, fmt.Errorf("%w: max length %d", ErrRoomNameExceedsLimits, s.config.Limit.MaxRoomNameLength) + } + + if claims.Identity == "" { + return nil, http.StatusBadRequest, ErrIdentityEmpty + } + if !s.config.Limit.CheckParticipantIdentityLength(claims.Identity) { + return nil, http.StatusBadRequest, fmt.Errorf("%w: max length %d", ErrParticipantIdentityExceedsLimits, s.config.Limit.MaxParticipantIdentityLength) + } + + var clientInfo struct { + ClientIP string `json:"clientIp"` + SubscribedParticipantTrackNames map[string][]string `json:"subscribedParticipantTrackNames"` + } + clientInfoHeader := r.Header.Get("X-LiveKit-ClientInfo") + if clientInfoHeader != "" { + if err := json.NewDecoder(strings.NewReader(clientInfoHeader)).Decode(&clientInfo); err != nil { + return nil, http.StatusBadRequest, fmt.Errorf("malformed json in client info header: %s", err) + } + } + + fromIngress := r.Header.Get("X-Livekit-Ingress") + + offerSDPBytes, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, http.StatusBadRequest, fmt.Errorf("body does not have SDP offer: %s", err) + } + if len(offerSDPBytes) == 0 { + return nil, http.StatusBadRequest, errors.New("body does not have SDP offer") + } + offerSDP := string(offerSDPBytes) + sd := &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: offerSDP, + } + _, err = sd.Unmarshal() + if err != nil { + return nil, http.StatusBadRequest, fmt.Errorf("malformed SDP offer: %s", err) + } + + ci := ParseClientInfo(r) + if ci.Protocol == 0 { + // if no client info available (which will be mostly the case with WHIP clients), at least set protocol + ci.Protocol = types.CurrentProtocol + } + + pi := routing.ParticipantInit{ + Identity: livekit.ParticipantIdentity(claims.Identity), + Name: livekit.ParticipantName(claims.Name), + AutoSubscribe: true, + Client: ci, + Grants: claims, + CreateRoom: &livekit.CreateRoomRequest{ + Name: string(roomName), + RoomPreset: claims.RoomPreset, + }, + AdaptiveStream: false, + DisableICELite: true, + } + SetRoomConfiguration(pi.CreateRoom, claims.GetRoomConfiguration()) + + return &createRequest{ + roomName, + pi, + clientInfo.ClientIP, + offerSDP, + clientInfo.SubscribedParticipantTrackNames, + fromIngress != "", + }, http.StatusOK, nil +} + +func (s *WHIPService) handleCreate(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-type") != "application/sdp" { + s.handleError("Create", w, r, http.StatusBadRequest, fmt.Errorf("unsupported content-type: %s", r.Header.Get("Content-type"))) + return + } + + w.Header().Add("Content-type", "application/sdp") + + req, status, err := s.validateCreate(r) + if err != nil { + s.handleError("Create", w, r, status, err) + return + } + + if err := s.roomAllocator.SelectRoomNode(r.Context(), req.RoomName, ""); err != nil { + s.handleError("Create", w, r, http.StatusInternalServerError, err) + return + } + + rtcNode, err := s.router.GetNodeForRoom(r.Context(), req.RoomName) + if err != nil { + s.handleError("Create", w, r, http.StatusInternalServerError, err) + return + } + + connID := livekit.ConnectionID(guid.New("CO_")) + starSession, err := req.ParticipantInit.ToStartSession(req.RoomName, connID) + if err != nil { + s.handleError("Create", w, r, http.StatusInternalServerError, err) + return + } + + subscribedParticipantTracks := map[string]*rpc.WHIPCreateRequest_TrackList{} + for identity, trackNames := range req.SubscribedParticipantTrackNames { + subscribedParticipantTracks[identity] = &rpc.WHIPCreateRequest_TrackList{ + TrackNames: trackNames, + } + } + + res, err := s.client.Create(r.Context(), livekit.NodeID(rtcNode.Id), &rpc.WHIPCreateRequest{ + OfferSdp: req.OfferSDP, + StartSession: starSession, + SubscribedParticipantTracks: subscribedParticipantTracks, + FromIngress: req.FromIngress, + }) + if err != nil { + s.handleError("Create", w, r, http.StatusServiceUnavailable, err) + return + } + + // created resource sent in Location header: + // https://www.rfc-editor.org/rfc/rfc9725.html#name-ingest-session-setup + // using relative location + w.Header().Add("Location", fmt.Sprintf("%s/%s", cParticipantPath, res.ParticipantId)) + + // ICE servers as Link header(s): + // https://www.rfc-editor.org/rfc/rfc9725.html#name-stun-turn-server-configurat + var iceServerLinks []*linkheader.Link + for _, iceServer := range res.IceServers { + for _, iceURL := range iceServer.Urls { + iceServerLink := &linkheader.Link{ + URL: url.PathEscape(iceURL), + Rel: "ice-server", + Params: map[string]string{}, + } + if iceServer.Username != "" { + iceServerLink.Params["username"] = iceServer.Username + } + if iceServer.Credential != "" { + iceServerLink.Params["credential"] = iceServer.Credential + } + + iceServerLinks = append(iceServerLinks, iceServerLink) + } + } + for _, iceServerLink := range iceServerLinks { + w.Header().Add("Link", iceServerLink.String()) + } + + // To support ICE Trickle/Restart, HTTP PATCH should have an ETag + // send ICE session ID (ICE ufrag is used as ID) in ETag header + // https://www.rfc-editor.org/rfc/rfc9725.html#name-http-patch-request-usage + if res.IceSessionId != "" { + w.Header().Add("ETag", res.IceSessionId) + } + + // 201 Status Created + w.WriteHeader(http.StatusCreated) + + // SDP answer in the response body + w.Write([]byte(res.AnswerSdp)) + + sutils.GetLogger(r.Context()).Infow( + "API WHIP.Create", + "connID", connID, + "participant", req.ParticipantInit.Identity, + "room", req.RoomName, + "status", http.StatusCreated, + "response", logger.Proto(res), + ) +} + +func (s *WHIPService) handleParticipantGet(w http.ResponseWriter, r *http.Request) { + // https:/www.rfc-editor.org/rfc/rfc9725.html#name-http-usage + w.WriteHeader(http.StatusNoContent) +} + +func (s *WHIPService) iceTrickle( + w http.ResponseWriter, + r *http.Request, + roomName livekit.RoomName, + participantIdentity livekit.ParticipantIdentity, + pID livekit.ParticipantID, + iceSessionID string, + sdpFragment string, +) { + _, err := s.participantClient.ICETrickle( + r.Context(), + s.topicFormatter.ParticipantTopic(r.Context(), roomName, participantIdentity), + &rpc.WHIPParticipantICETrickleRequest{ + Room: string(roomName), + ParticipantIdentity: string(participantIdentity), + ParticipantId: string(pID), + IceSessionId: iceSessionID, + SdpFragment: sdpFragment, + }, + ) + if err != nil { + var pe psrpc.Error + if errors.As(err, &pe) { + switch pe.Code() { + case psrpc.NotFound: + s.handleError("Patch", w, r, http.StatusNotFound, errors.New(pe.Error())) + + case psrpc.InvalidArgument: + switch pe.Error() { + case rtc.ErrInvalidSDPFragment.Error(), rtc.ErrMidMismatch.Error(), rtc.ErrICECredentialMismatch.Error(): + s.handleError("Patch", w, r, http.StatusBadRequest, errors.New(pe.Error())) + default: + s.handleError("Patch", w, r, http.StatusInternalServerError, errors.New(pe.Error())) + } + default: + s.handleError("Patch", w, r, http.StatusInternalServerError, errors.New(pe.Error())) + } + } else { + s.handleError("Patch", w, r, http.StatusInternalServerError, nil) + } + return + } + sutils.GetLogger(r.Context()).Infow( + "API WHIP.Patch", + "method", "ice-trickle", + "room", roomName, + "participant", participantIdentity, + "pID", pID, + "sdpFragment", sdpFragment, + "status", http.StatusNoContent, + ) + w.WriteHeader(http.StatusNoContent) +} + +func (s *WHIPService) iceRestart( + w http.ResponseWriter, + r *http.Request, + roomName livekit.RoomName, + participantIdentity livekit.ParticipantIdentity, + pID livekit.ParticipantID, + sdpFragment string, +) { + res, err := s.participantClient.ICERestart( + r.Context(), + s.topicFormatter.ParticipantTopic(r.Context(), roomName, participantIdentity), + &rpc.WHIPParticipantICERestartRequest{ + Room: string(roomName), + ParticipantIdentity: string(participantIdentity), + ParticipantId: string(pID), + SdpFragment: sdpFragment, + }, + ) + if err != nil { + var pe psrpc.Error + if errors.As(err, &pe) { + switch pe.Code() { + case psrpc.NotFound: + s.handleError("Patch", w, r, http.StatusNotFound, errors.New(pe.Error())) + + case psrpc.InvalidArgument: + switch pe.Error() { + case rtc.ErrInvalidSDPFragment.Error(): + s.handleError("Patch", w, r, http.StatusBadRequest, errors.New(pe.Error())) + default: + s.handleError("Patch", w, r, http.StatusInternalServerError, errors.New(pe.Error())) + } + default: + s.handleError("Patch", w, r, http.StatusInternalServerError, errors.New(pe.Error())) + } + } else { + s.handleError("Patch", w, r, http.StatusInternalServerError, nil) + } + return + } + sutils.GetLogger(r.Context()).Infow( + "API WHIP.Patch", + "method", "ice-restart", + "room", roomName, + "participant", participantIdentity, + "pID", pID, + "sdpFragment", sdpFragment, + "status", http.StatusNoContent, + "res", logger.Proto(res), + ) + if res.IceSessionId != "" { + w.Header().Add("ETag", res.IceSessionId) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(res.SdpFragment)) +} + +func (s *WHIPService) handleParticipantPatch(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-type") != "application/trickle-ice-sdpfrag" { + s.handleError("Patch", w, r, http.StatusBadRequest, fmt.Errorf("unsupported content-type: %s", r.Header.Get("Content-type"))) + return + } + + w.Header().Add("Content-type", "application/trickle-ice-sdpfrag") + + // https://www.rfc-editor.org/rfc/rfc9725.html#name-http-patch-request-usage + ifMatch := r.Header.Get("If-Match") + if ifMatch == "" { + s.handleError("Patch", w, r, http.StatusPreconditionRequired, errors.New("missing entity tag")) + return + } + + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil { + s.handleError("Patch", w, r, http.StatusUnauthorized, rtc.ErrPermissionDenied) + return + } + + roomName, err := EnsureJoinPermission(r.Context()) + if err != nil { + s.handleError("Patch", w, r, http.StatusUnauthorized, err) + return + } + if roomName == "" { + s.handleError("Patch", w, r, http.StatusUnauthorized, errors.New("room name cannot be empty")) + return + } + if claims.Identity == "" { + s.handleError("Patch", w, r, http.StatusUnauthorized, errors.New("participant identity cannot be empty")) + return + } + pID := livekit.ParticipantID(r.PathValue("participant_id")) + if pID == "" { + s.handleError("Patch", w, r, http.StatusBadRequest, errors.New("participant ID cannot be empty")) + return + } + + sdpFragmentBytes, err := ioutil.ReadAll(r.Body) + if err != nil { + s.handleError("Patch", w, r, http.StatusBadRequest, fmt.Errorf("body does not have SDP fragment: %s", err)) + return + } + sdpFragment := string(sdpFragmentBytes) + + if ifMatch == "*" { + s.iceRestart(w, r, roomName, livekit.ParticipantIdentity(claims.Identity), pID, sdpFragment) + } else { + s.iceTrickle(w, r, roomName, livekit.ParticipantIdentity(claims.Identity), pID, ifMatch, sdpFragment) + } +} + +func (s *WHIPService) handleParticipantDelete(w http.ResponseWriter, r *http.Request) { + claims := GetGrants(r.Context()) + if claims == nil || claims.Video == nil { + s.handleError("Delete", w, r, http.StatusUnauthorized, rtc.ErrPermissionDenied) + return + } + + roomName, err := EnsureJoinPermission(r.Context()) + if err != nil { + s.handleError("Delete", w, r, http.StatusUnauthorized, err) + return + } + if roomName == "" { + s.handleError("Delete", w, r, http.StatusUnauthorized, errors.New("room name cannot be empty")) + return + } + if claims.Identity == "" { + s.handleError("Delete", w, r, http.StatusUnauthorized, errors.New("participant identity cannot be empty")) + return + } + + _, err = s.participantClient.DeleteSession( + r.Context(), + s.topicFormatter.ParticipantTopic(r.Context(), roomName, livekit.ParticipantIdentity(claims.Identity)), + &rpc.WHIPParticipantDeleteSessionRequest{ + Room: string(roomName), + ParticipantIdentity: claims.Identity, + ParticipantId: r.PathValue("participant_id"), + }, + ) + if err != nil { + s.handleError("Delete", w, r, http.StatusNotFound, err) + return + } + + sutils.GetLogger(r.Context()).Infow( + "API WHIP.Delete", + "participant", claims.Identity, + "pID", r.PathValue("participant_id"), + "room", roomName, + "status", http.StatusOK, + ) + w.WriteHeader(http.StatusOK) +} + +func (s *WHIPService) handleError(method string, w http.ResponseWriter, r *http.Request, status int, err error) { + sutils.GetLogger(r.Context()).Warnw( + fmt.Sprintf("API WHIP.%s", method), err, + "status", status, + ) + w.WriteHeader(status) + json.NewEncoder(w).Encode(struct { + Error string `json:"error"` + }{ + Error: err.Error(), + }) +} diff --git a/livekit/pkg/service/wire.go b/livekit/pkg/service/wire.go new file mode 100644 index 0000000..bb42c98 --- /dev/null +++ b/livekit/pkg/service/wire.go @@ -0,0 +1,280 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build wireinject +// +build wireinject + +package service + +import ( + "fmt" + "os" + + "github.com/google/wire" + "github.com/pion/turn/v4" + "github.com/pkg/errors" + "github.com/redis/go-redis/v9" + "gopkg.in/yaml.v3" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + redisLiveKit "github.com/livekit/protocol/redis" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/webhook" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/middleware/otelpsrpc" + + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/telemetry" +) + +func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*LivekitServer, error) { + wire.Build( + getNodeID, + createRedisClient, + createStore, + wire.Bind(new(ServiceStore), new(ObjectStore)), + createKeyProvider, + createWebhookNotifier, + createForwardStats, + getNodeStatsConfig, + routing.CreateRouter, + getLimitConf, + config.DefaultAPIConfig, + wire.Bind(new(routing.MessageRouter), new(routing.Router)), + wire.Bind(new(livekit.RoomService), new(*RoomService)), + telemetry.NewAnalyticsService, + telemetry.NewTelemetryService, + getMessageBus, + NewIOInfoService, + wire.Bind(new(IOClient), new(*IOInfoService)), + rpc.NewEgressClient, + rpc.NewIngressClient, + getEgressStore, + NewEgressLauncher, + NewEgressService, + getIngressStore, + getIngressConfig, + NewIngressService, + rpc.NewSIPClientWithParams, + getSIPStore, + getSIPConfig, + NewSIPService, + NewRoomAllocator, + NewRoomService, + NewRTCService, + NewWHIPService, + NewAgentService, + NewAgentDispatchService, + getAgentConfig, + agent.NewAgentClient, + getAgentStore, + getSignalRelayConfig, + NewDefaultSignalServer, + routing.NewSignalClient, + getRoomConfig, + routing.NewRoomManagerClient, + rpc.NewKeepalivePubSub, + getPSRPCConfig, + getPSRPCClientParams, + rpc.NewTopicFormatter, + rpc.NewTypedRoomClient, + rpc.NewTypedParticipantClient, + rpc.NewTypedWHIPParticipantClient, + rpc.NewTypedAgentDispatchInternalClient, + NewLocalRoomManager, + NewTURNAuthHandler, + getTURNAuthHandlerFunc, + newInProcessTurnServer, + utils.NewDefaultTimedVersionGenerator, + NewLivekitServer, + ) + return &LivekitServer{}, nil +} + +func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routing.Router, error) { + wire.Build( + createRedisClient, + getNodeID, + getMessageBus, + getSignalRelayConfig, + getPSRPCConfig, + getPSRPCClientParams, + routing.NewSignalClient, + getRoomConfig, + routing.NewRoomManagerClient, + rpc.NewKeepalivePubSub, + getNodeStatsConfig, + routing.CreateRouter, + ) + + return nil, nil +} + +func getNodeID(currentNode routing.LocalNode) livekit.NodeID { + return currentNode.NodeID() +} + +func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) { + // prefer keyfile if set + if conf.KeyFile != "" { + var otherFilter os.FileMode = 0007 + if st, err := os.Stat(conf.KeyFile); err != nil { + return nil, err + } else if st.Mode().Perm()&otherFilter != 0000 { + return nil, fmt.Errorf("key file others permissions must be set to 0") + } + f, err := os.Open(conf.KeyFile) + if err != nil { + return nil, err + } + defer func() { + _ = f.Close() + }() + decoder := yaml.NewDecoder(f) + if err = decoder.Decode(conf.Keys); err != nil { + return nil, err + } + } + + if len(conf.Keys) == 0 { + return nil, errors.New("one of key-file or keys must be provided in order to support a secure installation") + } + + return auth.NewFileBasedKeyProviderFromMap(conf.Keys), nil +} + +func createWebhookNotifier(conf *config.Config, provider auth.KeyProvider) (webhook.QueuedNotifier, error) { + wc := conf.WebHook + + secret := provider.GetSecret(wc.APIKey) + if secret == "" && len(wc.URLs) > 0 { + return nil, ErrWebHookMissingAPIKey + } + + return webhook.NewDefaultNotifier(wc, provider) +} + +func createRedisClient(conf *config.Config) (redis.UniversalClient, error) { + if !conf.Redis.IsConfigured() { + return nil, nil + } + return redisLiveKit.GetRedisClient(&conf.Redis) +} + +func createStore(rc redis.UniversalClient) ObjectStore { + if rc != nil { + return NewRedisStore(rc) + } + return NewLocalStore() +} + +func getMessageBus(rc redis.UniversalClient) psrpc.MessageBus { + if rc == nil { + return psrpc.NewLocalMessageBus() + } + return psrpc.NewRedisMessageBus(rc) +} + +func getEgressStore(s ObjectStore) EgressStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + +func getIngressStore(s ObjectStore) IngressStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + +func getAgentStore(s ObjectStore) AgentStore { + switch store := s.(type) { + case *RedisStore: + return store + case *LocalStore: + return store + default: + return nil + } +} + +func getIngressConfig(conf *config.Config) *config.IngressConfig { + return &conf.Ingress +} + +func getSIPStore(s ObjectStore) SIPStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + +func getSIPConfig(conf *config.Config) *config.SIPConfig { + return &conf.SIP +} + +func getLimitConf(config *config.Config) config.LimitConfig { + return config.Limit +} + +func getRoomConfig(config *config.Config) config.RoomConfig { + return config.Room +} + +func getSignalRelayConfig(config *config.Config) config.SignalRelayConfig { + return config.SignalRelay +} + +func getPSRPCConfig(config *config.Config) rpc.PSRPCConfig { + return config.PSRPC +} + +func getPSRPCClientParams(config rpc.PSRPCConfig, bus psrpc.MessageBus) rpc.ClientParams { + return rpc.NewClientParams(config, bus, logger.GetLogger(), rpc.PSRPCMetricsObserver{}, + otelpsrpc.ClientOptions(otelpsrpc.Config{}), + ) +} + +func createForwardStats(conf *config.Config) *sfu.ForwardStats { + if conf.RTC.ForwardStats.SummaryInterval == 0 || conf.RTC.ForwardStats.ReportInterval == 0 || conf.RTC.ForwardStats.ReportWindow == 0 { + return nil + } + return sfu.NewForwardStats(conf.RTC.ForwardStats.SummaryInterval, conf.RTC.ForwardStats.ReportInterval, conf.RTC.ForwardStats.ReportWindow) +} + +func newInProcessTurnServer(conf *config.Config, authHandler turn.AuthHandler) (*turn.Server, error) { + return NewTurnServer(conf, authHandler, false) +} + +func getNodeStatsConfig(config *config.Config) config.NodeStatsConfig { + return config.NodeStats +} + +func getAgentConfig(config *config.Config) agent.Config { + return config.Agents +} diff --git a/livekit/pkg/service/wire_gen.go b/livekit/pkg/service/wire_gen.go new file mode 100644 index 0000000..cb00cb0 --- /dev/null +++ b/livekit/pkg/service/wire_gen.go @@ -0,0 +1,343 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package service + +import ( + "fmt" + "github.com/livekit/livekit-server/pkg/agent" + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + redis2 "github.com/livekit/protocol/redis" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/webhook" + "github.com/livekit/psrpc" + "github.com/livekit/psrpc/pkg/middleware/otelpsrpc" + "github.com/pion/turn/v4" + "github.com/pkg/errors" + "github.com/redis/go-redis/v9" + "gopkg.in/yaml.v3" + "os" +) + +import ( + _ "net/http/pprof" +) + +// Injectors from wire.go: + +func InitializeServer(conf *config.Config, currentNode routing.LocalNode) (*LivekitServer, error) { + limitConfig := getLimitConf(conf) + apiConfig := config.DefaultAPIConfig() + universalClient, err := createRedisClient(conf) + if err != nil { + return nil, err + } + nodeID := getNodeID(currentNode) + messageBus := getMessageBus(universalClient) + signalRelayConfig := getSignalRelayConfig(conf) + signalClient, err := routing.NewSignalClient(nodeID, messageBus, signalRelayConfig) + if err != nil { + return nil, err + } + psrpcConfig := getPSRPCConfig(conf) + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + roomConfig := getRoomConfig(conf) + roomManagerClient, err := routing.NewRoomManagerClient(clientParams, roomConfig) + if err != nil { + return nil, err + } + keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams) + if err != nil { + return nil, err + } + nodeStatsConfig := getNodeStatsConfig(conf) + router := routing.CreateRouter(universalClient, currentNode, signalClient, roomManagerClient, keepalivePubSub, nodeStatsConfig) + objectStore := createStore(universalClient) + roomAllocator, err := NewRoomAllocator(conf, router, objectStore) + if err != nil { + return nil, err + } + egressClient, err := rpc.NewEgressClient(clientParams) + if err != nil { + return nil, err + } + egressStore := getEgressStore(objectStore) + ingressStore := getIngressStore(objectStore) + sipStore := getSIPStore(objectStore) + keyProvider, err := createKeyProvider(conf) + if err != nil { + return nil, err + } + queuedNotifier, err := createWebhookNotifier(conf, keyProvider) + if err != nil { + return nil, err + } + analyticsService := telemetry.NewAnalyticsService(conf, currentNode) + telemetryService := telemetry.NewTelemetryService(queuedNotifier, analyticsService) + ioInfoService, err := NewIOInfoService(messageBus, egressStore, ingressStore, sipStore, telemetryService) + if err != nil { + return nil, err + } + rtcEgressLauncher := NewEgressLauncher(egressClient, ioInfoService, objectStore) + topicFormatter := rpc.NewTopicFormatter() + roomClient, err := rpc.NewTypedRoomClient(clientParams) + if err != nil { + return nil, err + } + participantClient, err := rpc.NewTypedParticipantClient(clientParams) + if err != nil { + return nil, err + } + roomService, err := NewRoomService(limitConfig, apiConfig, router, roomAllocator, objectStore, rtcEgressLauncher, topicFormatter, roomClient, participantClient) + if err != nil { + return nil, err + } + agentDispatchInternalClient, err := rpc.NewTypedAgentDispatchInternalClient(clientParams) + if err != nil { + return nil, err + } + agentDispatchService := NewAgentDispatchService(agentDispatchInternalClient, topicFormatter, roomAllocator, router) + egressService := NewEgressService(egressClient, rtcEgressLauncher, ioInfoService, roomService) + ingressConfig := getIngressConfig(conf) + ingressClient, err := rpc.NewIngressClient(clientParams) + if err != nil { + return nil, err + } + ingressService := NewIngressService(ingressConfig, nodeID, messageBus, ingressClient, ingressStore, ioInfoService, telemetryService) + sipConfig := getSIPConfig(conf) + sipClient, err := rpc.NewSIPClientWithParams(clientParams) + if err != nil { + return nil, err + } + sipService := NewSIPService(sipConfig, nodeID, messageBus, sipClient, sipStore, roomService, telemetryService) + rtcService := NewRTCService(conf, roomAllocator, router, telemetryService) + whipParticipantClient, err := rpc.NewTypedWHIPParticipantClient(clientParams) + if err != nil { + return nil, err + } + serviceWHIPService, err := NewWHIPService(conf, router, roomAllocator, clientParams, topicFormatter, whipParticipantClient) + if err != nil { + return nil, err + } + agentService, err := NewAgentService(conf, currentNode, messageBus, keyProvider) + if err != nil { + return nil, err + } + agentConfig := getAgentConfig(conf) + client, err := agent.NewAgentClient(messageBus, agentConfig) + if err != nil { + return nil, err + } + agentStore := getAgentStore(objectStore) + timedVersionGenerator := utils.NewDefaultTimedVersionGenerator() + turnAuthHandler := NewTURNAuthHandler(keyProvider) + forwardStats := createForwardStats(conf) + roomManager, err := NewLocalRoomManager(conf, objectStore, currentNode, router, roomAllocator, telemetryService, client, agentStore, rtcEgressLauncher, timedVersionGenerator, turnAuthHandler, messageBus, forwardStats) + if err != nil { + return nil, err + } + signalServer, err := NewDefaultSignalServer(currentNode, messageBus, signalRelayConfig, router, roomManager) + if err != nil { + return nil, err + } + authHandler := getTURNAuthHandlerFunc(turnAuthHandler) + server, err := newInProcessTurnServer(conf, authHandler) + if err != nil { + return nil, err + } + livekitServer, err := NewLivekitServer(conf, roomService, agentDispatchService, egressService, ingressService, sipService, ioInfoService, rtcService, serviceWHIPService, agentService, keyProvider, router, roomManager, signalServer, server, currentNode) + if err != nil { + return nil, err + } + return livekitServer, nil +} + +func InitializeRouter(conf *config.Config, currentNode routing.LocalNode) (routing.Router, error) { + universalClient, err := createRedisClient(conf) + if err != nil { + return nil, err + } + nodeID := getNodeID(currentNode) + messageBus := getMessageBus(universalClient) + signalRelayConfig := getSignalRelayConfig(conf) + signalClient, err := routing.NewSignalClient(nodeID, messageBus, signalRelayConfig) + if err != nil { + return nil, err + } + psrpcConfig := getPSRPCConfig(conf) + clientParams := getPSRPCClientParams(psrpcConfig, messageBus) + roomConfig := getRoomConfig(conf) + roomManagerClient, err := routing.NewRoomManagerClient(clientParams, roomConfig) + if err != nil { + return nil, err + } + keepalivePubSub, err := rpc.NewKeepalivePubSub(clientParams) + if err != nil { + return nil, err + } + nodeStatsConfig := getNodeStatsConfig(conf) + router := routing.CreateRouter(universalClient, currentNode, signalClient, roomManagerClient, keepalivePubSub, nodeStatsConfig) + return router, nil +} + +// wire.go: + +func getNodeID(currentNode routing.LocalNode) livekit.NodeID { + return currentNode.NodeID() +} + +func createKeyProvider(conf *config.Config) (auth.KeyProvider, error) { + + if conf.KeyFile != "" { + var otherFilter os.FileMode = 0007 + if st, err := os.Stat(conf.KeyFile); err != nil { + return nil, err + } else if st.Mode().Perm()&otherFilter != 0000 { + return nil, fmt.Errorf("key file others permissions must be set to 0") + } + f, err := os.Open(conf.KeyFile) + if err != nil { + return nil, err + } + defer func() { + _ = f.Close() + }() + decoder := yaml.NewDecoder(f) + if err = decoder.Decode(conf.Keys); err != nil { + return nil, err + } + } + + if len(conf.Keys) == 0 { + return nil, errors.New("one of key-file or keys must be provided in order to support a secure installation") + } + + return auth.NewFileBasedKeyProviderFromMap(conf.Keys), nil +} + +func createWebhookNotifier(conf *config.Config, provider auth.KeyProvider) (webhook.QueuedNotifier, error) { + wc := conf.WebHook + + secret := provider.GetSecret(wc.APIKey) + if secret == "" && len(wc.URLs) > 0 { + return nil, ErrWebHookMissingAPIKey + } + + return webhook.NewDefaultNotifier(wc, provider) +} + +func createRedisClient(conf *config.Config) (redis.UniversalClient, error) { + if !conf.Redis.IsConfigured() { + return nil, nil + } + return redis2.GetRedisClient(&conf.Redis) +} + +func createStore(rc redis.UniversalClient) ObjectStore { + if rc != nil { + return NewRedisStore(rc) + } + return NewLocalStore() +} + +func getMessageBus(rc redis.UniversalClient) psrpc.MessageBus { + if rc == nil { + return psrpc.NewLocalMessageBus() + } + return psrpc.NewRedisMessageBus(rc) +} + +func getEgressStore(s ObjectStore) EgressStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + +func getIngressStore(s ObjectStore) IngressStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + +func getAgentStore(s ObjectStore) AgentStore { + switch store := s.(type) { + case *RedisStore: + return store + case *LocalStore: + return store + default: + return nil + } +} + +func getIngressConfig(conf *config.Config) *config.IngressConfig { + return &conf.Ingress +} + +func getSIPStore(s ObjectStore) SIPStore { + switch store := s.(type) { + case *RedisStore: + return store + default: + return nil + } +} + +func getSIPConfig(conf *config.Config) *config.SIPConfig { + return &conf.SIP +} + +func getLimitConf(config2 *config.Config) config.LimitConfig { + return config2.Limit +} + +func getRoomConfig(config2 *config.Config) config.RoomConfig { + return config2.Room +} + +func getSignalRelayConfig(config2 *config.Config) config.SignalRelayConfig { + return config2.SignalRelay +} + +func getPSRPCConfig(config2 *config.Config) rpc.PSRPCConfig { + return config2.PSRPC +} + +func getPSRPCClientParams(config2 rpc.PSRPCConfig, bus psrpc.MessageBus) rpc.ClientParams { + return rpc.NewClientParams(config2, bus, logger.GetLogger(), rpc.PSRPCMetricsObserver{}, otelpsrpc.ClientOptions(otelpsrpc.Config{})) +} + +func createForwardStats(conf *config.Config) *sfu.ForwardStats { + if conf.RTC.ForwardStats.SummaryInterval == 0 || conf.RTC.ForwardStats.ReportInterval == 0 || conf.RTC.ForwardStats.ReportWindow == 0 { + return nil + } + return sfu.NewForwardStats(conf.RTC.ForwardStats.SummaryInterval, conf.RTC.ForwardStats.ReportInterval, conf.RTC.ForwardStats.ReportWindow) +} + +func newInProcessTurnServer(conf *config.Config, authHandler turn.AuthHandler) (*turn.Server, error) { + return NewTurnServer(conf, authHandler, false) +} + +func getNodeStatsConfig(config2 *config.Config) config.NodeStatsConfig { + return config2.NodeStats +} + +func getAgentConfig(config2 *config.Config) agent.Config { + return config2.Agents +} diff --git a/livekit/pkg/service/wsprotocol.go b/livekit/pkg/service/wsprotocol.go new file mode 100644 index 0000000..3752753 --- /dev/null +++ b/livekit/pkg/service/wsprotocol.go @@ -0,0 +1,195 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "errors" + "io" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/rtc/types" +) + +const ( + pingFrequency = 10 * time.Second + pingTimeout = 2 * time.Second +) + +type WSSignalConnection struct { + conn types.WebsocketClient + mu sync.Mutex + useJSON bool +} + +func NewWSSignalConnection(conn types.WebsocketClient) *WSSignalConnection { + wsc := &WSSignalConnection{ + conn: conn, + mu: sync.Mutex{}, + useJSON: false, + } + go wsc.pingWorker() + return wsc +} + +func (c *WSSignalConnection) Close() error { + return c.conn.Close() +} + +func (c *WSSignalConnection) SetReadDeadline(deadline time.Time) error { + return c.conn.SetReadDeadline(deadline) +} + +func (c *WSSignalConnection) ReadRequest() (*livekit.SignalRequest, int, error) { + // handle special messages and pass on the rest + messageType, payload, err := c.conn.ReadMessage() + if err != nil { + return nil, 0, err + } + + msg := &livekit.SignalRequest{} + switch messageType { + case websocket.BinaryMessage: + if c.useJSON { + c.mu.Lock() + // switch to protobuf if client supports it + c.useJSON = false + c.mu.Unlock() + } + // protobuf encoded + err := proto.Unmarshal(payload, msg) + return msg, len(payload), err + case websocket.TextMessage: + c.mu.Lock() + // json encoded, also write back JSON + c.useJSON = true + c.mu.Unlock() + err := protojson.Unmarshal(payload, msg) + return msg, len(payload), err + default: + logger.Debugw("unsupported message", "message", messageType) + return nil, len(payload), nil + } +} + +func (c *WSSignalConnection) ReadWorkerMessage() (*livekit.WorkerMessage, int, error) { + // handle special messages and pass on the rest + messageType, payload, err := c.conn.ReadMessage() + if err != nil { + return nil, 0, err + } + + msg := &livekit.WorkerMessage{} + switch messageType { + case websocket.BinaryMessage: + if c.useJSON { + c.mu.Lock() + // switch to protobuf if client supports it + c.useJSON = false + c.mu.Unlock() + } + // protobuf encoded + err := proto.Unmarshal(payload, msg) + return msg, len(payload), err + case websocket.TextMessage: + c.mu.Lock() + // json encoded, also write back JSON + c.useJSON = true + c.mu.Unlock() + err := protojson.Unmarshal(payload, msg) + return msg, len(payload), err + default: + logger.Debugw("unsupported message", "message", messageType) + return nil, len(payload), nil + } +} + +func (c *WSSignalConnection) WriteResponse(msg *livekit.SignalResponse) (int, error) { + var msgType int + var payload []byte + var err error + + c.mu.Lock() + defer c.mu.Unlock() + + if c.useJSON { + msgType = websocket.TextMessage + payload, err = protojson.Marshal(msg) + } else { + msgType = websocket.BinaryMessage + payload, err = proto.Marshal(msg) + } + if err != nil { + return 0, err + } + + return len(payload), c.conn.WriteMessage(msgType, payload) +} + +func (c *WSSignalConnection) WriteServerMessage(msg *livekit.ServerMessage) (int, error) { + var msgType int + var payload []byte + var err error + + c.mu.Lock() + defer c.mu.Unlock() + + if c.useJSON { + msgType = websocket.TextMessage + payload, err = protojson.Marshal(msg) + } else { + msgType = websocket.BinaryMessage + payload, err = proto.Marshal(msg) + } + if err != nil { + return 0, err + } + + return len(payload), c.conn.WriteMessage(msgType, payload) +} + +func (c *WSSignalConnection) pingWorker() { + ticker := time.NewTicker(pingFrequency) + defer ticker.Stop() + + for range ticker.C { + err := c.conn.WriteControl(websocket.PingMessage, []byte(""), time.Now().Add(pingTimeout)) + if err != nil { + return + } + } +} + +// IsWebSocketCloseError checks that error is normal/expected closure +func IsWebSocketCloseError(err error) bool { + return errors.Is(err, io.EOF) || + strings.HasSuffix(err.Error(), "use of closed network connection") || + strings.HasSuffix(err.Error(), "connection reset by peer") || + websocket.IsCloseError( + err, + websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + ) +} diff --git a/livekit/pkg/sfu/NOTICE b/livekit/pkg/sfu/NOTICE new file mode 100644 index 0000000..3f39f9c --- /dev/null +++ b/livekit/pkg/sfu/NOTICE @@ -0,0 +1,16 @@ +Portions of this package originated from ion-sfu: https://github.com/pion/ion-sfu. + +MIT License + +Copyright (c) 2019 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +---------------------------------------------------------------------------------------- diff --git a/livekit/pkg/sfu/audio/audiolevel.go b/livekit/pkg/sfu/audio/audiolevel.go new file mode 100644 index 0000000..5f56079 --- /dev/null +++ b/livekit/pkg/sfu/audio/audiolevel.go @@ -0,0 +1,181 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audio + +import ( + "math" + "sync" +) + +const ( + silentAudioLevel = 127 + negInv20 = -1.0 / 20 +) + +// -------------------------------------- + +type AudioLevelConfig struct { + // minimum level to be considered active, 0-127, where 0 is loudest + ActiveLevel uint8 `yaml:"active_level,omitempty"` + // percentile to measure, a participant is considered active if it has exceeded the ActiveLevel more than + // MinPercentile% of the time + MinPercentile uint8 `yaml:"min_percentile,omitempty"` + // interval to update clients, in ms + UpdateInterval uint32 `yaml:"update_interval,omitempty"` + // smoothing for audioLevel values sent to the client. + // audioLevel will be an average of `smooth_intervals`, 0 to disable + SmoothIntervals uint32 `yaml:"smooth_intervals,omitempty"` +} + +var ( + DefaultAudioLevelConfig = AudioLevelConfig{ + ActiveLevel: 35, // -35dBov + MinPercentile: 40, + UpdateInterval: 400, + SmoothIntervals: 2, + } +) + +// -------------------------------------- + +type AudioLevelParams struct { + Config AudioLevelConfig + ClockRate uint32 +} + +// keeps track of audio level for a participant +type AudioLevel struct { + params AudioLevelParams + // min duration within an observe duration window to be considered active + minActiveDuration uint32 + smoothFactor float64 + activeThreshold float64 + + lock sync.Mutex + smoothedLevel float64 + + loudestObservedLevel uint8 + activeDuration uint32 // ms + observedDuration uint32 // ms + lastObservedAt int64 + + highestRTPTimestamp uint32 + highestRTPTimestampInitialized bool +} + +func NewAudioLevel(params AudioLevelParams) *AudioLevel { + l := &AudioLevel{ + params: params, + minActiveDuration: uint32(params.Config.MinPercentile) * params.Config.UpdateInterval / 100, + smoothFactor: 1, + activeThreshold: ConvertAudioLevel(float64(params.Config.ActiveLevel)), + loudestObservedLevel: silentAudioLevel, + } + + if l.params.Config.SmoothIntervals > 0 { + // exponential moving average (EMA), same center of mass with simple moving average (SMA) + l.smoothFactor = float64(2) / (float64(l.params.Config.SmoothIntervals + 1)) + } + + return l +} + +// Observes a new frame +func (l *AudioLevel) Observe(level uint8, durationMs uint32, arrivalTime int64) { + l.lock.Lock() + defer l.lock.Unlock() + + l.observeLocked(level, durationMs, arrivalTime) +} + +func (l *AudioLevel) observeLocked(level uint8, durationMs uint32, arrivalTime int64) { + l.lastObservedAt = arrivalTime + + l.observedDuration += durationMs + + if level <= l.params.Config.ActiveLevel { + l.activeDuration += durationMs + if l.loudestObservedLevel > level { + l.loudestObservedLevel = level + } + } + + if l.observedDuration >= l.params.Config.UpdateInterval { + smoothedLevel := float64(0.0) + // compute and reset + if l.activeDuration >= l.minActiveDuration { + // adjust loudest observed level by how much of the window was active. + // Weight will be 0 if active the entire duration + // > 0 if active for longer than observe duration + // < 0 if active for less than observe duration + activityWeight := 20 * math.Log10(float64(l.activeDuration)/float64(l.params.Config.UpdateInterval)) + adjustedLevel := float64(l.loudestObservedLevel) - activityWeight + linearLevel := ConvertAudioLevel(adjustedLevel) + + // exponential smoothing to dampen transients + smoothedLevel = l.smoothedLevel + (linearLevel-l.smoothedLevel)*l.smoothFactor + } + l.resetLocked(smoothedLevel) + } +} + +func (l *AudioLevel) ObserveWithRTPTimestamp(level uint8, ts uint32, arrivalTime int64) { + l.lock.Lock() + defer l.lock.Unlock() + + if !l.highestRTPTimestampInitialized { + l.highestRTPTimestampInitialized = true + l.highestRTPTimestamp = ts + } + + if (ts - l.highestRTPTimestamp) < (1 << 31) { + durationMs := (ts - l.highestRTPTimestamp) * 1e3 / l.params.ClockRate + l.observeLocked(level, durationMs, arrivalTime) + + l.highestRTPTimestamp = ts + } +} + +// returns current smoothed audio level +func (l *AudioLevel) GetLevel(now int64) (float64, bool) { + l.lock.Lock() + defer l.lock.Unlock() + + l.resetIfStaleLocked(now) + + return l.smoothedLevel, l.smoothedLevel >= l.activeThreshold +} + +func (l *AudioLevel) resetIfStaleLocked(arrivalTime int64) { + if (arrivalTime-l.lastObservedAt)/1e6 < int64(2*l.params.Config.UpdateInterval) { + return + } + + l.resetLocked(0.0) +} + +func (l *AudioLevel) resetLocked(smoothedLevel float64) { + l.smoothedLevel = smoothedLevel + l.loudestObservedLevel = silentAudioLevel + l.activeDuration = 0 + l.observedDuration = 0 +} + +// --------------------------------------------------- + +// convert decibel back to linear +func ConvertAudioLevel(level float64) float64 { + return math.Pow(10, level*negInv20) +} diff --git a/livekit/pkg/sfu/audio/audiolevel_test.go b/livekit/pkg/sfu/audio/audiolevel_test.go new file mode 100644 index 0000000..81afd85 --- /dev/null +++ b/livekit/pkg/sfu/audio/audiolevel_test.go @@ -0,0 +1,157 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package audio + +import ( + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const ( + samplesPerBatch = 25 + defaultActiveLevel = 30 + // requires two noisy samples to count + defaultPercentile = 10 + defaultObserveDuration = 500 // ms +) + +func TestAudioLevel(t *testing.T) { + t.Run("initially to return not noisy, within a few samples", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + _, noisy := a.GetLevel(clock.UnixNano()) + require.False(t, noisy) + + observeSamples(a, 28, 5, clock) + clock = clock.Add(5 * 20 * time.Millisecond) + + _, noisy = a.GetLevel(clock.UnixNano()) + require.False(t, noisy) + }) + + t.Run("not noisy when all samples are below threshold", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamples(a, 35, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + + _, noisy := a.GetLevel(clock.UnixNano()) + require.False(t, noisy) + }) + + t.Run("not noisy when less than percentile samples are above threshold", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamples(a, 35, samplesPerBatch-2, clock) + clock = clock.Add((samplesPerBatch - 2) * 20 * time.Millisecond) + observeSamples(a, 25, 1, clock) + clock = clock.Add(20 * time.Millisecond) + observeSamples(a, 35, 1, clock) + clock = clock.Add(20 * time.Millisecond) + + _, noisy := a.GetLevel(clock.UnixNano()) + require.False(t, noisy) + }) + + t.Run("noisy when higher than percentile samples are above threshold", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamples(a, 35, samplesPerBatch-16, clock) + clock = clock.Add((samplesPerBatch - 16) * 20 * time.Millisecond) + observeSamples(a, 25, 8, clock) + clock = clock.Add(8 * 20 * time.Millisecond) + observeSamples(a, 29, 8, clock) + clock = clock.Add(8 * 20 * time.Millisecond) + + level, noisy := a.GetLevel(clock.UnixNano()) + require.True(t, noisy) + require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) + require.Less(t, level, ConvertAudioLevel(float64(25))) + }) + + t.Run("not noisy when samples are stale", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamples(a, 25, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + level, noisy := a.GetLevel(clock.UnixNano()) + require.True(t, noisy) + require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) + require.Less(t, level, ConvertAudioLevel(float64(20))) + + // let enough time pass to make the samples stale + clock = clock.Add(1500 * time.Millisecond) + level, noisy = a.GetLevel(clock.UnixNano()) + require.Equal(t, float64(0.0), level) + require.False(t, noisy) + }) + + t.Run("not noisy when samples are stale - with RTP timestamp", func(t *testing.T) { + clock := time.Now() + a := createAudioLevel(defaultActiveLevel, defaultPercentile, defaultObserveDuration) + + observeSamplesWithRTPTimestamp(a, 25, 100, clock) + clock = clock.Add(100 * 20 * time.Millisecond) + level, noisy := a.GetLevel(clock.UnixNano()) + require.True(t, noisy) + require.Greater(t, level, ConvertAudioLevel(float64(defaultActiveLevel))) + require.Less(t, level, ConvertAudioLevel(float64(20))) + + // let enough time pass to make the samples stale + clock = clock.Add(1500 * time.Millisecond) + level, noisy = a.GetLevel(clock.UnixNano()) + require.Equal(t, float64(0.0), level) + require.False(t, noisy) + }) +} + +func createAudioLevel(activeLevel uint8, minPercentile uint8, observeDuration uint32) *AudioLevel { + return NewAudioLevel(AudioLevelParams{ + Config: AudioLevelConfig{ + ActiveLevel: activeLevel, + MinPercentile: minPercentile, + UpdateInterval: observeDuration, + }, + ClockRate: 48000, + }) +} + +func observeSamples(a *AudioLevel, level uint8, count int, baseTime time.Time) { + for i := range count { + a.Observe(level, 20, baseTime.Add(time.Duration(i*20)*time.Millisecond).UnixNano()) + } +} + +func observeSamplesWithRTPTimestamp(a *AudioLevel, level uint8, count int, baseTime time.Time) { + sampleTS := uint32(rand.Intn(1 << 20)) + sampleTime := baseTime + for i := range count { + if (i % 5) == 0 { + // out-of-order sample + a.ObserveWithRTPTimestamp(level, sampleTS-1920, sampleTime.UnixNano()) + } + a.ObserveWithRTPTimestamp(level, sampleTS, sampleTime.UnixNano()) + sampleTS += 960 // 20 ms at 48 kHz + sampleTime = sampleTime.Add(20 * time.Millisecond) + } +} diff --git a/livekit/pkg/sfu/buffer/buffer.go b/livekit/pkg/sfu/buffer/buffer.go new file mode 100644 index 0000000..e904cb3 --- /dev/null +++ b/livekit/pkg/sfu/buffer/buffer.go @@ -0,0 +1,457 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "encoding/binary" + "errors" + "io" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + + sutils "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/mediatransportutil/pkg/bucket" + "github.com/livekit/mediatransportutil/pkg/twcc" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/mono" +) + +const ( + rtcpReceiverReportDelta = 1e9 + + InitPacketBufferSizeVideo = 300 + InitPacketBufferSizeAudio = 70 +) + +var ( + errInvalidCodec = errors.New("invalid codec") +) + +var _ BufferProvider = (*Buffer)(nil) + +type pendingPacket struct { + arrivalTime int64 + packet []byte +} + +// Buffer contains all packets +type Buffer struct { + *BufferBase + + pPackets []pendingPacket + lastReportAt int64 + isBound bool + + twcc *twcc.Responder + twccExtID uint8 + + enableAudioLossProxying bool + lastFractionLostToReport uint8 // Last fraction lost from subscribers, should report to publisher; Audio only + + lastPacketRead int + + // callbacks + onClose func() + onRtcpFeedback func([]rtcp.Packet) + onFinalRtpStats func(*livekit.RTPStats) + onNotifyRTX func(uint32, uint32, string) + + primaryBufferForRTX *Buffer + rtxPktBuf []byte +} + +func NewBuffer(ssrc uint32, maxVideoPkts, maxAudioPkts int) *Buffer { + b := &Buffer{} + b.BufferBase = NewBufferBase(BufferBaseParams{ + SSRC: ssrc, + MaxVideoPkts: maxVideoPkts, + MaxAudioPkts: maxAudioPkts, + LoggerComponents: []string{sutils.ComponentPub, sutils.ComponentSFU}, + SendPLI: b.sendPLI, + IsReportingEnabled: true, + }) + return b +} + +func (b *Buffer) SetTWCCAndExtID(twcc *twcc.Responder, extID uint8) { + b.Lock() + defer b.Unlock() + + b.twcc = twcc + b.twccExtID = extID +} + +func (b *Buffer) SetAudioLossProxying(enable bool) { + b.Lock() + defer b.Unlock() + + b.enableAudioLossProxying = enable +} + +func (b *Buffer) Bind(params webrtc.RTPParameters, codec webrtc.RTPCodecCapability, bitrates int) error { + b.Lock() + defer b.Unlock() + if b.isBound { + return nil + } + + if err := b.BufferBase.BindLocked(params, codec, bitrates); err != nil { + return err + } + + b.lastReportAt = mono.UnixNano() + + if len(b.pPackets) != 0 { + b.logger.Debugw("releasing queued packets on bind", "count", len(b.pPackets)) + } + for _, pp := range b.pPackets { + b.calc(pp.packet, nil, pp.arrivalTime, true, false) + } + b.pPackets = nil + + b.isBound = true + + return nil +} + +// Write adds an RTP Packet, ordering is not guaranteed, newer packets may arrive later +// +//go:noinline +func (b *Buffer) Write(pkt []byte) (n int, err error) { + var rtpPacket rtp.Packet + err = rtpPacket.Unmarshal(pkt) + if err != nil { + return + } + + b.Lock() + if b.BufferBase.IsClosed() { + b.Unlock() + err = io.EOF + return + } + + now := mono.UnixNano() + if b.twcc != nil && b.twccExtID != 0 { + if ext := rtpPacket.GetExtension(b.twccExtID); ext != nil { + b.twcc.Push(rtpPacket.SSRC, binary.BigEndian.Uint16(ext[0:2]), now, rtpPacket.Marker) + } + } + + // libwebrtc will use 0 ssrc for probing, don't push the packet to pending queue to avoid memory increasing since + // the Bind will not be called to consume the pending packets. More details in https://github.com/pion/webrtc/pull/2816 + if rtpPacket.SSRC == 0 { + b.Unlock() + return + } + + // handle RTX packet + if pb := b.primaryBufferForRTX; pb != nil { + b.Unlock() + + // skip padding only packets + if rtpPacket.Padding && len(rtpPacket.Payload) == 0 { + return + } + + pb.writeRTX(&rtpPacket, now) + return + } + + if !b.isBound { + packet := make([]byte, len(pkt)) + copy(packet, pkt) + + if len(b.pPackets) == 0 { + b.logger.Debugw("received first packet") + } + + startIdx := 0 + overflow := len(b.pPackets) - max(b.BufferBase.MaxVideoPkts(), b.BufferBase.MaxAudioPkts()) + if overflow > 0 { + startIdx = overflow + } + b.pPackets = append(b.pPackets[startIdx:], pendingPacket{ + packet: packet, + arrivalTime: now, + }) + + b.BufferBase.NotifyRead() + b.Unlock() + return + } + + b.calc(pkt, &rtpPacket, now, false, false) + b.Unlock() + return +} + +func (b *Buffer) SetPrimaryBufferForRTX(primaryBuffer *Buffer) { + b.Lock() + b.primaryBufferForRTX = primaryBuffer + pkts := b.pPackets + b.pPackets = nil + b.Unlock() + + for _, pp := range pkts { + var rtpPacket rtp.Packet + err := rtpPacket.Unmarshal(pp.packet) + if err != nil { + continue + } + if rtpPacket.Padding && len(rtpPacket.Payload) == 0 { + continue + } + primaryBuffer.writeRTX(&rtpPacket, pp.arrivalTime) + } +} + +func (b *Buffer) NotifyRTX(ssrc uint32, repairSSRC uint32, rsid string) { + if onNotifyRTX := b.getOnNotifyRTX(); onNotifyRTX != nil { + onNotifyRTX(ssrc, repairSSRC, rsid) + } +} + +func (b *Buffer) writeRTX(rtxPkt *rtp.Packet, arrivalTime int64) { + b.Lock() + defer b.Unlock() + if !b.isBound { + return + } + + if rtxPkt.PayloadType != b.rtxPayloadType { + b.logger.Debugw("unexpected rtx payload type", "expected", b.rtxPayloadType, "actual", rtxPkt.PayloadType) + return + } + + if b.rtxPktBuf == nil { + b.rtxPktBuf = make([]byte, bucket.RTPMaxPktSize) + } + + if len(rtxPkt.Payload) < 2 { + b.logger.Warnw("rtx payload too short", nil, "size", len(rtxPkt.Payload)) + return + } + + repairedPkt := *rtxPkt + repairedPkt.PayloadType = b.payloadType + repairedPkt.SequenceNumber = binary.BigEndian.Uint16(rtxPkt.Payload[:2]) + repairedPkt.SSRC = b.BufferBase.SSRC() + repairedPkt.Payload = rtxPkt.Payload[2:] + n, err := repairedPkt.MarshalTo(b.rtxPktBuf) + if err != nil { + b.logger.Errorw("could not marshal repaired packet", err, "ssrc", b.BufferBase.SSRC(), "sn", repairedPkt.SequenceNumber) + return + } + + b.calc(b.rtxPktBuf[:n], &repairedPkt, arrivalTime, false, true) +} + +func (b *Buffer) Read(buff []byte) (n int, err error) { + b.Lock() + for { + if b.BufferBase.IsClosed() { + b.Unlock() + return 0, io.EOF + } + + if b.pPackets != nil && len(b.pPackets) > b.lastPacketRead { + if len(buff) < len(b.pPackets[b.lastPacketRead].packet) { + b.Unlock() + return 0, bucket.ErrBufferTooSmall + } + + n = copy(buff, b.pPackets[b.lastPacketRead].packet) + b.lastPacketRead++ + b.Unlock() + return + } + b.BufferBase.WaitRead() + } +} + +func (b *Buffer) Close() error { + stats, err := b.BufferBase.CloseWithReason("close") + if err != nil { + return err + } + + if stats != nil { + if cb := b.getOnFinalRtpStats(); cb != nil { + cb(stats) + } + } + + if cb := b.getOnClose(); cb != nil { + cb() + } + + return nil +} + +func (b *Buffer) OnClose(fn func()) { + b.Lock() + b.onClose = fn + b.Unlock() +} + +func (b *Buffer) getOnClose() func() { + b.RLock() + defer b.RUnlock() + + return b.onClose +} + +func (b *Buffer) sendPLI() { + ssrc := b.BufferBase.SSRC() + if ssrc == 0 { + return + } + + b.logger.Debugw("send pli", "mediaSSRC", ssrc) + pli := []rtcp.Packet{ + &rtcp.PictureLossIndication{ + SenderSSRC: ssrc, + MediaSSRC: ssrc, + }, + } + + if cb := b.getOnRtcpFeedback(); cb != nil { + cb(pli) + } +} + +func (b *Buffer) calc(rawPkt []byte, rtpPacket *rtp.Packet, arrivalTime int64, isBuffered bool, isRTX bool) { + b.BufferBase.HandleIncomingPacketLocked( + rawPkt, + rtpPacket, + arrivalTime, + isBuffered, + isRTX, + nil, + 0, + ) + + b.doNACKs() + + b.doReports(arrivalTime) +} + +func (b *Buffer) doNACKs() { + if r := b.buildNACKPacket(); r != nil { + if cb := b.onRtcpFeedback; cb != nil { + cb(r) + } + } +} + +func (b *Buffer) buildNACKPacket() []rtcp.Packet { + if nacks := b.BufferBase.GetNACKPairsLocked(); len(nacks) > 0 { + ssrc := b.BufferBase.SSRC() + pkts := []rtcp.Packet{&rtcp.TransportLayerNack{ + SenderSSRC: ssrc, + MediaSSRC: ssrc, + Nacks: nacks, + }} + return pkts + } + return nil +} + +func (b *Buffer) doReports(arrivalTime int64) { + if arrivalTime-b.lastReportAt < rtcpReceiverReportDelta { + return + } + b.lastReportAt = arrivalTime + + // RTCP reports + pkts := b.getRTCP() + if pkts != nil { + if cb := b.onRtcpFeedback; cb != nil { + cb(pkts) + } + } +} + +func (b *Buffer) getRTCP() []rtcp.Packet { + var pkts []rtcp.Packet + + rr := b.buildReceptionReport() + if rr != nil { + pkts = append(pkts, &rtcp.ReceiverReport{ + SSRC: b.BufferBase.SSRC(), + Reports: []rtcp.ReceptionReport{*rr}, + }) + } + + return pkts +} + +func (b *Buffer) buildReceptionReport() *rtcp.ReceptionReport { + proxyLoss := b.lastFractionLostToReport + if b.codecType == webrtc.RTPCodecTypeAudio && !b.enableAudioLossProxying { + proxyLoss = 0 + } + + return b.BufferBase.GetRtcpReceptionReportLocked(proxyLoss) +} + +func (b *Buffer) SetLastFractionLostReport(lost uint8) { + b.Lock() + defer b.Unlock() + + b.lastFractionLostToReport = lost +} + +func (b *Buffer) OnRtcpFeedback(fn func(fb []rtcp.Packet)) { + b.Lock() + b.onRtcpFeedback = fn + b.Unlock() +} + +func (b *Buffer) getOnRtcpFeedback() func(fb []rtcp.Packet) { + b.RLock() + defer b.RUnlock() + + return b.onRtcpFeedback +} + +func (b *Buffer) OnFinalRtpStats(fn func(*livekit.RTPStats)) { + b.Lock() + b.onFinalRtpStats = fn + b.Unlock() +} + +func (b *Buffer) getOnFinalRtpStats() func(*livekit.RTPStats) { + b.RLock() + defer b.RUnlock() + + return b.onFinalRtpStats +} + +func (b *Buffer) OnNotifyRTX(fn func(ssrc uint32, repairSSRC uint32, rsid string)) { + b.Lock() + b.onNotifyRTX = fn + b.Unlock() +} + +func (b *Buffer) getOnNotifyRTX() func(ssrc uint32, repairSSRC uint32, rsid string) { + b.RLock() + defer b.RUnlock() + + return b.onNotifyRTX +} diff --git a/livekit/pkg/sfu/buffer/buffer_base.go b/livekit/pkg/sfu/buffer/buffer_base.go new file mode 100644 index 0000000..ec4cf80 --- /dev/null +++ b/livekit/pkg/sfu/buffer/buffer_base.go @@ -0,0 +1,1551 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "errors" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/gammazero/deque" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/rtp/codecs" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/sfu/audio" + "github.com/livekit/livekit-server/pkg/sfu/mime" + act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/mediatransportutil/pkg/bucket" + "github.com/livekit/mediatransportutil/pkg/nack" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +var ( + ExtPacketFactory = &sync.Pool{ + New: func() any { + return &ExtPacket{} + }, + } +) + +func ReleaseExtPacket(extPkt *ExtPacket) { + if extPkt == nil { + return + } + + ReleaseExtDependencyDescriptor(extPkt.DependencyDescriptor) + + *extPkt = ExtPacket{} + ExtPacketFactory.Put(extPkt) +} + +// -------------------------------------- + +type ExtPacket struct { + VideoLayer + Arrival int64 + ExtSequenceNumber uint64 + ExtTimestamp uint64 + Packet *rtp.Packet + Payload any + IsKeyFrame bool + RawPacket []byte + DependencyDescriptor *ExtDependencyDescriptor + AbsCaptureTimeExt *act.AbsCaptureTime + IsOutOfOrder bool + IsBuffered bool +} + +// VideoSize represents video resolution +type VideoSize struct { + Width uint32 + Height uint32 +} + +type BufferProvider interface { + SetLogger(lgr logger.Logger) + SetAudioLevelConfig(audioLevelConfig audio.AudioLevelConfig) + SetStreamRestartDetection(enable bool) + SetPLIThrottle(duration int64) + SetRTT(rtt uint32) + SetPaused(paused bool) + + SendPLI(force bool) + + ReadExtended(buf []byte) (*ExtPacket, error) + GetPacket(buf []byte, esn uint64) (int, error) + + GetAudioLevel() (float64, bool) + GetTemporalLayerFpsForSpatial(layer int32) []float32 + GetStats() *livekit.RTPStats + GetDeltaStats() *StreamStatsWithLayers + GetDeltaStatsLite() *rtpstats.RTPDeltaInfoLite + GetLastSenderReportTime() time.Time + GetNACKPairs() []rtcp.NackPair + + SetSenderReportData(srData *livekit.RTCPSenderReportState) + GetSenderReportData() *livekit.RTCPSenderReportState + + OnRtcpSenderReport(fn func()) + OnFpsChanged(fn func()) + OnVideoSizeChanged(fn func([]VideoSize)) + OnCodecChange(fn func(webrtc.RTPCodecParameters)) + OnStreamRestart(fn func(string)) + + StartKeyFrameSeeder() + StopKeyFrameSeeder() + + HandleIncomingPacket( + rawPkt []byte, + rtpPacket *rtp.Packet, + arrivalTime int64, + isBuffered bool, + isRTX bool, + skippedSeqs []uint16, + oobSequenceNumber uint16, + ) (uint64, error) + + RestartStream(reason string) + + CloseWithReason(reason string) (*livekit.RTPStats, error) +} + +const ( + bucketCapCheckInterval = 1e9 +) + +type BufferBaseParams struct { + SSRC uint32 + MaxVideoPkts int + MaxAudioPkts int + LoggerComponents []string + SendPLI func() + IsReportingEnabled bool + IsOOBSequenceNumber bool + IsDDRestartEnabled bool +} + +type BufferBase struct { + sync.RWMutex + + params BufferBaseParams + + readCond *sync.Cond + + bucket *bucket.Bucket[uint64, uint16] + lastBucketCapCheckAt int64 + + nacker *nack.NackQueue + rtpStatsLite *rtpstats.RTPStatsReceiverLite + liteStatsSnapshotId uint32 + + extPackets deque.Deque[*ExtPacket] + + codecType webrtc.RTPCodecType + closeOnce sync.Once + clockRate uint32 + mime mime.MimeType + + rtpParameters webrtc.RTPParameters + payloadType uint8 + rtxPayloadType uint8 + + snRangeMap *utils.RangeMap[uint64, uint64] + + audioLevelConfig audio.AudioLevelConfig + audioLevel *audio.AudioLevel + audioLevelExtID uint8 + + enableStreamRestartDetection bool + + pliThrottle int64 + + rtpStats *rtpstats.RTPStatsReceiver + ppsSnapshotId uint32 + rrSnapshotId uint32 + deltaStatsSnapshotId uint32 + + // callbacks + onRtcpSenderReport func() + onFpsChanged func() + onVideoSizeChanged func([]VideoSize) + onCodecChange func(webrtc.RTPCodecParameters) + onStreamRestart func(string) + + // video size tracking for multiple spatial layers + currentVideoSize [DefaultMaxLayerSpatial + 1]VideoSize + + logger logger.Logger + + // dependency descriptor + ddExtID uint8 + ddParser *DependencyDescriptorParser + + isPaused bool + frameRateCalculator [DefaultMaxLayerSpatial + 1]FrameRateCalculator + frameRateCalculated bool + + packetNotFoundCount atomic.Uint32 + packetTooOldCount atomic.Uint32 + extPacketTooMuchCount atomic.Uint32 + + absCaptureTimeExtID uint8 + + keyFrameSeederGeneration atomic.Int32 + + isRestartPending bool + + isClosed atomic.Bool +} + +func NewBufferBase(params BufferBaseParams) *BufferBase { + l := logger.GetLogger() // will be reset with correct context via SetLogger + for _, component := range params.LoggerComponents { + l = l.WithComponent(component) + } + l = l.WithValues("ssrc", params.SSRC) + + b := &BufferBase{ + params: params, + lastBucketCapCheckAt: mono.UnixNano(), + snRangeMap: utils.NewRangeMap[uint64, uint64](100), + pliThrottle: int64(500 * time.Millisecond), + logger: l, + } + b.readCond = sync.NewCond(&b.RWMutex) + b.extPackets.SetBaseCap(128) + return b +} + +func (b *BufferBase) SSRC() uint32 { + return b.params.SSRC +} + +func (b *BufferBase) MaxVideoPkts() int { + return b.params.MaxVideoPkts +} + +func (b *BufferBase) MaxAudioPkts() int { + return b.params.MaxAudioPkts +} + +func (b *BufferBase) SetLogger(lgr logger.Logger) { + b.Lock() + defer b.Unlock() + + for _, component := range b.params.LoggerComponents { + lgr = lgr.WithComponent(component) + } + lgr = lgr.WithValues("ssrc", b.params.SSRC) + b.logger = lgr + + if b.rtpStats != nil { + b.rtpStats.SetLogger(b.logger) + } + + if b.rtpStatsLite != nil { + b.rtpStatsLite.SetLogger(b.logger) + } +} + +func (b *BufferBase) Bind(rtpParameters webrtc.RTPParameters, codec webrtc.RTPCodecCapability, bitrate int) error { + b.Lock() + defer b.Unlock() + + return b.BindLocked(rtpParameters, codec, bitrate) +} + +func (b *BufferBase) BindLocked(rtpParameters webrtc.RTPParameters, codec webrtc.RTPCodecCapability, bitrate int) error { + b.logger.Debugw("binding track") + if codec.ClockRate == 0 { + b.logger.Warnw("invalid codec", nil, "rtpParameters", rtpParameters, "codec", codec, "bitrate", bitrate) + return errInvalidCodec + } + + b.setupRTPStats(codec.ClockRate) + + b.clockRate = codec.ClockRate + b.mime = mime.NormalizeMimeType(codec.MimeType) + b.rtpParameters = rtpParameters + for _, codecParameter := range rtpParameters.Codecs { + if mime.IsMimeTypeStringEqual(codecParameter.MimeType, codec.MimeType) { + b.payloadType = uint8(codecParameter.PayloadType) + break + } + } + + if b.payloadType == 0 && !mime.IsMimeTypeStringEqual(codec.MimeType, webrtc.MimeTypePCMU) { + b.logger.Warnw( + "could not find payload type for codec", nil, + "codec", codec.MimeType, + "rtpParameters", rtpParameters, + ) + b.payloadType = uint8(rtpParameters.Codecs[0].PayloadType) + } + + // find RTX payload type + for _, codec := range rtpParameters.Codecs { + if mime.IsMimeTypeStringRTX(codec.MimeType) && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", b.payloadType)) { + b.rtxPayloadType = uint8(codec.PayloadType) + break + } + } + + for _, ext := range rtpParameters.HeaderExtensions { + switch ext.URI { + case dd.ExtensionURI: + if b.ddExtID != 0 { + b.logger.Warnw( + "multiple dependency descriptor extensions found", nil, + "id", ext.ID, + "previous", b.ddExtID, + ) + continue + } + b.ddExtID = uint8(ext.ID) + b.createDDParserAndFrameRateCalculator() + + case sdp.AudioLevelURI: + b.audioLevelExtID = uint8(ext.ID) + b.audioLevel = audio.NewAudioLevel(audio.AudioLevelParams{ + Config: b.audioLevelConfig, + ClockRate: b.clockRate, + }) + + case act.AbsCaptureTimeURI: + b.absCaptureTimeExtID = uint8(ext.ID) + } + } + + switch { + case mime.IsMimeTypeAudio(b.mime): + b.codecType = webrtc.RTPCodecTypeAudio + b.bucket = bucket.NewBucket[uint64, uint16]( + InitPacketBufferSizeAudio, + bucket.RTPMaxPktSize, + bucket.RTPSeqNumOffset, + ) + + case mime.IsMimeTypeVideo(b.mime): + b.codecType = webrtc.RTPCodecTypeVideo + b.bucket = bucket.NewBucket[uint64, uint16]( + InitPacketBufferSizeVideo, + bucket.RTPMaxPktSize, + bucket.RTPSeqNumOffset, + ) + + if b.frameRateCalculator[0] == nil { + b.createFrameRateCalculator() + } + + if bitrate > 0 { + pps := bitrate / 8 / 1200 + for pps > b.bucket.Capacity() { + if b.bucket.Grow() >= b.params.MaxVideoPkts { + break + } + } + } + + default: + b.codecType = webrtc.RTPCodecType(0) + } + + for _, fb := range codec.RTCPFeedback { + switch fb.Type { + case webrtc.TypeRTCPFBGoogREMB: + b.logger.Debugw("Setting feedback", "type", webrtc.TypeRTCPFBGoogREMB) + b.logger.Debugw("REMB not supported, RTCP feedback will not be generated") + + case webrtc.TypeRTCPFBNACK: + // pion uses a single mediaengine to manage negotiated codecs of peerconnection, that means we can't have different + // codec settings at track level for same codec type, so enable nack for all audio receivers but don't create nack queue + // for red codec. + if b.mime == mime.MimeTypeRED { + break + } + + b.logger.Debugw("Setting feedback", "type", webrtc.TypeRTCPFBNACK) + b.nacker = nack.NewNACKQueue(nack.NackQueueParamsDefault) + } + } + + if b.nacker == nil && b.params.IsOOBSequenceNumber { + b.nacker = nack.NewNACKQueue(nack.NackQueueParamsDefault) + } + + b.StartKeyFrameSeeder() + + return nil +} + +func (b *BufferBase) CloseWithReason(reason string) (stats *livekit.RTPStats, err error) { + b.closeOnce.Do(func() { + b.isClosed.Store(true) + + b.StopKeyFrameSeeder() + + b.Lock() + stats, _ = b.stopRTPStats(reason) + b.readCond.Broadcast() + b.Unlock() + + go b.flushExtPackets() + }) + return +} + +func (b *BufferBase) IsClosed() bool { + return b.isClosed.Load() +} + +func (b *BufferBase) SetPaused(paused bool) { + b.Lock() + defer b.Unlock() + + b.isPaused = paused +} + +func (b *BufferBase) SetAudioLevelConfig(audioLevelConfig audio.AudioLevelConfig) { + b.Lock() + defer b.Unlock() + + b.audioLevelConfig = audioLevelConfig +} + +func (b *BufferBase) SetStreamRestartDetection(enable bool) { + b.Lock() + defer b.Unlock() + + b.enableStreamRestartDetection = enable +} + +func (b *BufferBase) setupRTPStats(clockRate uint32) { + b.rtpStats = rtpstats.NewRTPStatsReceiver(rtpstats.RTPStatsParams{ + ClockRate: clockRate, + Logger: b.logger, + }) + b.ppsSnapshotId = b.rtpStats.NewSnapshotId() + if b.params.IsReportingEnabled { + b.rrSnapshotId = b.rtpStats.NewSnapshotId() + b.deltaStatsSnapshotId = b.rtpStats.NewSnapshotId() + } + + if b.params.IsOOBSequenceNumber { + b.rtpStatsLite = rtpstats.NewRTPStatsReceiverLite(rtpstats.RTPStatsParams{ + ClockRate: clockRate, + Logger: b.logger, + }) + b.liteStatsSnapshotId = b.rtpStatsLite.NewSnapshotLiteId() + } +} + +func (b *BufferBase) stopRTPStats(reason string) (stats *livekit.RTPStats, statsLite *livekit.RTPStats) { + if b.rtpStats != nil { + b.rtpStats.Stop() + stats = b.rtpStats.ToProto() + } + if b.rtpStatsLite != nil { + b.rtpStatsLite.Stop() + statsLite = b.rtpStatsLite.ToProto() + } + + b.logger.Debugw( + "rtp stats", + "direction", "upstream", + "stats", b.rtpStats, + "statsLite", b.rtpStatsLite, + "reason", reason, + ) + return +} + +func (b *BufferBase) RestartStream(reason string) { + b.Lock() + defer b.Unlock() + + b.restartStreamLocked(reason, false) + b.readCond.Broadcast() +} + +func (b *BufferBase) restartStreamLocked(reason string, isDetected bool) { + b.logger.Infow("stream restart", "reason", reason) + + // stop + b.StopKeyFrameSeeder() + b.stopRTPStats("stream-restart") + b.flushExtPacketsLocked() + + // restart + b.snRangeMap = utils.NewRangeMap[uint64, uint64](100) + b.setupRTPStats(b.clockRate) + + b.bucket.ResyncOnNextPacket() + b.lastBucketCapCheckAt = mono.UnixNano() + + if b.nacker != nil { + b.nacker = nack.NewNACKQueue(nack.NackQueueParamsDefault) + } + + if b.audioLevel != nil { + b.audioLevel = audio.NewAudioLevel(audio.AudioLevelParams{ + Config: b.audioLevelConfig, + ClockRate: b.clockRate, + }) + } + + if b.ddExtID != 0 { + b.createDDParserAndFrameRateCalculator() + } + + b.frameRateCalculated = false + if b.frameRateCalculator[0] == nil { + b.createFrameRateCalculator() + } + + b.StartKeyFrameSeeder() + + b.isRestartPending = true + + if f := b.onStreamRestart; f != nil && isDetected { + go f(reason) + } +} + +func (b *BufferBase) createDDParserAndFrameRateCalculator() { + if mime.IsMimeTypeSVCCapable(b.mime) || b.mime == mime.MimeTypeVP8 { + frc := NewFrameRateCalculatorDD(b.clockRate, b.logger) + for i := range b.frameRateCalculator { + b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) + } + b.ddParser = NewDependencyDescriptorParser( + b.ddExtID, + b.logger, + func(spatial, temporal int32) { + frc.SetMaxLayer(spatial, temporal) + }, + b.params.IsDDRestartEnabled, + ) + } +} + +func (b *BufferBase) createFrameRateCalculator() { + switch b.mime { + case mime.MimeTypeVP8: + b.frameRateCalculator[0] = NewFrameRateCalculatorVP8(b.clockRate, b.logger) + + case mime.MimeTypeVP9: + frc := NewFrameRateCalculatorVP9(b.clockRate, b.logger) + for i := range b.frameRateCalculator { + b.frameRateCalculator[i] = frc.GetFrameRateCalculatorForSpatial(int32(i)) + } + + case mime.MimeTypeH265: + b.frameRateCalculator[0] = NewFrameRateCalculatorH26x(b.clockRate, b.logger) + } +} + +func (b *BufferBase) ReadExtended(buf []byte) (*ExtPacket, error) { + b.Lock() + for { + if b.isClosed.Load() { + b.Unlock() + return nil, io.EOF + } + + if b.isRestartPending { + b.isRestartPending = false + b.Unlock() + return nil, nil + } + + if b.extPackets.Len() > 0 { + ep := b.extPackets.PopFront() + patched := b.patchExtPacket(ep, buf) + if patched == nil { + ReleaseExtPacket(ep) + continue + } + + b.Unlock() + return patched, nil + } + + b.readCond.Wait() + } +} + +func (b *BufferBase) SetPLIThrottle(duration int64) { + b.Lock() + defer b.Unlock() + + b.pliThrottle = duration +} + +func (b *BufferBase) SendPLI(force bool) { + b.RLock() + if b.codecType != webrtc.RTPCodecTypeVideo { + b.RUnlock() + return + } + + rtpStats := b.rtpStats + pliThrottle := b.pliThrottle + b.RUnlock() + + if (rtpStats == nil && !force) || !rtpStats.CheckAndUpdatePli(pliThrottle, force) { + return + } + + if b.params.SendPLI != nil { + b.params.SendPLI() + } +} + +func (b *BufferBase) SetRTT(rtt uint32) { + b.Lock() + defer b.Unlock() + + if rtt == 0 { + return + } + + if b.nacker != nil { + b.nacker.SetRTT(rtt) + } + + if b.rtpStats != nil { + b.rtpStats.UpdateRtt(rtt) + } +} + +func (b *BufferBase) WaitRead() { + b.readCond.Wait() +} + +func (b *BufferBase) NotifyRead() { + b.readCond.Broadcast() +} + +func (b *BufferBase) HandleIncomingPacket( + rawPkt []byte, + rtpPacket *rtp.Packet, + arrivalTime int64, + isBuffered bool, + isRTX bool, + skippedSeqs []uint16, + oobSequenceNumber uint16, +) (uint64, error) { + b.Lock() + defer b.Unlock() + + if b.isClosed.Load() { + return 0, io.EOF + } + + return b.HandleIncomingPacketLocked( + rawPkt, + rtpPacket, + arrivalTime, + isBuffered, + isRTX, + skippedSeqs, + oobSequenceNumber, + ) +} + +func (b *BufferBase) HandleIncomingPacketLocked( + rawPkt []byte, + rtpPacket *rtp.Packet, + arrivalTime int64, + isBuffered bool, + isRTX bool, + skippedSeqs []uint16, + oobSequenceNumber uint16, +) (uint64, error) { + if rtpPacket == nil { + rtpPacket = &rtp.Packet{} + if err := rtpPacket.Unmarshal(rawPkt); err != nil { + b.logger.Errorw("could not unmarshal RTP packet", err) + return 0, err + } + } + + b.processAudioSsrcLevelHeaderExtension(rtpPacket, arrivalTime) + + if len(skippedSeqs) > 0 { + skippedRtpPkt := rtp.Packet{ + Header: rtpPacket.Header, + } + skippedRtpPkt.Marker = false + // Use the current highest timestamp to prevent the case of old sequence number and newer timestamp. + // It is possible that the skipped packet is older. An example sequence + // - Packet 10, skipped 6, 7, 9 -> Packet 8 is unknown at this point + // - Packet 11, skipped 8 -> this would cause sequence number be older, but using timestamp from Packet 11 will make time stamp diff +ve + skippedRtpPkt.Timestamp = b.rtpStats.HighestTimestamp() + for _, sn := range skippedSeqs { + skippedRtpPkt.SequenceNumber = sn + flowState := b.rtpStats.Update( + arrivalTime, + skippedRtpPkt.Header.SequenceNumber, + skippedRtpPkt.Header.Timestamp, + skippedRtpPkt.Header.Marker, + skippedRtpPkt.Header.MarshalSize(), + len(skippedRtpPkt.Payload), + int(skippedRtpPkt.PaddingSize), + ) + if flowState.UnhandledReason == rtpstats.RTPFlowUnhandledReasonNone && !flowState.IsOutOfOrder { + if err := b.snRangeMap.ExcludeRange(flowState.ExtSequenceNumber, flowState.ExtSequenceNumber+1); err != nil { + b.logger.Errorw( + "could not exclude range", err, + "sequenceNumber", sn, + "extSequenceNumber", flowState.ExtSequenceNumber, + "rtpStats", b.rtpStats, + "rtpStatsLite", b.rtpStatsLite, + "snRangeMap", b.snRangeMap, + "skipped", skippedSeqs, + ) + } + } + } + } + + // do not start on an RTX packet + if isRTX && !b.rtpStats.IsActive() { + return 0, errors.New("cannot start on rtx packet") + } + + flowState := b.rtpStats.Update( + arrivalTime, + rtpPacket.Header.SequenceNumber, + rtpPacket.Header.Timestamp, + rtpPacket.Header.Marker, + rtpPacket.Header.MarshalSize(), + len(rtpPacket.Payload), + int(rtpPacket.PaddingSize), + ) + switch flowState.UnhandledReason { + case rtpstats.RTPFlowUnhandledReasonNone: + case rtpstats.RTPFlowUnhandledReasonRestart: + if !b.enableStreamRestartDetection { + return 0, fmt.Errorf("unhandled reason: %s", flowState.UnhandledReason.String()) + } + + b.restartStreamLocked("discontinuity", true) + + flowState = b.rtpStats.Update( + arrivalTime, + rtpPacket.Header.SequenceNumber, + rtpPacket.Header.Timestamp, + rtpPacket.Header.Marker, + rtpPacket.Header.MarshalSize(), + len(rtpPacket.Payload), + int(rtpPacket.PaddingSize), + ) + default: + return 0, fmt.Errorf("unhandled reason: %s", flowState.UnhandledReason.String()) + } + + if b.params.IsOOBSequenceNumber { + b.updateOOBNACKState(oobSequenceNumber, arrivalTime, len(rawPkt)) + } else { + b.updateNACKState(rtpPacket.SequenceNumber, flowState) + } + + if len(rtpPacket.Payload) == 0 && (!flowState.IsOutOfOrder || flowState.IsDuplicate) { + // drop padding only in-order or duplicate packet + if !flowState.IsOutOfOrder { + // in-order packet - increment sequence number offset for subsequent packets + // Example: + // 40 - regular packet - pass through as sequence number 40 + // 41 - missing packet - don't know what it is, could be padding or not + // 42 - padding only packet - in-order - drop - increment sequence number offset to 1 - + // range[0, 42] = 0 offset + // 41 - arrives out of order - get offset 0 from cache - passed through as sequence number 41 + // 43 - regular packet - offset = 1 (running offset) - passes through as sequence number 42 + // 44 - padding only - in order - drop - increment sequence number offset to 2 + // range[0, 42] = 0 offset, range[43, 44] = 1 offset + // 43 - regular packet - out of order + duplicate - offset = 1 from cache - + // adjusted sequence number is 42, will be dropped by RTX buffer AddPacket method as duplicate + // 45 - regular packet - offset = 2 (running offset) - passed through with adjusted sequence number as 43 + // 44 - padding only - out-of-order + duplicate - dropped as duplicate + // + if err := b.snRangeMap.ExcludeRange(flowState.ExtSequenceNumber, flowState.ExtSequenceNumber+1); err != nil { + b.logger.Errorw( + "could not exclude range", err, + "sn", rtpPacket.SequenceNumber, + "esn", flowState.ExtSequenceNumber, + "rtpStats", b.rtpStats, + "snRangeMap", b.snRangeMap, + ) + } + } + return 0, errors.New("padding only packet") + } + + if !flowState.IsOutOfOrder && rtpPacket.PayloadType != b.payloadType && b.codecType == webrtc.RTPCodecTypeVideo { + b.logger.Infow("possible codec change", "oldPT", b.payloadType, "receivedPT", rtpPacket.PayloadType) + b.handleCodecChange(rtpPacket.PayloadType) + } + + // add to RTX buffer using sequence number after accounting for dropped padding only packets + snAdjustment, err := b.snRangeMap.GetValue(flowState.ExtSequenceNumber) + if err != nil { + b.logger.Errorw( + "could not get sequence number adjustment", err, + "sequenceNumber", rtpPacket.SequenceNumber, + "extSequenceNumber", flowState.ExtSequenceNumber, + "timestamp", rtpPacket.Timestamp, + "extTimestamp", flowState.ExtTimestamp, + "payloadSize", len(rtpPacket.Payload), + "paddingSize", rtpPacket.PaddingSize, + "rtpStats", b.rtpStats, + "rtpStatsLite", b.rtpStatsLite, + "snRangeMap", b.snRangeMap, + ) + return 0, err + } + + flowState.ExtSequenceNumber -= snAdjustment + rtpPacket.Header.SequenceNumber = uint16(flowState.ExtSequenceNumber) + if _, err = b.bucket.AddPacketWithSequenceNumber(rawPkt, flowState.ExtSequenceNumber); err != nil { + if !flowState.IsDuplicate { + if errors.Is(err, bucket.ErrPacketTooOld) { + packetTooOldCount := b.packetTooOldCount.Inc() + if (packetTooOldCount-1)%100 == 0 { + b.logger.Warnw( + "could not add packet to bucket", err, + "count", packetTooOldCount, + "flowState", &flowState, + "snAdjustment", snAdjustment, + "incomingSequenceNumber", flowState.ExtSequenceNumber+snAdjustment, + "rtpStats", b.rtpStats, + "rtpStatsLite", b.rtpStatsLite, + "snRangeMap", b.snRangeMap, + "skipped", skippedSeqs, + ) + } + } else if err != bucket.ErrRTXPacket { + b.logger.Warnw( + "could not add packet to bucket", err, + "flowState", &flowState, + "snAdjustment", snAdjustment, + "incomingSequenceNumber", flowState.ExtSequenceNumber+snAdjustment, + "rtpStats", b.rtpStats, + "rtpStatsLite", b.rtpStatsLite, + "snRangeMap", b.snRangeMap, + "skipped", skippedSeqs, + ) + } + } + return 0, err + } + + ep := b.getExtPacket(rtpPacket, arrivalTime, isBuffered, flowState) + if ep == nil { + return 0, errors.New("could not get ext packet") + } + b.extPackets.PushBack(ep) + b.readCond.Broadcast() + + if b.extPackets.Len() > b.bucket.Capacity() { + if (b.extPacketTooMuchCount.Inc()-1)%100 == 0 { + b.logger.Warnw("too much ext packets", nil, "count", b.extPackets.Len()) + } + } + + b.maybeGrowBucket(arrivalTime) + + return ep.ExtSequenceNumber, nil +} + +func (b *BufferBase) updateNACKState(sequenceNumber uint16, flowState rtpstats.RTPFlowState) { + if b.nacker == nil { + return + } + + b.nacker.Remove(sequenceNumber) + + for lost := flowState.LossStartInclusive; lost != flowState.LossEndExclusive; lost++ { + b.nacker.Push(uint16(lost)) + } +} + +func (b *BufferBase) updateOOBNACKState(sequenceNumber uint16, arrivalTime int64, size int) { + if b.nacker == nil || !b.params.IsOOBSequenceNumber { + return + } + + fsLite := b.rtpStatsLite.Update(arrivalTime, size, sequenceNumber) + if fsLite.IsNotHandled { + return + } + + b.nacker.Remove(sequenceNumber) + + for lost := fsLite.LossStartInclusive; lost != fsLite.LossEndExclusive; lost++ { + b.nacker.Push(uint16(lost)) + } +} + +func (b *BufferBase) processAudioSsrcLevelHeaderExtension(p *rtp.Packet, arrivalTime int64) { + if b.audioLevelExtID == 0 { + return + } + + if e := p.GetExtension(b.audioLevelExtID); e != nil { + ext := rtp.AudioLevelExtension{} + if err := ext.Unmarshal(e); err == nil { + b.audioLevel.ObserveWithRTPTimestamp(ext.Level, p.Timestamp, arrivalTime) + } + } +} + +func (b *BufferBase) handleCodecChange(newPT uint8) { + var ( + codecFound, rtxFound bool + rtxPt uint8 + newCodec webrtc.RTPCodecParameters + ) + for _, codec := range b.rtpParameters.Codecs { + if !codecFound && uint8(codec.PayloadType) == newPT { + newCodec = codec + codecFound = true + } + + if mime.IsMimeTypeStringRTX(codec.MimeType) && strings.Contains(codec.SDPFmtpLine, fmt.Sprintf("apt=%d", newPT)) { + rtxFound = true + rtxPt = uint8(codec.PayloadType) + } + + if codecFound && rtxFound { + break + } + } + if !codecFound { + b.logger.Errorw( + "could not find codec for new payload type", nil, + "pt", newPT, + "rtpParameters", b.rtpParameters, + ) + return + } + + b.logger.Infow( + "codec changed", + "oldPayload", b.payloadType, "newPayload", newPT, + "oldRtxPayload", b.rtxPayloadType, "newRtxPayload", rtxPt, + "oldMime", b.mime, "newMime", newCodec.MimeType, + ) + b.payloadType = newPT + b.rtxPayloadType = rtxPt + b.mime = mime.NormalizeMimeType(newCodec.MimeType) + + if f := b.onCodecChange; f != nil { + go f(newCodec) + } +} + +func (b *BufferBase) getExtPacket( + rtpPacket *rtp.Packet, + arrivalTime int64, + isBuffered bool, + flowState rtpstats.RTPFlowState, +) *ExtPacket { + ep := ExtPacketFactory.Get().(*ExtPacket) + *ep = ExtPacket{ + Arrival: arrivalTime, + ExtSequenceNumber: flowState.ExtSequenceNumber, + ExtTimestamp: flowState.ExtTimestamp, + Packet: rtpPacket, + VideoLayer: VideoLayer{ + Spatial: InvalidLayerSpatial, + Temporal: InvalidLayerTemporal, + }, + IsOutOfOrder: flowState.IsOutOfOrder, + IsBuffered: isBuffered, + } + + if len(ep.Packet.Payload) == 0 { + // padding only packet, nothing else to do + return ep + } + + if err := b.processVideoPacket(ep); err != nil { + ReleaseExtPacket(ep) + return nil + } + + if b.absCaptureTimeExtID != 0 { + extData := rtpPacket.GetExtension(b.absCaptureTimeExtID) + + var actExt act.AbsCaptureTime + if err := actExt.Unmarshal(extData); err == nil { + ep.AbsCaptureTimeExt = &actExt + } + } + + return ep +} + +func (b *BufferBase) processVideoPacket(ep *ExtPacket) error { + if b.codecType != webrtc.RTPCodecTypeVideo { + return nil + } + + ep.Temporal = 0 + var videoSize []VideoSize + if b.ddParser != nil { + ddVal, videoLayer, err := b.ddParser.Parse(ep.Packet) + if err != nil { + if errors.Is(err, ErrDDExtentionNotFound) { + if b.mime == mime.MimeTypeVP8 || b.mime == mime.MimeTypeVP9 { + b.logger.Infow("dd extension not found, disable dd parser") + b.ddParser = nil + b.createFrameRateCalculator() + } + } else { + return err + } + } else if ddVal != nil { + ep.DependencyDescriptor = ddVal + ep.VideoLayer = videoLayer + videoSize = ExtractDependencyDescriptorVideoSize(ddVal.Descriptor) + // DD-TODO : notify active decode target change if changed. + } + } + + switch b.mime { + case mime.MimeTypeVP8: + vp8Packet := VP8{} + if err := vp8Packet.Unmarshal(ep.Packet.Payload); err != nil { + b.logger.Warnw("could not unmarshal VP8 packet", err) + return err + } + ep.IsKeyFrame = vp8Packet.IsKeyFrame + if ep.DependencyDescriptor == nil { + ep.Temporal = int32(vp8Packet.TID) + + if ep.IsKeyFrame { + if sz := ExtractVP8VideoSize(&vp8Packet, ep.Packet.Payload); sz.Width > 0 && sz.Height > 0 { + videoSize = append(videoSize, sz) + } + } + } else { + // vp8 with DependencyDescriptor enabled, use the TID from the descriptor + vp8Packet.TID = uint8(ep.Temporal) + } + ep.Payload = vp8Packet + ep.Spatial = InvalidLayerSpatial // vp8 don't have spatial scalability, reset to invalid + + case mime.MimeTypeVP9: + if ep.DependencyDescriptor == nil { + var vp9Packet codecs.VP9Packet + _, err := vp9Packet.Unmarshal(ep.Packet.Payload) + if err != nil { + b.logger.Warnw("could not unmarshal VP9 packet", err) + return err + } + ep.VideoLayer = VideoLayer{ + Spatial: int32(vp9Packet.SID), + Temporal: int32(vp9Packet.TID), + } + ep.Payload = vp9Packet + ep.IsKeyFrame = IsVP9KeyFrame(&vp9Packet, ep.Packet.Payload) + + if ep.IsKeyFrame { + for i := 0; i < len(vp9Packet.Width); i++ { + videoSize = append(videoSize, VideoSize{ + Width: uint32(vp9Packet.Width[i]), + Height: uint32(vp9Packet.Height[i]), + }) + } + } + } else { + ep.IsKeyFrame = IsVP9KeyFrame(nil, ep.Packet.Payload) + } + + case mime.MimeTypeH264: + ep.IsKeyFrame = IsH264KeyFrame(ep.Packet.Payload) + ep.Spatial = InvalidLayerSpatial // h.264 don't have spatial scalability, reset to invalid + + // Check H264 key frame video size + if ep.IsKeyFrame { + if sz := ExtractH264VideoSize(ep.Packet.Payload); sz.Width > 0 && sz.Height > 0 { + videoSize = append(videoSize, sz) + } + } + + case mime.MimeTypeAV1: + ep.IsKeyFrame = IsAV1KeyFrame(ep.Packet.Payload) + + case mime.MimeTypeH265: + ep.IsKeyFrame = IsH265KeyFrame(ep.Packet.Payload) + if ep.DependencyDescriptor == nil { + if len(ep.Packet.Payload) < 2 { + b.logger.Warnw("invalid H265 packet", nil, "payloadLen", len(ep.Packet.Payload)) + return errors.New("invalid H265 packet") + } + ep.VideoLayer = VideoLayer{ + Temporal: int32(ep.Packet.Payload[1]&0x07) - 1, + } + ep.Spatial = InvalidLayerSpatial + + if ep.IsKeyFrame { + if sz := ExtractH265VideoSize(ep.Packet.Payload); sz.Width > 0 && sz.Height > 0 { + videoSize = append(videoSize, sz) + } + } + } + } + + if ep.IsKeyFrame { + if b.rtpStats != nil { + b.rtpStats.UpdateKeyFrame(1) + } + } + + if len(videoSize) > 0 { + b.checkVideoSizeChange(videoSize) + } + + b.doFpsCalc(ep) + + return nil +} + +func (b *BufferBase) patchExtPacket(ep *ExtPacket, buf []byte) *ExtPacket { + n, err := b.getPacketLocked(buf, ep.ExtSequenceNumber) + if err != nil { + packetNotFoundCount := b.packetNotFoundCount.Inc() + if (packetNotFoundCount-1)%20 == 0 { + b.logger.Warnw( + "could not get packet from bucket", err, + "sn", ep.Packet.SequenceNumber, + "headSN", b.bucket.HeadSequenceNumber(), + "count", packetNotFoundCount, + "rtpStats", b.rtpStats, + "rtpStatsLite", b.rtpStatsLite, + "snRangeMap", b.snRangeMap, + ) + } + return nil + } + ep.RawPacket = buf[:n] + + // patch RTP packet to point payload to new buffer + pkt := *ep.Packet + payloadStart := ep.Packet.Header.MarshalSize() + payloadEnd := payloadStart + len(ep.Packet.Payload) + if payloadEnd > n { + b.logger.Warnw("unexpected marshal size", nil, "max", n, "need", payloadEnd) + return nil + } + pkt.Payload = buf[payloadStart:payloadEnd] + ep.Packet = &pkt + + return ep +} + +func (b *BufferBase) flushExtPackets() { + b.Lock() + defer b.Unlock() + b.flushExtPacketsLocked() +} + +func (b *BufferBase) flushExtPacketsLocked() { + for b.extPackets.Len() > 0 { + ep := b.extPackets.PopFront() + ReleaseExtPacket(ep) + } + b.extPackets.Clear() +} + +func (b *BufferBase) maybeGrowBucket(now int64) { + if now-b.lastBucketCapCheckAt < bucketCapCheckInterval { + return + } + + // check and allocate in a go routine, away from the forwarding path + go func() { + b.Lock() + defer b.Unlock() + + b.lastBucketCapCheckAt = now + + cap := b.bucket.Capacity() + maxPkts := b.params.MaxVideoPkts + if b.codecType == webrtc.RTPCodecTypeAudio { + maxPkts = b.params.MaxAudioPkts + } + if cap >= maxPkts { + return + } + + oldCap := cap + if deltaInfo := b.rtpStats.DeltaInfo(b.ppsSnapshotId); deltaInfo != nil { + duration := deltaInfo.EndTime.Sub(deltaInfo.StartTime) + if duration < 500*time.Millisecond { + return + } + + pps := int(time.Duration(deltaInfo.Packets) * time.Second / duration) + for pps > cap && cap < maxPkts { + cap = b.bucket.Grow() + } + if cap > oldCap { + b.logger.Infow( + "grow bucket", + "from", oldCap, + "to", cap, + "pps", pps, + "deltaInfo", deltaInfo, + "rtpStats", b.rtpStats, + ) + } + } + }() +} + +func (b *BufferBase) doFpsCalc(ep *ExtPacket) { + if b.isPaused || b.frameRateCalculated || len(ep.Packet.Payload) == 0 { + return + } + + spatial := ep.Spatial + if spatial < 0 || int(spatial) >= len(b.frameRateCalculator) { + spatial = 0 + } + if fr := b.frameRateCalculator[spatial]; fr != nil { + if fr.RecvPacket(ep) { + complete := true + for _, fr2 := range b.frameRateCalculator { + if fr2 != nil && !fr2.Completed() { + complete = false + break + } + } + if complete { + b.frameRateCalculated = true + if f := b.onFpsChanged; f != nil { + go f() + } + } + } + } +} + +func (b *BufferBase) SetSenderReportData(srData *livekit.RTCPSenderReportState) { + b.RLock() + didSet := false + if b.rtpStats != nil { + didSet = b.rtpStats.SetRtcpSenderReportData(srData) + } + b.RUnlock() + + if didSet { + if cb := b.getOnRtcpSenderReport(); cb != nil { + cb() + } + } +} + +func (b *BufferBase) GetSenderReportData() *livekit.RTCPSenderReportState { + b.RLock() + defer b.RUnlock() + + if b.rtpStats != nil { + return b.rtpStats.GetRtcpSenderReportData() + } + + return nil +} + +func (b *BufferBase) GetPacket(buff []byte, esn uint64) (int, error) { + b.Lock() + defer b.Unlock() + + return b.getPacketLocked(buff, esn) +} + +func (b *BufferBase) getPacketLocked(buff []byte, esn uint64) (int, error) { + if b.isClosed.Load() { + return 0, io.EOF + } + return b.bucket.GetPacket(buff, esn) +} + +func (b *BufferBase) GetStats() *livekit.RTPStats { + b.RLock() + defer b.RUnlock() + + if b.rtpStats == nil { + return nil + } + + return b.rtpStats.ToProto() +} + +func (b *BufferBase) GetDeltaStats() *StreamStatsWithLayers { + b.RLock() + defer b.RUnlock() + + if b.rtpStats == nil { + return nil + } + + deltaStats := b.rtpStats.DeltaInfo(b.deltaStatsSnapshotId) + if deltaStats == nil { + return nil + } + + return &StreamStatsWithLayers{ + RTPStats: deltaStats, + Layers: map[int32]*rtpstats.RTPDeltaInfo{ + 0: deltaStats, + }, + } +} + +func (b *BufferBase) GetDeltaStatsLite() *rtpstats.RTPDeltaInfoLite { + b.RLock() + defer b.RUnlock() + + if b.rtpStatsLite == nil { + return nil + } + + return b.rtpStatsLite.DeltaInfoLite(b.liteStatsSnapshotId) +} + +func (b *BufferBase) GetLastSenderReportTime() time.Time { + b.RLock() + defer b.RUnlock() + + if b.rtpStats == nil { + return time.Time{} + } + + return b.rtpStats.LastSenderReportTime() +} + +func (b *BufferBase) GetAudioLevel() (float64, bool) { + b.RLock() + defer b.RUnlock() + + if b.audioLevel == nil { + return 0, false + } + + return b.audioLevel.GetLevel(mono.UnixNano()) +} + +func (b *BufferBase) OnRtcpSenderReport(fn func()) { + b.Lock() + b.onRtcpSenderReport = fn + b.Unlock() +} + +func (b *BufferBase) getOnRtcpSenderReport() func() { + b.RLock() + defer b.RUnlock() + + return b.onRtcpSenderReport +} + +func (b *BufferBase) OnFpsChanged(f func()) { + b.Lock() + b.onFpsChanged = f + b.Unlock() +} + +func (b *BufferBase) OnVideoSizeChanged(fn func([]VideoSize)) { + b.Lock() + b.onVideoSizeChanged = fn + b.Unlock() +} + +func (b *BufferBase) OnCodecChange(fn func(webrtc.RTPCodecParameters)) { + b.Lock() + b.onCodecChange = fn + b.Unlock() +} + +func (b *BufferBase) OnStreamRestart(fn func(string)) { + b.Lock() + b.onStreamRestart = fn + b.Unlock() +} + +// checkVideoSizeChange checks if video size has changed for a specific spatial layer and fires callback +func (b *BufferBase) checkVideoSizeChange(videoSizes []VideoSize) { + if len(videoSizes) > len(b.currentVideoSize) { + b.logger.Warnw( + "video size index out of range", nil, + "newSize", videoSizes, + "currentVideoSize", b.currentVideoSize, + ) + return + } + + if len(videoSizes) < len(b.currentVideoSize) { + videoSizes = append(videoSizes, make([]VideoSize, len(b.currentVideoSize)-len(videoSizes))...) + } + + changed := false + for i, sz := range videoSizes { + if b.currentVideoSize[i].Width != sz.Width || b.currentVideoSize[i].Height != sz.Height { + changed = true + break + } + } + + if changed { + b.logger.Debugw("video size changed", "from", b.currentVideoSize, "to", videoSizes) + copy(b.currentVideoSize[:], videoSizes[:]) + if b.onVideoSizeChanged != nil { + go b.onVideoSizeChanged(videoSizes) + } + } +} + +func (b *BufferBase) GetTemporalLayerFpsForSpatial(layer int32) []float32 { + b.RLock() + defer b.RUnlock() + + if int(layer) >= len(b.frameRateCalculator) { + return nil + } + + if fc := b.frameRateCalculator[layer]; fc != nil { + return fc.GetFrameRate() + } + return nil +} + +func (b *BufferBase) StartKeyFrameSeeder() { + if b.codecType == webrtc.RTPCodecTypeVideo { + go b.seedKeyFrame(b.keyFrameSeederGeneration.Inc()) + } +} + +func (b *BufferBase) StopKeyFrameSeeder() { + b.keyFrameSeederGeneration.Inc() +} + +func (b *BufferBase) seedKeyFrame(keyFrameSeederGeneration int32) { + // a key frame is needed especially when using Dependency Descriptor + // to get the DD structure which is used in parsing subsequent packets, + // till then packets are dropped which results in stream tracker not + // getting any data which means it does not declare layer start. + // + // send gratuitous PLIs for some time or until a key frame is seen to + // get the engine rolling + b.logger.Debugw("starting key frame seeder", "generation", keyFrameSeederGeneration) + timer := time.NewTimer(30 * time.Second) + defer timer.Stop() + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + initialCount := uint32(0) + b.RLock() + rtpStats := b.rtpStats + b.RUnlock() + if rtpStats == nil { + b.logger.Debugw("cannot do key frame seeding without stats", "generation", keyFrameSeederGeneration) + return + } + initialCount, _ = rtpStats.KeyFrame() + + for { + if b.isClosed.Load() || b.keyFrameSeederGeneration.Load() != keyFrameSeederGeneration { + b.logger.Debugw( + "stopping key frame seeder: stopped", + "generation", keyFrameSeederGeneration, + "currentGeneration", b.keyFrameSeederGeneration.Load(), + ) + return + } + + select { + case <-timer.C: + b.logger.Debugw("stopping key frame seeder: timeout", "generation", keyFrameSeederGeneration) + return + + case <-ticker.C: + cnt, last := rtpStats.KeyFrame() + if cnt > initialCount { + b.logger.Debugw( + "stopping key frame seeder: received key frame", + "generation", keyFrameSeederGeneration, + "keyFrameCountInitial", initialCount, + "keyFrameCount", cnt, + "lastKeyFrame", last, + ) + return + } + + b.SendPLI(false) + } + } +} + +func (b *BufferBase) GetNACKPairs() []rtcp.NackPair { + b.RLock() + defer b.RUnlock() + + return b.GetNACKPairsLocked() +} + +func (b *BufferBase) GetNACKPairsLocked() []rtcp.NackPair { + if b.nacker == nil { + return nil + } + + pairs, numSeqNumsNacked := b.nacker.Pairs() + if !b.params.IsOOBSequenceNumber { + if b.rtpStats != nil { + b.rtpStats.UpdateNack(uint32(numSeqNumsNacked)) + } + } else { + if b.rtpStatsLite != nil { + b.rtpStatsLite.UpdateNack(uint32(numSeqNumsNacked)) + } + } + + return pairs +} + +func (b *BufferBase) GetRtcpReceptionReportLocked(proxyLoss uint8) *rtcp.ReceptionReport { + if b.rtpStats == nil { + return nil + } + + return b.rtpStats.GetRtcpReceptionReport(b.params.SSRC, proxyLoss, b.rrSnapshotId) +} + +// --------------------------------------------------------------- diff --git a/livekit/pkg/sfu/buffer/buffer_test.go b/livekit/pkg/sfu/buffer/buffer_test.go new file mode 100644 index 0000000..f0f3076 --- /dev/null +++ b/livekit/pkg/sfu/buffer/buffer_test.go @@ -0,0 +1,479 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "fmt" + "math" + "sync" + "testing" + "time" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + + "github.com/livekit/mediatransportutil/pkg/nack" +) + +var h265Codec = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: "video/h265", + ClockRate: 90000, + RTCPFeedback: []webrtc.RTCPFeedback{{ + Type: "nack", + }}, + }, + PayloadType: 116, +} + +var vp8Codec = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: "video/vp8", + ClockRate: 90000, + RTCPFeedback: []webrtc.RTCPFeedback{{ + Type: "nack", + }}, + }, + PayloadType: 96, +} + +var opusCodec = webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: "audio/opus", + ClockRate: 48000, + }, + PayloadType: 111, +} + +func TestNack(t *testing.T) { + t.Run("nack normal", func(t *testing.T) { + buff := NewBuffer(123, 1, 1) + buff.codecType = webrtc.RTPCodecTypeVideo + require.NotNil(t, buff) + var wg sync.WaitGroup + // 5 tries + wg.Add(5) + buff.OnRtcpFeedback(func(fb []rtcp.Packet) { + for _, pkt := range fb { + switch p := pkt.(type) { + case *rtcp.TransportLayerNack: + if p.Nacks[0].PacketList()[0] == 1 && p.MediaSSRC == 123 { + wg.Done() + } + } + } + }) + buff.Bind(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{vp8Codec}, + }, vp8Codec.RTPCodecCapability, 0) + rtt := uint32(20) + buff.nacker.SetRTT(rtt) + for i := range 15 { + if i == 1 { + continue + } + if i < 14 { + time.Sleep(time.Duration(float64(rtt)*math.Pow(nack.NackQueueParamsDefault.BackoffFactor, float64(i))+10) * time.Millisecond) + } else { + time.Sleep(500 * time.Millisecond) // even a long wait should not exceed max retries + } + pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: uint16(i), + Timestamp: uint32(i), + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + b, err := pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(b) + require.NoError(t, err) + } + wg.Wait() + + }) + + t.Run("nack with seq wrap", func(t *testing.T) { + buff := NewBuffer(123, 1, 1) + buff.codecType = webrtc.RTPCodecTypeVideo + require.NotNil(t, buff) + var wg sync.WaitGroup + expects := map[uint16]int{ + 65534: 0, + 65535: 0, + 0: 0, + 1: 0, + } + wg.Add(5 * len(expects)) // retry 5 times + buff.OnRtcpFeedback(func(fb []rtcp.Packet) { + for _, pkt := range fb { + switch p := pkt.(type) { + case *rtcp.TransportLayerNack: + if p.MediaSSRC == 123 { + for _, v := range p.Nacks { + v.Range(func(seq uint16) bool { + if _, ok := expects[seq]; ok { + wg.Done() + } else { + require.Fail(t, "unexpected nack seq ", seq) + } + return true + }) + } + } + } + } + }) + buff.Bind(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{vp8Codec}, + }, vp8Codec.RTPCodecCapability, 0) + rtt := uint32(30) + buff.nacker.SetRTT(rtt) + for i := range 15 { + if i > 0 && i < 5 { + continue + } + if i < 14 { + time.Sleep(time.Duration(float64(rtt)*math.Pow(nack.NackQueueParamsDefault.BackoffFactor, float64(i))+10) * time.Millisecond) + } else { + time.Sleep(500 * time.Millisecond) // even a long wait should not exceed max retries + } + pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: uint16(i + 65533), + Timestamp: uint32(i), + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + b, err := pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(b) + require.NoError(t, err) + } + wg.Wait() + + }) +} + +func TestNewBuffer(t *testing.T) { + tests := []struct { + name string + }{ + { + name: "Must not be nil and add packets in sequence", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var TestPackets = []*rtp.Packet{ + { + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: 65533, + SSRC: 123, + }, + }, + { + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: 65534, + SSRC: 123, + }, + Payload: []byte{1}, + }, + { + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: 2, + SSRC: 123, + }, + }, + { + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: 65535, + SSRC: 123, + }, + }, + } + buff := NewBuffer(123, 1, 1) + buff.codecType = webrtc.RTPCodecTypeVideo + require.NotNil(t, buff) + buff.OnRtcpFeedback(func(_ []rtcp.Packet) {}) + buff.Bind(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{vp8Codec}, + }, vp8Codec.RTPCodecCapability, 0) + + for _, p := range TestPackets { + buf, _ := p.Marshal() + _, _ = buff.Write(buf) + } + require.Equal(t, uint16(2), buff.rtpStats.HighestSequenceNumber()) + require.Equal(t, uint64(65536+2), buff.rtpStats.ExtendedHighestSequenceNumber()) + }) + } +} + +func TestFractionLostReport(t *testing.T) { + buff := NewBuffer(123, 1, 1) + require.NotNil(t, buff) + + var wg sync.WaitGroup + + // with loss proxying + wg.Add(1) + buff.SetAudioLossProxying(true) + buff.SetLastFractionLostReport(55) + buff.OnRtcpFeedback(func(fb []rtcp.Packet) { + for _, pkt := range fb { + switch p := pkt.(type) { + case *rtcp.ReceiverReport: + for _, v := range p.Reports { + require.EqualValues(t, 55, v.FractionLost) + } + wg.Done() + } + } + }) + buff.Bind(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{opusCodec}, + }, opusCodec.RTPCodecCapability, 0) + for i := range 15 { + pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 111, + SequenceNumber: uint16(i), + Timestamp: uint32(i), + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + b, err := pkt.Marshal() + require.NoError(t, err) + if i == 1 { + time.Sleep(1 * time.Second) + } + _, err = buff.Write(b) + require.NoError(t, err) + } + wg.Wait() + + wg.Add(1) + buff.SetAudioLossProxying(false) + buff.OnRtcpFeedback(func(fb []rtcp.Packet) { + for _, pkt := range fb { + switch p := pkt.(type) { + case *rtcp.ReceiverReport: + for _, v := range p.Reports { + require.EqualValues(t, 0, v.FractionLost) + } + wg.Done() + } + } + }) + buff.Bind(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{opusCodec}, + }, opusCodec.RTPCodecCapability, 0) + for i := range 15 { + pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 111, + SequenceNumber: uint16(i), + Timestamp: uint32(i), + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + b, err := pkt.Marshal() + require.NoError(t, err) + if i == 1 { + time.Sleep(1 * time.Second) + } + _, err = buff.Write(b) + require.NoError(t, err) + } + wg.Wait() +} + +func TestCodecChange(t *testing.T) { + // codec change before bind + buff := NewBuffer(123, 1, 1) + require.NotNil(t, buff) + changedCodec := make(chan webrtc.RTPCodecParameters, 1) + buff.OnCodecChange(func(rp webrtc.RTPCodecParameters) { + select { + case changedCodec <- rp: + default: + t.Fatalf("codec change not consumed") + } + }) + buff.OnStreamRestart(func(reason string) { + require.Equal(t, "codec-change", reason) + + // read once to clear pending restart + var buf [1500]byte + extPkt, err := buff.ReadExtended(buf[:]) + require.NoError(t, err) + require.Nil(t, extPkt) + }) + + h265Pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 116, + SequenceNumber: 1, + Timestamp: 1, + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + buf, err := h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + + select { + case <-changedCodec: + t.Fatalf("unexpected codec change") + case <-time.After(100 * time.Millisecond): + } + + // Bind sets up VP8 as expected codec, + // packet written to the buffer above before Bind is H.265, + // that should trigger a codec change and a stream restart + // when the queued packets from Write before Bind get flushed + buff.Bind( + webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{vp8Codec, h265Codec}, + }, + vp8Codec.RTPCodecCapability, + 0, + ) + + select { + case c := <-changedCodec: + require.Equal(t, h265Codec, c) + case <-time.After(1 * time.Second): + t.Fatalf("expected codec change") + } + + // second codec change - writing VP8 packet after Bind should trigger another codec change + vp8Pkt := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + PayloadType: 96, + SequenceNumber: 3, + Timestamp: 3, + SSRC: 123, + }, + Payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}, + } + buf, err = vp8Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + + select { + case c := <-changedCodec: + require.Equal(t, vp8Codec, c) + case <-time.After(1 * time.Second): + t.Fatalf("expected codec change") + } + fmt.Printf("done second codec change\n") // REMOVE + + // out of order pkts can't cause codec change + // rewrite the VP8 packet to start the sequence after a stream restart + _, err = buff.Write(buf) + require.NoError(t, err) + + h265Pkt.SequenceNumber = 2 + h265Pkt.Timestamp = 2 + buf, err = h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + select { + case <-changedCodec: + t.Fatalf("unexpected codec change") + case <-time.After(100 * time.Millisecond): + } + + // unknown codec should not cause change even if it is in order + h265Pkt.SequenceNumber = 4 + h265Pkt.Timestamp = 4 + h265Pkt.PayloadType = 117 + buf, err = h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + select { + case <-changedCodec: + t.Fatalf("unexpected codec change") + case <-time.After(100 * time.Millisecond): + } + + // an in-order packet should change codec again + h265Pkt.SequenceNumber = 5 + h265Pkt.Timestamp = 5 + h265Pkt.PayloadType = 116 + buf, err = h265Pkt.Marshal() + require.NoError(t, err) + _, err = buff.Write(buf) + require.NoError(t, err) + select { + case c := <-changedCodec: + require.Equal(t, h265Codec, c) + case <-time.After(1 * time.Second): + t.Fatalf("expected codec change") + } +} + +func BenchmarkMemcpu(b *testing.B) { + buf := make([]byte, 1500*1500*10) + buf2 := make([]byte, 1500*1500*20) + + for b.Loop() { + copy(buf2, buf) + } +} + +func BenchmarkExtPacketFactory(b *testing.B) { + + for b.Loop() { + extPkt := ExtPacketFactory.Get().(*ExtPacket) + *extPkt = ExtPacket{} + ExtPacketFactory.Put(extPkt) + } +} diff --git a/livekit/pkg/sfu/buffer/datastats.go b/livekit/pkg/sfu/buffer/datastats.go new file mode 100644 index 0000000..795f8a0 --- /dev/null +++ b/livekit/pkg/sfu/buffer/datastats.go @@ -0,0 +1,102 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "sync" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/protocol/livekit" +) + +type DataStatsParam struct { + WindowDuration time.Duration +} + +type DataStats struct { + params DataStatsParam + lock sync.RWMutex + totalBytes int64 + startTime time.Time + endTime time.Time + windowStart int64 + windowBytes int64 +} + +func NewDataStats(params DataStatsParam) *DataStats { + return &DataStats{ + params: params, + startTime: time.Now(), + windowStart: time.Now().UnixNano(), + } +} + +func (s *DataStats) Update(bytes int, time int64) { + s.lock.Lock() + defer s.lock.Unlock() + s.totalBytes += int64(bytes) + + if s.params.WindowDuration > 0 && time-s.windowStart > s.params.WindowDuration.Nanoseconds() { + s.windowBytes = 0 + s.windowStart = time + } + s.windowBytes += int64(bytes) +} + +func (s *DataStats) ToProtoActive() *livekit.RTPStats { + if s.params.WindowDuration == 0 { + return &livekit.RTPStats{} + } + s.lock.RLock() + defer s.lock.RUnlock() + now := time.Now().UnixNano() + duration := now - s.windowStart + if duration > s.params.WindowDuration.Nanoseconds() { + return &livekit.RTPStats{} + } + + return &livekit.RTPStats{ + StartTime: timestamppb.New(time.Unix(s.windowStart/1e9, s.windowStart%1e9)), + EndTime: timestamppb.New(time.Unix(0, now)), + Duration: float64(duration / 1e9), + Bytes: uint64(s.windowBytes), + Bitrate: float64(s.windowBytes) * 8 / float64(duration) / 1e9, + } +} + +func (s *DataStats) Stop() { + s.lock.Lock() + s.endTime = time.Now() + s.lock.Unlock() +} + +func (s *DataStats) ToProtoAggregateOnly() *livekit.RTPStats { + s.lock.RLock() + defer s.lock.RUnlock() + + end := s.endTime + if end.IsZero() { + end = time.Now() + } + return &livekit.RTPStats{ + StartTime: timestamppb.New(s.startTime), + EndTime: timestamppb.New(end), + Duration: end.Sub(s.startTime).Seconds(), + Bytes: uint64(s.totalBytes), + Bitrate: float64(s.totalBytes) * 8 / end.Sub(s.startTime).Seconds(), + } +} diff --git a/livekit/pkg/sfu/buffer/datastats_test.go b/livekit/pkg/sfu/buffer/datastats_test.go new file mode 100644 index 0000000..5803fe9 --- /dev/null +++ b/livekit/pkg/sfu/buffer/datastats_test.go @@ -0,0 +1,53 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" +) + +func TestDataStats(t *testing.T) { + stats := NewDataStats(DataStatsParam{WindowDuration: time.Second}) + + time.Sleep(time.Millisecond) + r := stats.ToProtoAggregateOnly() + require.Equal(t, r.StartTime.AsTime().UnixNano(), stats.startTime.UnixNano()) + require.NotZero(t, r.EndTime) + require.NotZero(t, r.Duration) + r.StartTime = nil + r.EndTime = nil + r.Duration = 0 + require.True(t, proto.Equal(r, &livekit.RTPStats{})) + + stats.Update(100, time.Now().UnixNano()) + r = stats.ToProtoActive() + require.EqualValues(t, 100, r.Bytes) + require.NotZero(t, r.Bitrate) + + // wait for window duration + time.Sleep(time.Second) + r = stats.ToProtoActive() + require.True(t, proto.Equal(r, &livekit.RTPStats{})) + stats.Stop() + r = stats.ToProtoAggregateOnly() + require.EqualValues(t, 100, r.Bytes) + require.NotZero(t, r.Bitrate) +} diff --git a/livekit/pkg/sfu/buffer/dependencydescriptorparser.go b/livekit/pkg/sfu/buffer/dependencydescriptorparser.go new file mode 100644 index 0000000..e4934e1 --- /dev/null +++ b/livekit/pkg/sfu/buffer/dependencydescriptorparser.go @@ -0,0 +1,333 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "fmt" + "sort" + "sync" + "time" + + "github.com/pion/rtp" + "go.uber.org/atomic" + + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/mediatransportutil/pkg/utils" + "github.com/livekit/protocol/logger" +) + +var ( + ExtDependencyDescriptorFactory = &sync.Pool{ + New: func() any { + return &ExtDependencyDescriptor{} + }, + } +) + +// -------------------------------------- + +const ( + ddRestartThreshold = 30 * time.Second + + // frame integrity check 2 seconds for L3T3 30fps video + integrityCheckFrame = 180 + integrityCheckPkt = 1024 +) + +var ( + ErrFrameEarlierThanKeyFrame = fmt.Errorf("frame is earlier than current keyframe") + ErrDDStructureAttachedToNonFirstPacket = fmt.Errorf("dependency descriptor structure is attached to non-first packet of a frame") + ErrDDExtentionNotFound = fmt.Errorf("dependency descriptor extension not found") +) + +type DependencyDescriptorParser struct { + structure *dd.FrameDependencyStructure + ddExtID uint8 + logger logger.Logger + onMaxLayerChanged func(int32, int32) + decodeTargets []DependencyDescriptorDecodeTarget + + seqWrapAround *utils.WrapAround[uint16, uint64] + frameWrapAround *utils.WrapAround[uint16, uint64] + structureExtFrameNum uint64 + activeDecodeTargetsExtSeq uint64 + activeDecodeTargetsMask uint32 + frameChecker *FrameIntegrityChecker + + ddNotFoundCount atomic.Uint32 + + // restart detection + restartGeneration int + enableRestart bool + lastPacketAt time.Time +} + +func NewDependencyDescriptorParser(ddExtID uint8, logger logger.Logger, onMaxLayerChanged func(int32, int32), enableRestart bool) *DependencyDescriptorParser { + return &DependencyDescriptorParser{ + ddExtID: ddExtID, + logger: logger, + onMaxLayerChanged: onMaxLayerChanged, + seqWrapAround: utils.NewWrapAround[uint16, uint64](utils.WrapAroundParams{IsRestartAllowed: false}), + frameWrapAround: utils.NewWrapAround[uint16, uint64](utils.WrapAroundParams{IsRestartAllowed: false}), + frameChecker: NewFrameIntegrityChecker(integrityCheckFrame, integrityCheckPkt), + enableRestart: enableRestart, + } +} + +type ExtDependencyDescriptor struct { + Descriptor *dd.DependencyDescriptor + + DecodeTargets []DependencyDescriptorDecodeTarget + StructureUpdated bool + ActiveDecodeTargetsUpdated bool + Integrity bool + ExtFrameNum uint64 + // the frame number of the keyframe which the current frame depends on + ExtKeyFrameNum uint64 + + // increase when the stream restarts, clear and reinitialize all dd state includes + // attached structure, frame chain, decode target. + RestartGeneration int +} + +func (r *DependencyDescriptorParser) Parse(pkt *rtp.Packet) (*ExtDependencyDescriptor, VideoLayer, error) { + var videoLayer VideoLayer + ddBuf := pkt.GetExtension(r.ddExtID) + if ddBuf == nil { + ddNotFoundCount := r.ddNotFoundCount.Inc() + if ddNotFoundCount%100 == 0 { + r.logger.Warnw("dependency descriptor extension is not present", nil, "seq", pkt.SequenceNumber, "count", ddNotFoundCount) + } + return nil, videoLayer, ErrDDExtentionNotFound + } + + var ddVal dd.DependencyDescriptor + ext := &dd.DependencyDescriptorExtension{ + Descriptor: &ddVal, + Structure: r.structure, + } + _, err := ext.Unmarshal(ddBuf) + if err != nil { + if err != dd.ErrDDReaderNoStructure && err != dd.ErrDDReaderInvalidTemplateIndex { + r.logger.Infow("failed to parse generic dependency descriptor", err, "payload", pkt.PayloadType, "ddbufLen", len(ddBuf)) + } + return nil, videoLayer, err + } + + var restart bool + if r.enableRestart { + if !r.lastPacketAt.IsZero() && time.Since(r.lastPacketAt) > ddRestartThreshold { + r.restart() + restart = true + r.logger.Debugw( + "dependency descriptor parser restart stream", + "generation", r.restartGeneration, + "lastPacketAt", r.lastPacketAt, + "sinceLast", time.Since(r.lastPacketAt), + "frameWrapAround", r.frameWrapAround, + ) + } + r.lastPacketAt = time.Now() + } + + extSeq := r.seqWrapAround.Update(pkt.SequenceNumber).ExtendedVal + + if ddVal.FrameDependencies != nil { + videoLayer.Spatial, videoLayer.Temporal = int32(ddVal.FrameDependencies.SpatialId), int32(ddVal.FrameDependencies.TemporalId) + } + + // assume the packet is in-order when stream restarting + unwrapped := r.frameWrapAround.UpdateWithOrderKnown(ddVal.FrameNumber, restart) + extFN := unwrapped.ExtendedVal + + if extFN < r.structureExtFrameNum { + r.logger.Debugw( + "drop frame which is earlier than current structure", + "fn", ddVal.FrameNumber, + "extFN", extFN, + "structureExtFrameNum", r.structureExtFrameNum, + "unwrappedFN", unwrapped, + "frameWrapAround", r.frameWrapAround, + ) + return nil, videoLayer, ErrFrameEarlierThanKeyFrame + } + + r.frameChecker.AddPacket(extSeq, extFN, &ddVal) + + extDD := ExtDependencyDescriptorFactory.Get().(*ExtDependencyDescriptor) + *extDD = ExtDependencyDescriptor{ + Descriptor: &ddVal, + ExtFrameNum: extFN, + Integrity: r.frameChecker.FrameIntegrity(extFN), + RestartGeneration: r.restartGeneration, + } + + if ddVal.AttachedStructure != nil { + if !ddVal.FirstPacketInFrame { + r.logger.Warnw( + "attached structure is not the first packet in frame", nil, + "sn", pkt.SequenceNumber, + "extSeq", extSeq, + "fn", ddVal.FrameNumber, + "extFN", extFN, + ) + ReleaseExtDependencyDescriptor(extDD) + return nil, videoLayer, ErrDDStructureAttachedToNonFirstPacket + } + + if r.structure == nil || ddVal.AttachedStructure.StructureId != r.structure.StructureId { + r.logger.Debugw( + "structure updated", + "structureID", ddVal.AttachedStructure.StructureId, + "sn", pkt.SequenceNumber, + "extSeq", extSeq, + "fn", ddVal.FrameNumber, + "extFN", extFN, + "descriptor", ddVal.String(), + "unwrappedFN", unwrapped, + "frameWrapAround", r.frameWrapAround, + ) + } + r.structure = ddVal.AttachedStructure + r.decodeTargets = ProcessFrameDependencyStructure(ddVal.AttachedStructure) + if extFN > unwrapped.PreExtendedHighest && extFN-unwrapped.PreExtendedHighest > 1000 { + r.logger.Debugw( + "large frame number jump on structure updating", + "fn", ddVal.FrameNumber, + "extFN", extFN, + "preExtendedHighest", unwrapped.PreExtendedHighest, + "structureExtFrameNum", r.structureExtFrameNum, + "unwrappedFN", unwrapped, + "frameWrapAround", r.frameWrapAround, + ) + } + r.structureExtFrameNum = extFN + extDD.StructureUpdated = true + extDD.ActiveDecodeTargetsUpdated = true + // The dependency descriptor reader will always set ActiveDecodeTargetsBitmask for TemplateDependencyStructure is present, + // so don't need to notify max layer change here. + } + + if mask := ddVal.ActiveDecodeTargetsBitmask; mask != nil && extSeq > r.activeDecodeTargetsExtSeq { + r.activeDecodeTargetsExtSeq = extSeq + if *mask != r.activeDecodeTargetsMask { + r.activeDecodeTargetsMask = *mask + extDD.ActiveDecodeTargetsUpdated = true + var maxSpatial, maxTemporal int32 + for _, dt := range r.decodeTargets { + if *mask&(1< base +} + +func (f *Factory) GetOrNew(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser { + f.Lock() + defer f.Unlock() + switch packetType { + case packetio.RTCPBufferPacket: + if reader, ok := f.rtcpReaders[ssrc]; ok { + return reader + } + reader := NewRTCPReader(ssrc) + f.rtcpReaders[ssrc] = reader + reader.OnClose(func() { + f.Lock() + delete(f.rtcpReaders, ssrc) + f.Unlock() + }) + return reader + case packetio.RTPBufferPacket: + if reader, ok := f.rtpBuffers[ssrc]; ok { + return reader + } + buffer := NewBuffer(ssrc, f.trackingPacketsVideo, f.trackingPacketsAudio) + f.rtpBuffers[ssrc] = buffer + for repair, base := range f.rtxPair { + if repair == ssrc { + baseBuffer, ok := f.rtpBuffers[base] + if ok { + buffer.SetPrimaryBufferForRTX(baseBuffer) + } + break + } else if base == ssrc { + repairBuffer, ok := f.rtpBuffers[repair] + if ok { + repairBuffer.SetPrimaryBufferForRTX(buffer) + } + break + } + } + buffer.OnClose(func() { + f.Lock() + delete(f.rtpBuffers, ssrc) + delete(f.rtxPair, ssrc) + f.Unlock() + }) + return buffer + } + return nil +} + +func (f *Factory) GetBufferPair(ssrc uint32) (*Buffer, *RTCPReader) { + f.RLock() + defer f.RUnlock() + return f.rtpBuffers[ssrc], f.rtcpReaders[ssrc] +} + +func (f *Factory) GetBuffer(ssrc uint32) *Buffer { + f.RLock() + defer f.RUnlock() + return f.rtpBuffers[ssrc] +} + +func (f *Factory) GetRTCPReader(ssrc uint32) *RTCPReader { + f.RLock() + defer f.RUnlock() + return f.rtcpReaders[ssrc] +} + +func (f *Factory) SetRTXPair(repair, base uint32, rsid string) { + f.Lock() + repairBuffer, baseBuffer := f.rtpBuffers[repair], f.rtpBuffers[base] + if repairBuffer == nil || baseBuffer == nil { + f.rtxPair[repair] = base + } + f.Unlock() + if repairBuffer != nil && baseBuffer != nil { + repairBuffer.SetPrimaryBufferForRTX(baseBuffer) + if rsid != "" { + baseBuffer.NotifyRTX(base, repair, rsid) + } + } +} diff --git a/livekit/pkg/sfu/buffer/fps.go b/livekit/pkg/sfu/buffer/fps.go new file mode 100644 index 0000000..5622198 --- /dev/null +++ b/livekit/pkg/sfu/buffer/fps.go @@ -0,0 +1,732 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "container/list" + + "github.com/pion/rtp/codecs" + + "github.com/livekit/protocol/logger" +) + +var minFramesForCalculation = [...]int{8, 15, 40, 60} + +type frameInfo struct { + startSeq uint16 + endSeq uint16 + ts uint32 + fn uint16 + spatial int32 + temporal int32 + frameDiff []int +} + +type FrameRateCalculator interface { + RecvPacket(ep *ExtPacket) bool + GetFrameRate() []float32 + Completed() bool +} + +// ----------------------------- + +// FrameRateCalculator based on PictureID in VPx +type frameRateCalculatorVPx struct { + frameRates [DefaultMaxLayerTemporal + 1]float32 + clockRate uint32 + logger logger.Logger + firstFrames [DefaultMaxLayerTemporal + 1]*frameInfo + secondFrames [DefaultMaxLayerTemporal + 1]*frameInfo + fnReceived [64]*frameInfo + baseFrame *frameInfo + completed bool +} + +func newFrameRateCalculatorVPx(clockRate uint32, logger logger.Logger) *frameRateCalculatorVPx { + return &frameRateCalculatorVPx{ + clockRate: clockRate, + logger: logger, + } +} + +func (f *frameRateCalculatorVPx) Completed() bool { + return f.completed +} + +func (f *frameRateCalculatorVPx) RecvPacket(ep *ExtPacket, fn uint16) bool { + if f.completed { + return true + } + + if ep.Temporal >= int32(len(f.frameRates)) { + f.logger.Warnw("invalid temporal layer", nil, "temporal", ep.Temporal) + return false + } + + temporal := ep.Temporal + if temporal < 0 { + temporal = 0 + } + + if f.baseFrame == nil { + f.baseFrame = &frameInfo{ts: ep.Packet.Timestamp, fn: fn} + f.fnReceived[0] = f.baseFrame + f.firstFrames[temporal] = f.baseFrame + return false + } + + baseDiff := fn - f.baseFrame.fn + if baseDiff == 0 || baseDiff > 0x4000 { + return false + } + + if baseDiff >= uint16(len(f.fnReceived)) { + // frame number is not continuous, reset + f.reset() + + return false + } + + if f.fnReceived[baseDiff] != nil { + return false + } + + fi := &frameInfo{ + ts: ep.Packet.Timestamp, + fn: fn, + temporal: temporal, + } + f.fnReceived[baseDiff] = fi + + firstFrame := f.firstFrames[temporal] + secondFrame := f.secondFrames[temporal] + if firstFrame == nil { + f.firstFrames[temporal] = fi + firstFrame = fi + } else { + if (secondFrame == nil || secondFrame.fn < fn) && fn != firstFrame.fn && (fn-firstFrame.fn) < 0x4000 { + f.secondFrames[temporal] = fi + } + } + + return f.calc() +} + +func (f *frameRateCalculatorVPx) calc() bool { + var rateCounter int + for currentTemporal := int32(0); currentTemporal <= DefaultMaxLayerTemporal; currentTemporal++ { + if f.frameRates[currentTemporal] > 0 { + rateCounter++ + continue + } + + ff := f.firstFrames[currentTemporal] + sf := f.secondFrames[currentTemporal] + // lower temporal layer has been calculated, but higher layer has not received any frames, it should not exist + if rateCounter > 0 && ff == nil { + rateCounter++ + continue + } + if ff == nil || sf == nil { + continue + } + + var frameCount int + lastTs := ff.ts + for j := ff.fn - f.baseFrame.fn + 1; j < sf.fn-f.baseFrame.fn+1; j++ { + if f := f.fnReceived[j]; f == nil { + break + } else if f.temporal <= currentTemporal { + frameCount++ + lastTs = f.ts + } + } + if frameCount >= minFramesForCalculation[currentTemporal] { + f.frameRates[currentTemporal] = float32(f.clockRate) / float32(lastTs-ff.ts) * float32(frameCount) + rateCounter++ + } + } + + if rateCounter == len(f.frameRates) { + f.completed = true + + // normalize frame rates, Microsoft Edge use 3 temporal layers for vp8 but the middle layer has chance to + // get a very low frame rate, so we need to normalize the frame rate(use fixed ration 1:2 of highest layer for that layer) + if f.frameRates[2] > 0 && f.frameRates[2] > f.frameRates[1]*3 { + f.frameRates[1] = f.frameRates[2] / 2 + } + f.reset() + return true + } + return false +} + +func (f *frameRateCalculatorVPx) reset() { + for i := range f.firstFrames { + f.firstFrames[i] = nil + f.secondFrames[i] = nil + } + + for i := range f.fnReceived { + f.fnReceived[i] = nil + } + f.baseFrame = nil +} + +func (f *frameRateCalculatorVPx) GetFrameRate() []float32 { + return f.frameRates[:] +} + +// ----------------------------- + +// FrameRateCalculator based on PictureID in VP8 +type FrameRateCalculatorVP8 struct { + *frameRateCalculatorVPx + logger logger.Logger +} + +func NewFrameRateCalculatorVP8(clockRate uint32, logger logger.Logger) *FrameRateCalculatorVP8 { + return &FrameRateCalculatorVP8{ + frameRateCalculatorVPx: newFrameRateCalculatorVPx(clockRate, logger), + logger: logger, + } +} + +func (f *FrameRateCalculatorVP8) RecvPacket(ep *ExtPacket) bool { + if f.frameRateCalculatorVPx.Completed() { + return true + } + + vp8, ok := ep.Payload.(VP8) + if !ok { + f.logger.Debugw("no vp8 payload", "sn", ep.Packet.SequenceNumber) + return false + } + success := f.frameRateCalculatorVPx.RecvPacket(ep, vp8.PictureID) + + if f.frameRateCalculatorVPx.Completed() { + f.logger.Debugw("frame rate calculated", "rate", f.frameRateCalculatorVPx.GetFrameRate()) + } + + return success +} + +// ----------------------------- + +// FrameRateCalculator based on PictureID in VP9 +type FrameRateCalculatorVP9 struct { + logger logger.Logger + completed bool + + // VP9-TODO - this is assuming three spatial layers. As `completed` marker relies on all layers being finished, have to assume this. FIX. + // Maybe look at number of layers in livekit.TrackInfo and declare completed once advertised layers are measured + frameRateCalculatorsVPx [DefaultMaxLayerSpatial + 1]*frameRateCalculatorVPx +} + +func NewFrameRateCalculatorVP9(clockRate uint32, logger logger.Logger) *FrameRateCalculatorVP9 { + f := &FrameRateCalculatorVP9{ + logger: logger, + } + + for i := range f.frameRateCalculatorsVPx { + f.frameRateCalculatorsVPx[i] = newFrameRateCalculatorVPx(clockRate, logger) + } + + return f +} + +func (f *FrameRateCalculatorVP9) Completed() bool { + return f.completed +} + +func (f *FrameRateCalculatorVP9) RecvPacket(ep *ExtPacket) bool { + if f.completed { + return true + } + + vp9, ok := ep.Payload.(codecs.VP9Packet) + if !ok { + f.logger.Debugw("no vp9 payload", "sn", ep.Packet.SequenceNumber) + return false + } + + if ep.Spatial < 0 || ep.Spatial >= int32(len(f.frameRateCalculatorsVPx)) || f.frameRateCalculatorsVPx[ep.Spatial] == nil { + f.logger.Debugw("invalid spatial layer", "sn", ep.Packet.SequenceNumber, "spatial", ep.Spatial) + return false + } + + success := f.frameRateCalculatorsVPx[ep.Spatial].RecvPacket(ep, vp9.PictureID) + + completed := true + for _, frc := range f.frameRateCalculatorsVPx { + if !frc.Completed() { + completed = false + break + } + } + + if completed { + f.completed = true + + var frameRates [DefaultMaxLayerSpatial + 1][]float32 + for i := range f.frameRateCalculatorsVPx { + frameRates[i] = f.frameRateCalculatorsVPx[i].GetFrameRate() + } + f.logger.Debugw("frame rate calculated", "rate", frameRates) + } + + return success +} + +func (f *FrameRateCalculatorVP9) GetFrameRateForSpatial(spatial int32) []float32 { + if spatial < 0 || spatial >= int32(len(f.frameRateCalculatorsVPx)) || f.frameRateCalculatorsVPx[spatial] == nil { + return nil + } + return f.frameRateCalculatorsVPx[spatial].GetFrameRate() +} + +func (f *FrameRateCalculatorVP9) GetFrameRateCalculatorForSpatial(spatial int32) *FrameRateCalculatorForVP9Layer { + return &FrameRateCalculatorForVP9Layer{ + FrameRateCalculatorVP9: f, + spatial: spatial, + } +} + +// ----------------------------- + +type FrameRateCalculatorForVP9Layer struct { + *FrameRateCalculatorVP9 + spatial int32 +} + +func (f *FrameRateCalculatorForVP9Layer) GetFrameRate() []float32 { + return f.FrameRateCalculatorVP9.GetFrameRateForSpatial(f.spatial) +} + +// ----------------------------------------------- + +// FrameRateCalculator based on Dependency descriptor +type FrameRateCalculatorDD struct { + frameRates [DefaultMaxLayerSpatial + 1][DefaultMaxLayerTemporal + 1]float32 + clockRate uint32 + logger logger.Logger + firstFrames [DefaultMaxLayerSpatial + 1][DefaultMaxLayerTemporal + 1]*frameInfo + secondFrames [DefaultMaxLayerSpatial + 1][DefaultMaxLayerTemporal + 1]*frameInfo + fnReceived [256]*frameInfo + baseFrame *frameInfo + completed bool + + // frames for each decode target + targetFrames [DefaultMaxLayerSpatial + 1][DefaultMaxLayerTemporal + 1]list.List + + maxSpatial, maxTemporal int32 +} + +func NewFrameRateCalculatorDD(clockRate uint32, logger logger.Logger) *FrameRateCalculatorDD { + return &FrameRateCalculatorDD{ + clockRate: clockRate, + logger: logger, + maxSpatial: DefaultMaxLayerSpatial, + maxTemporal: DefaultMaxLayerTemporal, + } +} + +func (f *FrameRateCalculatorDD) Completed() bool { + return f.completed +} + +func (f *FrameRateCalculatorDD) SetMaxLayer(spatial, temporal int32) { + f.maxSpatial, f.maxTemporal = spatial, temporal +} + +func (f *FrameRateCalculatorDD) RecvPacket(ep *ExtPacket) bool { + if f.completed { + return true + } + + if ep.DependencyDescriptor == nil { + f.logger.Debugw("dependency descriptor is nil") + return false + } + + spatial := ep.Spatial + // non-SVC codec will set spatial to -1 + if spatial < 0 { + spatial = 0 + } + temporal := ep.Temporal + if temporal < 0 || temporal > DefaultMaxLayerTemporal || spatial > DefaultMaxLayerSpatial { + f.logger.Warnw("invalid spatial or temporal", nil, "spatial", spatial, "temporal", temporal, "sn", ep.Packet.SequenceNumber) + return false + } + + fn := ep.DependencyDescriptor.Descriptor.FrameNumber + if f.baseFrame == nil { + f.baseFrame = &frameInfo{ts: ep.Packet.Timestamp, fn: fn} + f.fnReceived[0] = f.baseFrame + f.firstFrames[spatial][temporal] = f.baseFrame + f.secondFrames[spatial][temporal] = f.baseFrame + return false + } + + baseDiff := fn - f.baseFrame.fn + if baseDiff == 0 || baseDiff > 0x8000 { + return false + } + + if baseDiff >= uint16(len(f.fnReceived)) { + // frame number is not continuous, reset + f.baseFrame = nil + for i := range f.firstFrames { + for j := range f.firstFrames[i] { + f.firstFrames[i][j] = nil + f.secondFrames[i][j] = nil + f.targetFrames[i][j].Init() + } + } + for i := range f.fnReceived { + f.fnReceived[i] = nil + } + return false + } + + if f.fnReceived[baseDiff] != nil { + return false + } + + fi := &frameInfo{ + ts: ep.Packet.Timestamp, + fn: fn, + temporal: temporal, + spatial: spatial, + frameDiff: ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs, + } + f.fnReceived[baseDiff] = fi + + if f.firstFrames[spatial][temporal] == nil { + f.firstFrames[spatial][temporal] = fi + f.secondFrames[spatial][temporal] = fi + return false + } + + chain := &f.targetFrames[spatial][temporal] + if chain.Len() == 0 { + chain.PushBack(fn) + } + for _, fdiff := range ep.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs { + dependFrame := fn - uint16(fdiff) + // frame too old, ignore + if dependFrame-f.secondFrames[spatial][temporal].fn > 0x8000 { + continue + } + + insertFrame: + for e := chain.Back(); e != nil; e = e.Prev() { + val := e.Value.(uint16) + switch { + case val == dependFrame: + break insertFrame + case sn16LT(val, dependFrame): + chain.InsertAfter(dependFrame, e) + break insertFrame + default: + if e == chain.Front() { + chain.PushFront(dependFrame) + break insertFrame + } + } + } + } + return f.calc() +} + +func (f *FrameRateCalculatorDD) calc() bool { + var rateCounter int + for currentSpatial := int32(0); currentSpatial <= f.maxSpatial; currentSpatial++ { + var currentSpatialRateCounter int + for currentTemporal := int32(0); currentTemporal <= f.maxTemporal; currentTemporal++ { + if f.frameRates[currentSpatial][currentTemporal] > 0 { + rateCounter++ + currentSpatialRateCounter++ + continue + } + + firstFrame := f.firstFrames[currentSpatial][currentTemporal] + // lower temporal layer has been calculated, but higher layer has not received any frames, it should not exist + if currentSpatialRateCounter > 0 && firstFrame == nil { + currentSpatialRateCounter++ + rateCounter++ + continue + } + + chain := &f.targetFrames[currentSpatial][currentTemporal] + + // find last decodable frame (no dependency frame is lost) + var lastFrame *frameInfo + for e := chain.Front(); e != nil; e = e.Next() { + diff := e.Value.(uint16) - f.baseFrame.fn + if diff >= uint16(len(f.fnReceived)) { + continue + } + + fi := f.fnReceived[diff] + if fi == nil { + break + } else { + lastFrame = fi + if firstFrame == nil && fi.spatial == currentSpatial && fi.temporal == currentTemporal { + firstFrame = fi + } + } + } + + if lastFrame != nil && lastFrame.fn > f.secondFrames[currentSpatial][currentTemporal].fn { + f.secondFrames[currentSpatial][currentTemporal] = lastFrame + } else { + continue + } + + frameCount := 0 + for i := firstFrame.fn - f.baseFrame.fn; i <= lastFrame.fn-f.baseFrame.fn; i++ { + fi := f.fnReceived[i] + if fi == nil { + continue + } + if fi.spatial == currentSpatial && fi.temporal <= currentTemporal { + frameCount++ + } + } + + if frameCount >= minFramesForCalculation[currentTemporal] && lastFrame.ts > firstFrame.ts { + f.frameRates[currentSpatial][currentTemporal] = float32(f.clockRate) / float32(lastFrame.ts-firstFrame.ts) * float32(frameCount) + rateCounter++ + } + } + } + + if rateCounter == int(f.maxSpatial+1)*int(f.maxTemporal+1) { + f.completed = true + f.close() + + f.logger.Debugw("frame rate calculated", "rate", f.frameRates) + return true + } + return false +} + +func (f *FrameRateCalculatorDD) GetFrameRateForSpatial(spatial int32) []float32 { + if spatial < 0 || spatial >= int32(len(f.frameRates)) { + return nil + } + return f.frameRates[spatial][:] +} + +func (f *FrameRateCalculatorDD) close() { + f.baseFrame = nil + for i := range f.firstFrames { + for j := range f.firstFrames[i] { + f.firstFrames[i][j] = nil + f.secondFrames[i][j] = nil + } + } + + for i := range f.fnReceived { + f.fnReceived[i] = nil + } + for i := range f.targetFrames { + for j := range f.targetFrames[i] { + f.targetFrames[i][j].Init() + } + } +} + +func (f *FrameRateCalculatorDD) GetFrameRateCalculatorForSpatial(spatial int32) *FrameRateCalculatorForDDLayer { + return &FrameRateCalculatorForDDLayer{ + FrameRateCalculatorDD: f, + spatial: spatial, + } +} + +// ----------------------------------------------- + +type FrameRateCalculatorForDDLayer struct { + *FrameRateCalculatorDD + spatial int32 +} + +func (f *FrameRateCalculatorForDDLayer) GetFrameRate() []float32 { + return f.FrameRateCalculatorDD.GetFrameRateForSpatial(f.spatial) +} + +// ----------------------------------------------- + +type FrameRateCalculatorH26x struct { + frameRates [DefaultMaxLayerTemporal + 1]float32 + clockRate uint32 + logger logger.Logger + fnReceived *list.List + baseFrame *frameInfo + completed bool +} + +func NewFrameRateCalculatorH26x(clockRate uint32, logger logger.Logger) *FrameRateCalculatorH26x { + return &FrameRateCalculatorH26x{ + clockRate: clockRate, + logger: logger, + } +} + +func (f *FrameRateCalculatorH26x) Completed() bool { + return f.completed +} + +func (f *FrameRateCalculatorH26x) RecvPacket(ep *ExtPacket) bool { + if f.completed { + return true + } + + if ep.Temporal >= int32(len(f.frameRates)) { + f.logger.Warnw("invalid temporal layer", nil, "temporal", ep.Temporal) + return false + } + + temporal := ep.Temporal + if temporal < 0 { + temporal = 0 + } + + if f.baseFrame == nil { + f.baseFrame = &frameInfo{ + startSeq: ep.Packet.SequenceNumber, + endSeq: ep.Packet.SequenceNumber, + ts: ep.Packet.Timestamp, + temporal: temporal, + } + f.fnReceived = list.New() + f.fnReceived.PushBack(f.baseFrame) + return false + } + + if sn16LTOrEqual(ep.Packet.SequenceNumber, f.baseFrame.startSeq) { + return false + } + +insertFrame: + for e := f.fnReceived.Back(); e != nil; e = e.Prev() { + frame := e.Value.(*frameInfo) + switch { + case frame.ts == ep.Packet.Timestamp: + if sn16LT(frame.endSeq, ep.Packet.SequenceNumber) { + frame.endSeq = ep.Packet.SequenceNumber + } + if sn16LT(ep.Packet.SequenceNumber, frame.startSeq) { + frame.startSeq = ep.Packet.SequenceNumber + } + break insertFrame + case sn32LT(frame.ts, ep.Packet.Timestamp): + f.fnReceived.InsertAfter(&frameInfo{ + startSeq: ep.Packet.SequenceNumber, + endSeq: ep.Packet.SequenceNumber, + ts: ep.Packet.Timestamp, + temporal: temporal, + }, e) + break insertFrame + default: + if e == f.fnReceived.Front() { + f.fnReceived.PushFront(&frameInfo{ + startSeq: ep.Packet.SequenceNumber, + endSeq: ep.Packet.SequenceNumber, + ts: ep.Packet.Timestamp, + temporal: temporal, + }) + break insertFrame + } + } + } + + return f.calc() +} + +func (f *FrameRateCalculatorH26x) calc() bool { + frameCounts := make([]int, DefaultMaxLayerTemporal+1) + var totalFrameCount int + var tsDuration int + cur := f.fnReceived.Front() + for { + next := cur.Next() + if next == nil { + break + } + ff := cur.Value.(*frameInfo) + nf := next.Value.(*frameInfo) + if nf.startSeq-ff.endSeq == 1 { + totalFrameCount++ + tsDuration += int(nf.ts - ff.ts) + for i := int(nf.temporal); i < len(frameCounts); i++ { + frameCounts[i]++ + } + } else { + // reset to find continuous frames + totalFrameCount = 0 + for i := range frameCounts { + frameCounts[i] = 0 + } + tsDuration = 0 + } + + // received enough continuous frames, calculate fps + if totalFrameCount >= minFramesForCalculation[DefaultMaxLayerTemporal] { + for currentTemporal := int32(0); currentTemporal <= DefaultMaxLayerTemporal; currentTemporal++ { + count := frameCounts[currentTemporal] + if currentTemporal > 0 && count == frameCounts[currentTemporal-1] { + // no frames for this temporal layer + f.frameRates[currentTemporal] = 0 + } else { + f.frameRates[currentTemporal] = float32(f.clockRate) / float32(tsDuration) * float32(count) + } + } + f.logger.Debugw("fps changed", "fps", f.GetFrameRate()) + f.completed = true + f.reset() + return true + } + + cur = next + } + + return false +} + +func (f *FrameRateCalculatorH26x) reset() { + f.fnReceived.Init() + f.baseFrame = nil +} + +func (f *FrameRateCalculatorH26x) GetFrameRate() []float32 { + return f.frameRates[:] +} + +// ----------------------------------------------- +func sn16LT(a, b uint16) bool { + return a-b > 0x8000 +} + +func sn16LTOrEqual(a, b uint16) bool { + return a == b || a-b > 0x8000 +} + +func sn32LT(a, b uint32) bool { + return a-b > 0x80000000 +} diff --git a/livekit/pkg/sfu/buffer/fps_test.go b/livekit/pkg/sfu/buffer/fps_test.go new file mode 100644 index 0000000..e9071eb --- /dev/null +++ b/livekit/pkg/sfu/buffer/fps_test.go @@ -0,0 +1,437 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "testing" + + "github.com/pion/rtp" + "github.com/stretchr/testify/require" + + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/protocol/logger" +) + +type testFrameInfo struct { + header rtp.Header + framenumber uint16 + spatial int + temporal int + frameDiff []int +} + +func (f *testFrameInfo) toVP8() *ExtPacket { + return &ExtPacket{ + Packet: &rtp.Packet{Header: f.header}, + Payload: VP8{ + PictureID: f.framenumber, + }, + VideoLayer: VideoLayer{Spatial: InvalidLayerSpatial, Temporal: int32(f.temporal)}, + } +} + +func (f *testFrameInfo) toDD() *ExtPacket { + return &ExtPacket{ + Packet: &rtp.Packet{Header: f.header}, + DependencyDescriptor: &ExtDependencyDescriptor{ + Descriptor: &dd.DependencyDescriptor{ + FrameNumber: f.framenumber, + FrameDependencies: &dd.FrameDependencyTemplate{ + FrameDiffs: f.frameDiff, + }, + }, + }, + VideoLayer: VideoLayer{Spatial: int32(f.spatial), Temporal: int32(f.temporal)}, + } +} + +func (f *testFrameInfo) toH26x() *ExtPacket { + return &ExtPacket{ + Packet: &rtp.Packet{Header: f.header}, + VideoLayer: VideoLayer{Spatial: InvalidLayerSpatial, Temporal: int32(f.temporal)}, + } +} + +func createFrames(startFrameNumber uint16, startTs uint32, startSeq uint16, totalFramesPerSpatial int, fps [][]float32, spatialDependency bool) [][]*testFrameInfo { + spatials := len(fps) + temporals := len(fps[0]) + frames := make([][]*testFrameInfo, spatials) + for s := 0; s < spatials; s++ { + frames[s] = make([]*testFrameInfo, 0, totalFramesPerSpatial) + } + fn := startFrameNumber + + nextTs := make([][]uint32, spatials) + tsStep := make([][]uint32, spatials) + for i := range spatials { + nextTs[i] = make([]uint32, temporals) + tsStep[i] = make([]uint32, temporals) + for j := 0; j < temporals; j++ { + nextTs[i][j] = startTs + tsStep[i][j] = uint32(90000 / fps[i][j]) + } + } + + currentTs := make([]uint32, spatials) + for i := range spatials { + currentTs[i] = startTs + } + for range totalFramesPerSpatial { + for s := range spatials { + frame := &testFrameInfo{ + header: rtp.Header{Timestamp: currentTs[s], SequenceNumber: startSeq}, + framenumber: fn, + spatial: s, + } + for t := 0; t < temporals; t++ { + if currentTs[s] >= nextTs[s][t] { + frame.temporal = t + for nt := t; nt < temporals; nt++ { + nextTs[s][nt] += tsStep[s][nt] + } + break + } + } + currentTs[s] += tsStep[s][temporals-1] + frames[s] = append(frames[s], frame) + fn++ + startSeq++ + + for fidx := len(frames[s]) - 1; fidx >= 0; fidx-- { + cf := frames[s][fidx] + if cf.header.Timestamp-frame.header.Timestamp > 0x80000000 { + frame.frameDiff = append(frame.frameDiff, int(frame.framenumber-cf.framenumber)) + break + } + } + + if spatialDependency && frame.spatial > 0 { + for fidx := len(frames[frame.spatial-1]) - 1; fidx >= 0; fidx-- { + cf := frames[frame.spatial-1][fidx] + if cf.header.Timestamp == frame.header.Timestamp { + frame.frameDiff = append(frame.frameDiff, int(frame.framenumber-cf.framenumber)) + break + } + } + } + } + } + + return frames +} + +func verifyFps(t *testing.T, expect, got []float32) { + require.Equal(t, len(expect), len(got)) + for i := range expect { + require.GreaterOrEqual(t, got[i], expect[i]*0.9, "expect %v, got %v", expect, got) + require.LessOrEqual(t, got[i], expect[i]*1.1, "expect %v, got %v", expect, got) + } +} + +type testcase struct { + startTs uint32 + startSeq uint16 + startFrameNumber uint16 + fps [][]float32 + spatialDependency bool +} + +func TestFpsVP8(t *testing.T) { + cases := map[string]testcase{ + "normal": { + startTs: 12345678, + startFrameNumber: 100, + fps: [][]float32{{5, 10, 15}, {5, 10, 15}, {7.5, 15, 30}}, + }, + "frame number and timestamp wrap": { + startTs: (uint32(1) << 31) - 10, + startFrameNumber: (uint16(1) << 15) - 10, + fps: [][]float32{{5, 10, 15}, {5, 10, 15}, {7.5, 15, 30}}, + }, + "2 temporal layers": { + startTs: 12345678, + startFrameNumber: 100, + fps: [][]float32{{7.5, 15}, {7.5, 15}, {15, 30}}, + }, + } + + for name, c := range cases { + testCase := c + t.Run(name, func(t *testing.T) { + fps := testCase.fps + frames := make([][]*testFrameInfo, 0) + vp8calcs := make([]*FrameRateCalculatorVP8, len(fps)) + for i := range vp8calcs { + vp8calcs[i] = NewFrameRateCalculatorVP8(90000, logger.GetLogger()) + frames = append(frames, createFrames(c.startFrameNumber, c.startTs, 10, 200, [][]float32{fps[i]}, false)[0]) + } + + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if vp8calcs[s].RecvPacket(f.toVP8()) { + frameratesGot = true + for _, calc := range vp8calcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range vp8calcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) + } + t.Run("packet lost and duplicate", func(t *testing.T) { + fps := [][]float32{{7.5, 15}, {7.5, 15}, {15, 30}} + frames := make([][]*testFrameInfo, 0) + vp8calcs := make([]*FrameRateCalculatorVP8, len(fps)) + for i := range vp8calcs { + vp8calcs[i] = NewFrameRateCalculatorVP8(90000, logger.GetLogger()) + frames = append(frames, createFrames(100, 12345678, 10, 300, [][]float32{fps[i]}, false)[0]) + for j := 5; j < 130; j++ { + if j%2 == 0 { + frames[i][j] = frames[i][j-1] + } + } + } + + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if vp8calcs[s].RecvPacket(f.toVP8()) { + frameratesGot = true + for _, calc := range vp8calcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range vp8calcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) +} + +func TestFpsDD(t *testing.T) { + cases := map[string]testcase{ + "normal": { + startTs: 12345678, + startFrameNumber: 100, + fps: [][]float32{{5.1, 10.1, 16}, {5.1, 10.1, 16}, {8, 15, 30.1}}, + spatialDependency: true, + }, + "frame number and timestamp wrap": { + startTs: (uint32(1) << 31) - 10, + startFrameNumber: (uint16(1) << 15) - 10, + fps: [][]float32{{7.5, 15, 30}, {7.5, 15, 30}, {7.5, 15, 30}}, + spatialDependency: true, + }, + "vp8": { + startTs: 12345678, + startFrameNumber: 100, + fps: [][]float32{{7.5, 15}, {7.5, 15}, {15, 30}}, + spatialDependency: false, + }, + } + + for name, c := range cases { + testCase := c + t.Run(name, func(t *testing.T) { + fps := testCase.fps + frames := createFrames(c.startFrameNumber, c.startTs, 10, 500, fps, testCase.spatialDependency) + ddcalc := NewFrameRateCalculatorDD(90000, logger.GetLogger()) + ddcalc.SetMaxLayer(int32(len(fps)-1), int32(len(fps[0])-1)) + ddcalcs := make([]FrameRateCalculator, len(fps)) + for i := range fps { + ddcalcs[i] = ddcalc.GetFrameRateCalculatorForSpatial(int32(i)) + } + + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if ddcalcs[s].RecvPacket(f.toDD()) { + frameratesGot = true + for _, calc := range ddcalcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range ddcalcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) + } + + t.Run("packet lost and duplicate", func(t *testing.T) { + fps := [][]float32{{7.5, 15, 30}, {7.5, 15, 30}, {7.5, 15, 30}} + frames := createFrames(100, 12345678, 10, 500, fps, true) + ddcalc := NewFrameRateCalculatorDD(90000, logger.GetLogger()) + ddcalc.SetMaxLayer(int32(len(fps)-1), int32(len(fps[0])-1)) + ddcalcs := make([]FrameRateCalculator, len(fps)) + for i := range fps { + ddcalcs[i] = ddcalc.GetFrameRateCalculatorForSpatial(int32(i)) + for j := 5; j < 130; j++ { + if j%2 == 0 { + frames[i][j] = frames[i][j-1] + } + } + } + + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if ddcalcs[s].RecvPacket(f.toDD()) { + frameratesGot = true + for _, calc := range ddcalcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range ddcalcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) +} + +func TestFpsH26x(t *testing.T) { + cases := map[string]testcase{ + "normal": { + startTs: 12345678, + startSeq: 100, + startFrameNumber: 100, + fps: [][]float32{{5, 10, 15}, {5, 10, 15}, {7.5, 15, 30}}, + }, + "frame number and timestamp wrap": { + startTs: (uint32(1) << 31) - 10, + startSeq: (uint16(1) << 15) - 10, + startFrameNumber: (uint16(1) << 15) - 10, + fps: [][]float32{{5, 10, 15}, {5, 10, 15}, {7.5, 15, 30}}, + }, + "2 temporal layers": { + startTs: 12345678, + startFrameNumber: 100, + fps: [][]float32{{7.5, 15}, {7.5, 15}, {15, 30}}, + }, + } + + for name, c := range cases { + testCase := c + t.Run(name, func(t *testing.T) { + fps := testCase.fps + frames := make([][]*testFrameInfo, 0) + h26xcalcs := make([]*FrameRateCalculatorH26x, len(fps)) + for i := range h26xcalcs { + h26xcalcs[i] = NewFrameRateCalculatorH26x(90000, logger.GetLogger()) + frames = append(frames, createFrames(c.startFrameNumber, c.startTs, c.startSeq, 200, [][]float32{fps[i]}, false)[0]) + } + + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if h26xcalcs[s].RecvPacket(f.toH26x()) { + frameratesGot = true + for _, calc := range h26xcalcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range h26xcalcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) + } + + t.Run("packet lost and duplicate", func(t *testing.T) { + fps := [][]float32{{7.5, 15, 30}, {7.5, 15, 30}, {7.5, 15, 30}} + frames := make([][]*testFrameInfo, 0, len(fps)) + h26xcalcs := make([]FrameRateCalculator, len(fps)) + for i := range fps { + frames = append(frames, createFrames(100, 12345678, 10, 500, [][]float32{fps[i]}, false)[0]) + h26xcalcs[i] = NewFrameRateCalculatorH26x(90000, logger.GetLogger()) + for j := 5; j < 130; j++ { + if j%2 == 0 { + frames[i][j] = frames[i][j-1] + } + } + for j := 130; j < 230; j++ { + if j%3 == 0 { + frames[i][j] = nil + } + } + for j := 230; j < 330; j++ { + if j%2 == 0 { + frames[i][j], frames[i][j-1] = frames[i][j-1], frames[i][j] + } + } + } + var frameratesGot bool + for s, fs := range frames { + for _, f := range fs { + if f == nil { + continue + } + if h26xcalcs[s].RecvPacket(f.toH26x()) { + frameratesGot = true + for _, calc := range h26xcalcs { + if !calc.Completed() { + frameratesGot = false + break + } + } + } + } + } + require.True(t, frameratesGot) + for i, calc := range h26xcalcs { + fpsExpected := fps[i] + fpsGot := calc.GetFrameRate() + verifyFps(t, fpsExpected, fpsGot[:len(fpsExpected)]) + } + }) +} diff --git a/livekit/pkg/sfu/buffer/frameintegrity.go b/livekit/pkg/sfu/buffer/frameintegrity.go new file mode 100644 index 0000000..935263b --- /dev/null +++ b/livekit/pkg/sfu/buffer/frameintegrity.go @@ -0,0 +1,225 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" +) + +type FrameEntity struct { + startSeq *uint64 + endSeq *uint64 + integrity bool + + pktHistory *PacketHistory +} + +func (fe *FrameEntity) AddPacket(extSeq uint64, ddVal *dd.DependencyDescriptor) { + // duplicate packet + if fe.integrity { + return + } + + if fe.startSeq == nil && ddVal.FirstPacketInFrame { + fe.startSeq = &extSeq + } + if fe.endSeq == nil && ddVal.LastPacketInFrame { + fe.endSeq = &extSeq + } + + if fe.startSeq != nil && fe.endSeq != nil { + if fe.pktHistory.PacketsConsecutive(*fe.startSeq, *fe.endSeq) { + fe.integrity = true + } + } +} + +func (fe *FrameEntity) Reset() { + fe.integrity = false + fe.startSeq, fe.endSeq = nil, nil +} + +func (fe *FrameEntity) Integrity() bool { + return fe.integrity +} + +// ------------------------------ + +type PacketHistory struct { + base uint64 + last uint64 + bits []uint64 + packetCount int + inited bool +} + +func NewPacketHistory(packetCount int) *PacketHistory { + packetCount = (packetCount + 63) / 64 * 64 + return &PacketHistory{ + bits: make([]uint64, packetCount/64), + packetCount: packetCount, + } +} + +func (ph *PacketHistory) AddPacket(extSeq uint64) { + if !ph.inited { + ph.inited = true + ph.base = uint64(extSeq) + // set base to extSeq-100 to avoid out-of-order packets belongs to first frame to be dropped + if ph.base > 100 { + ph.base -= 100 + } else { + ph.base = 0 + } + ph.last = uint64(extSeq) + ph.set(extSeq, true) + return + } + + if extSeq <= ph.base { + // too old + return + } + + if extSeq <= ph.last { + if ph.last-extSeq < uint64(ph.packetCount) { + ph.set(extSeq, true) + } + return + } + + for i := ph.last + 1; i < extSeq; i++ { + ph.set(i, false) + } + + ph.set(extSeq, true) + ph.last = extSeq +} + +func (ph *PacketHistory) getPos(seq uint64) (index, offset int) { + idx := (seq - ph.base) % uint64(ph.packetCount) + return int(idx >> 6), int(idx % 64) +} + +func (ph *PacketHistory) set(seq uint64, received bool) { + idx, offset := ph.getPos(seq) + if !received { + ph.bits[idx] &= ^(1 << offset) + } else { + ph.bits[idx] |= 1 << (offset) + } +} + +func (ph *PacketHistory) PacketsConsecutive(start, end uint64) bool { + if start > end { + return false + } + + if end-start >= uint64(ph.packetCount) { + return false + } + + startIndex, startOffset := ph.getPos(start) + endIndex, endOffset := ph.getPos(end) + + if startIndex == endIndex && end-start <= 64 { + testBits := uint64((1<<(endOffset-startOffset+1))-1) << startOffset + return ph.bits[startIndex]&testBits == testBits + } + + if (ph.bits[startIndex]>>(startOffset))+1 != 1<<(64-startOffset) { + return false + } + + for i := startIndex + 1; i != endIndex; i++ { + if i == len(ph.bits) { + i = 0 + if i == endIndex { + break + } + } + if ph.bits[i]+1 != 0 { + return false + } + } + + testBits := uint64((1 << (endOffset + 1)) - 1) + return ph.bits[endIndex]&testBits == testBits +} + +// ------------------------------ + +type FrameIntegrityChecker struct { + frameCount int + frames []FrameEntity + base uint64 + last uint64 + + pktHistory *PacketHistory + inited bool +} + +func NewFrameIntegrityChecker(frameCount, packetCount int) *FrameIntegrityChecker { + fc := &FrameIntegrityChecker{ + frames: make([]FrameEntity, frameCount), + pktHistory: NewPacketHistory(packetCount), + frameCount: frameCount, + } + + for i := range fc.frames { + fc.frames[i].pktHistory = fc.pktHistory + fc.frames[i].Reset() + } + return fc +} + +func (fc *FrameIntegrityChecker) AddPacket(extSeq uint64, extFrameNum uint64, ddVal *dd.DependencyDescriptor) { + fc.pktHistory.AddPacket(extSeq) + + if !fc.inited { + fc.inited = true + fc.base = extFrameNum + fc.last = extFrameNum + } + + if extFrameNum < fc.base { + // frame too old + return + } + + if extFrameNum <= fc.last { + if fc.last-extFrameNum >= uint64(fc.frameCount) { + // frame too old + return + } + fc.frames[int(extFrameNum-fc.base)%fc.frameCount].AddPacket(extSeq, ddVal) + return + } + + // reset missing frames + for i := fc.last + 1; i <= extFrameNum; i++ { + fc.frames[int(i-fc.base)%fc.frameCount].Reset() + } + fc.frames[int(extFrameNum-fc.base)%fc.frameCount].AddPacket(extSeq, ddVal) + fc.last = extFrameNum +} + +func (fc *FrameIntegrityChecker) FrameIntegrity(extFrameNum uint64) bool { + if extFrameNum < fc.base || extFrameNum > fc.last || fc.last-extFrameNum >= uint64(fc.frameCount) { + return false + } + + return fc.frames[int(extFrameNum-fc.base)%fc.frameCount].Integrity() +} diff --git a/livekit/pkg/sfu/buffer/frameintegrity_test.go b/livekit/pkg/sfu/buffer/frameintegrity_test.go new file mode 100644 index 0000000..2815cf5 --- /dev/null +++ b/livekit/pkg/sfu/buffer/frameintegrity_test.go @@ -0,0 +1,86 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" +) + +func TestFrameIntegrityChecker(t *testing.T) { + fc := NewFrameIntegrityChecker(100, 1000) + + // first frame out of order + fc.AddPacket(10, 10, &dd.DependencyDescriptor{}) + require.False(t, fc.FrameIntegrity(10)) + fc.AddPacket(9, 10, &dd.DependencyDescriptor{FirstPacketInFrame: true}) + require.False(t, fc.FrameIntegrity(10)) + fc.AddPacket(11, 10, &dd.DependencyDescriptor{LastPacketInFrame: true}) + require.True(t, fc.FrameIntegrity(10)) + + // single packet frame + fc.AddPacket(100, 100, &dd.DependencyDescriptor{FirstPacketInFrame: true, LastPacketInFrame: true}) + require.True(t, fc.FrameIntegrity(100)) + require.False(t, fc.FrameIntegrity(101)) + require.False(t, fc.FrameIntegrity(99)) + + // frame too old than first frame + fc.AddPacket(99, 99, &dd.DependencyDescriptor{FirstPacketInFrame: true, LastPacketInFrame: true}) + + // multiple packet frame, out of order + fc.AddPacket(2001, 2001, &dd.DependencyDescriptor{}) + require.False(t, fc.FrameIntegrity(2001)) + require.False(t, fc.FrameIntegrity(1999)) + // out of frame count(100) + require.False(t, fc.FrameIntegrity(100)) + require.False(t, fc.FrameIntegrity(1900)) + + fc.AddPacket(2000, 2001, &dd.DependencyDescriptor{FirstPacketInFrame: true}) + require.False(t, fc.FrameIntegrity(2001)) + fc.AddPacket(2002, 2001, &dd.DependencyDescriptor{LastPacketInFrame: true}) + require.True(t, fc.FrameIntegrity(2001)) + // duplicate packet + fc.AddPacket(2001, 2001, &dd.DependencyDescriptor{}) + require.True(t, fc.FrameIntegrity(2001)) + + // frame too old + fc.AddPacket(900, 1900, &dd.DependencyDescriptor{FirstPacketInFrame: true, LastPacketInFrame: true}) + require.False(t, fc.FrameIntegrity(1900)) + + for frame := uint64(2002); frame < 2102; frame++ { + // large frame (1000 packets) out of order / retransmitted + firstFrame := uint64(3000 + (frame-2002)*1000) + lastFrame := uint64(3999 + (frame-2002)*1000) + frames := make([]uint64, 0, lastFrame-firstFrame+1) + for i := firstFrame; i <= lastFrame; i++ { + frames = append(frames, i) + } + require.False(t, fc.FrameIntegrity(frame)) + rand.Seed(int64(frame)) + rand.Shuffle(len(frames), func(i, j int) { frames[i], frames[j] = frames[j], frames[i] }) + for i, f := range frames { + fc.AddPacket(f, frame, &dd.DependencyDescriptor{ + FirstPacketInFrame: f == firstFrame, + LastPacketInFrame: f == lastFrame, + }) + require.Equal(t, i == len(frames)-1, fc.FrameIntegrity(frame), i) + } + require.True(t, fc.FrameIntegrity(frame)) + } +} diff --git a/livekit/pkg/sfu/buffer/h26xhelper.go b/livekit/pkg/sfu/buffer/h26xhelper.go new file mode 100644 index 0000000..f00be0c --- /dev/null +++ b/livekit/pkg/sfu/buffer/h26xhelper.go @@ -0,0 +1,527 @@ +package buffer + +import ( + "errors" + "fmt" +) + +// SPSInfo holds parsed SPS parameters +type SPSInfo struct { + ChromaFormatIDC uint + PicWidthInLumaSamples uint + PicHeightInLumaSamples uint + ConformanceWindowFlag bool + ConfWinLeftOffset uint + ConfWinRightOffset uint + ConfWinTopOffset uint + ConfWinBottomOffset uint + CodedWidth, CodedHeight uint // Raw coded resolution + DisplayWidth, DisplayHeight uint // Resolution after conformance window cropping +} + +// -------- BitReader -------- +type BitReader struct { + data []byte + pos int // bit position +} + +func NewBitReader(data []byte) *BitReader { + return &BitReader{data: data} +} + +func (br *BitReader) left() int { + return len(br.data)*8 - br.pos +} + +func (br *BitReader) ReadBits(n int) (uint, error) { + if n < 0 || br.left() < n { + return 0, errors.New("not enough bits") + } + var v uint + for range n { + bytePos := br.pos / 8 + bitPos := 7 - (br.pos % 8) + bit := (br.data[bytePos] >> bitPos) & 1 + v = (v << 1) | uint(bit) + br.pos++ + } + return v, nil +} + +func (br *BitReader) ReadFlag() (bool, error) { + b, err := br.ReadBits(1) + return b == 1, err +} + +func (br *BitReader) ReadUE() (uint, error) { + // Unsigned Exp-Golomb + zeros := 0 + for { + bit, err := br.ReadBits(1) + if err != nil { + return 0, err + } + if bit == 0 { + zeros++ + continue + } + break // hit the stop bit '1' + } + if zeros == 0 { + return 0, nil + } + info, err := br.ReadBits(zeros) + if err != nil { + return 0, err + } + return (1<= 4 && b[0] == 0x00 && b[1] == 0x00 && b[2] == 0x00 && b[3] == 0x01 { + return b[4:] + } + if len(b) >= 3 && b[0] == 0x00 && b[1] == 0x00 && b[2] == 0x01 { + return b[3:] + } + return b +} + +// removeEmulationPreventionBytes removes 0x03 after 0x0000 +func removeEmulationPreventionBytes(data []byte) []byte { + out := make([]byte, 0, len(data)) + for i := range data { + if i > 1 && data[i] == 0x03 && data[i-1] == 0x00 && data[i-2] == 0x00 { + continue + } + out = append(out, data[i]) + } + return out +} + +// parseH265SPS parses a full H.265 SPS NAL unit +func parseH265SPS(nal []byte) (*SPSInfo, error) { + // Optional start code + nal = stripStartCode(nal) + + // Remove emulation prevention bytes across the NAL + rbsp := removeEmulationPreventionBytes(nal) + + br := NewBitReader(rbsp) + + // ---- NAL header (16 bits): forbidden_zero_bit(1), nal_unit_type(6), nuh_layer_id(6), nuh_temporal_id_plus1(3) + if _, err := br.ReadBits(1); err != nil { // forbidden_zero_bit + return nil, err + } + nalUnitType, err := br.ReadBits(6) + if err != nil { + return nil, err + } + if _, err = br.ReadBits(6); err != nil { // nuh_layer_id + return nil, err + } + if _, err = br.ReadBits(3); err != nil { // nuh_temporal_id_plus1 + return nil, err + } + // 33 = SPS + if nalUnitType != 33 { + return nil, fmt.Errorf("not an HEVC SPS NAL (type=%d)", nalUnitType) + } + + // ---- sps_video_parameter_set_id u(4), sps_max_sub_layers_minus1 u(3), sps_temporal_id_nesting_flag u(1) + if _, err = br.ReadBits(4); err != nil { + return nil, err + } + maxSubLayersMinus1, err := br.ReadBits(3) + if err != nil { + return nil, err + } + if _, err = br.ReadBits(1); err != nil { + return nil, err + } + + // ---- profile_tier_level(1, max_sub_layers_minus1) + // general_profile_space u(2), general_tier_flag u(1), general_profile_idc u(5) + if _, err = br.ReadBits(2 + 1 + 5); err != nil { + return nil, err + } + // general_profile_compatibility_flags u(32) + if _, err = br.ReadBits(32); err != nil { + return nil, err + } + // general_constraint_indicator_flags u(48) + if _, err = br.ReadBits(16); err != nil { + return nil, err + } + if _, err = br.ReadBits(32); err != nil { + return nil, err + } + // general_level_idc u(8) + if _, err = br.ReadBits(8); err != nil { + return nil, err + } + + subLayerProfilePresentFlag := make([]bool, maxSubLayersMinus1) + subLayerLevelPresentFlag := make([]bool, maxSubLayersMinus1) + for i := range maxSubLayersMinus1 { + f1, err := br.ReadFlag() + if err != nil { + return nil, err + } + f2, err := br.ReadFlag() + if err != nil { + return nil, err + } + subLayerProfilePresentFlag[i] = f1 + subLayerLevelPresentFlag[i] = f2 + } + if maxSubLayersMinus1 > 0 { + // reserved_zero_2bits for i = maxSubLayersMinus1 .. 7 + for i := maxSubLayersMinus1; i < 8; i++ { + if _, err := br.ReadBits(2); err != nil { + return nil, err + } + } + } + for i := range maxSubLayersMinus1 { + if subLayerProfilePresentFlag[i] { + if _, err = br.ReadBits(2 + 1 + 5); err != nil { + return nil, err + } + if _, err = br.ReadBits(32); err != nil { + return nil, err + } + if _, err = br.ReadBits(48); err != nil { + return nil, err + } + } + if subLayerLevelPresentFlag[i] { + if _, err = br.ReadBits(8); err != nil { + return nil, err + } + } + } + + // ---- Now the core SPS fields we need + _, err = br.ReadUE() // sps_seq_parameter_set_id + if err != nil { + return nil, err + } + + chromaFormatIDC, err := br.ReadUE() + if err != nil { + return nil, err + } + if chromaFormatIDC == 3 { + // separate_colour_plane_flag u(1) + if _, err := br.ReadFlag(); err != nil { + return nil, err + } + } + + picW, err := br.ReadUE() // pic_width_in_luma_samples + if err != nil { + return nil, err + } + picH, err := br.ReadUE() // pic_height_in_luma_samples + if err != nil { + return nil, err + } + + confFlag, err := br.ReadFlag() + if err != nil { + return nil, err + } + var l, r, t, b uint + if confFlag { + if l, err = br.ReadUE(); err != nil { + return nil, err + } + if r, err = br.ReadUE(); err != nil { + return nil, err + } + if t, err = br.ReadUE(); err != nil { + return nil, err + } + if b, err = br.ReadUE(); err != nil { + return nil, err + } + } + + // crop unit size depends on chroma_format_idc + subWidthC, subHeightC := getSubWidthC(chromaFormatIDC), getSubHeightC(chromaFormatIDC) + + info := &SPSInfo{ + ChromaFormatIDC: chromaFormatIDC, + PicWidthInLumaSamples: picW, + PicHeightInLumaSamples: picH, + ConformanceWindowFlag: confFlag, + ConfWinLeftOffset: l, + ConfWinRightOffset: r, + ConfWinTopOffset: t, + ConfWinBottomOffset: b, + CodedWidth: picW, + CodedHeight: picH, + } + + if confFlag { + w := int(picW) - int(l+r)*int(subWidthC) + h := int(picH) - int(t+b)*int(subHeightC) + if w < 0 { + w = 0 + } + if h < 0 { + h = 0 + } + info.DisplayWidth = uint(w) + info.DisplayHeight = uint(h) + } else { + info.DisplayWidth = picW + info.DisplayHeight = picH + } + + return info, nil +} + +func getSubWidthC(chromaFormatIDC uint) uint { + if chromaFormatIDC == 1 || chromaFormatIDC == 2 { + return 2 + } + return 1 +} + +func getSubHeightC(chromaFormatIDC uint) uint { + if chromaFormatIDC == 1 { + return 2 + } + return 1 +} + +func ExtractH265VideoSize(payload []byte) VideoSize { + if len(payload) < 2 { + return VideoSize{} + } + nalType := (payload[0] >> 1) & 0x3F + + var spsNalu []byte + switch nalType { + case 33: // SPS + spsNalu = payload + case 48: // Aggregation Packet (AP) + // skip 2-byte header + i := 2 + for i+2 <= len(payload) { + nalSize := int(payload[i])<<8 | int(payload[i+1]) + i += 2 + if i+nalSize > len(payload) { + break + } + nalUnit := payload[i : i+nalSize] + nt := (nalUnit[0] >> 1) & 0x3F + if nt == 33 { + spsNalu = nalUnit + break + } + i += nalSize + } + } + + if len(spsNalu) > 0 { + info, err := parseH265SPS(spsNalu) + if err != nil { + return VideoSize{} + } + return VideoSize{Width: uint32(info.DisplayWidth), Height: uint32(info.DisplayHeight)} + } + + return VideoSize{} +} + +// ------------------------- H264 ------------------------- + +// parseH264SPS parses a full H.264 SPS NAL unit into SPSInfo +func parseH264SPS(nal []byte) (*SPSInfo, error) { + if len(nal) < 1 { + return nil, errors.New("empty SPS NAL") + } + nal = stripStartCode(nal) + nalType := nal[0] & 0x1F + if nalType != 7 { + return nil, fmt.Errorf("not an SPS NAL (type=%d)", nalType) + } + + rbsp := removeEmulationPreventionBytes(nal[1:]) // skip NAL header + br := NewBitReader(rbsp) + + profileIDC, _ := br.ReadBits(8) + _, _ = br.ReadBits(8) // constraint flags + _, _ = br.ReadBits(8) // level_idc + _, _ = br.ReadUE() // seq_parameter_set_id + + chromaFormatIDC := uint(1) + if profileIDC == 100 || profileIDC == 110 || profileIDC == 122 || profileIDC == 244 || + profileIDC == 44 || profileIDC == 83 || profileIDC == 86 || profileIDC == 118 || profileIDC == 128 { + chromaFormatIDC, _ = br.ReadUE() + if chromaFormatIDC == 3 { + br.ReadFlag() // separate_colour_plane_flag + } + br.ReadUE() // bit_depth_luma_minus8 + br.ReadUE() // bit_depth_chroma_minus8 + br.ReadFlag() // qpprime_y_zero_transform_bypass_flag + if v, _ := br.ReadFlag(); v { // seq_scaling_matrix_present_flag + for range 8 { + br.ReadFlag() + } + } + } + + br.ReadUE() // log2_max_frame_num_minus4 + pocType, _ := br.ReadUE() + if pocType == 0 { + br.ReadUE() + } else if pocType == 1 { + br.ReadFlag() + br.ReadSE() + br.ReadSE() + cnt, _ := br.ReadUE() + for range cnt { + br.ReadSE() + } + } + + br.ReadUE() // max_num_ref_frames + br.ReadFlag() // gaps_in_frame_num_value_allowed_flag + + wMbs, _ := br.ReadUE() + hMapUnits, _ := br.ReadUE() + frameMbsOnly, _ := br.ReadFlag() + if !frameMbsOnly { + br.ReadFlag() // mb_adaptive_frame_field_flag + } + br.ReadFlag() // direct_8x8_inference_flag + + var cropLeft, cropRight, cropTop, cropBottom uint + if frameCropping, _ := br.ReadFlag(); frameCropping { + cropLeft, _ = br.ReadUE() + cropRight, _ = br.ReadUE() + cropTop, _ = br.ReadUE() + cropBottom, _ = br.ReadUE() + } + + width := (wMbs + 1) * 16 + height := (hMapUnits + 1) * 16 + if !frameMbsOnly { + height *= 2 + } + + subWidthC := getSubWidthC(chromaFormatIDC) + subHeightC := getSubHeightC(chromaFormatIDC) + cropUnitX := subWidthC + cropUnitY := subHeightC + if chromaFormatIDC == 0 { + cropUnitX = 1 + if !frameMbsOnly { + cropUnitY = 2 + } else { + cropUnitY = 1 + } + } else if !frameMbsOnly { + cropUnitY *= 2 + } + + info := &SPSInfo{ + ChromaFormatIDC: chromaFormatIDC, + PicWidthInLumaSamples: width, + PicHeightInLumaSamples: height, + ConformanceWindowFlag: cropLeft+cropRight+cropTop+cropBottom > 0, + ConfWinLeftOffset: cropLeft, + ConfWinRightOffset: cropRight, + ConfWinTopOffset: cropTop, + ConfWinBottomOffset: cropBottom, + CodedWidth: width, + CodedHeight: height, + DisplayWidth: width - (cropLeft+cropRight)*cropUnitX, + DisplayHeight: height - (cropTop+cropBottom)*cropUnitY, + } + + return info, nil +} + +// ExtractH264VideoSize extracts resolution from H.264 RTP payload +func ExtractH264VideoSize(payload []byte) VideoSize { + if len(payload) < 1 { + return VideoSize{} + } + + parseNAL := func(nal []byte) VideoSize { + info, err := parseH264SPS(nal) + if err != nil { + return VideoSize{} + } + return VideoSize{Width: uint32(info.DisplayWidth), Height: uint32(info.DisplayHeight)} + } + + nalType := payload[0] & 0x1F + + switch nalType { + case 7: // SPS NAL + return parseNAL(payload) + + case 28: // FU-A + if len(payload) < 2 { + return VideoSize{} + } + start := (payload[1] & 0x80) != 0 + if !start { + return VideoSize{} + } + nalHeader := (payload[0] & 0xE0) | (payload[1] & 0x1F) + sps := append([]byte{nalHeader}, payload[2:]...) + return parseNAL(sps) + + case 24, 25, 26, 27: // STAP-A/B, MTAP16, MTAP24 + offset := 1 + if nalType == 25 { // STAP-B has 16-bit DON + offset += 2 + } else if nalType == 26 { // MTAP16 + offset += 3 + } else if nalType == 27 { // MTAP24 + offset += 4 + } + + for offset+2 <= len(payload) { + naluSize := int(payload[offset])<<8 | int(payload[offset+1]) + offset += 2 + if offset+naluSize > len(payload) { + break + } + nalu := payload[offset : offset+naluSize] + if nalu[0]&0x1F == 7 { // SPS + return parseNAL(nalu) + } + offset += naluSize + } + return VideoSize{} + + default: + return VideoSize{} + } +} diff --git a/livekit/pkg/sfu/buffer/h26xhelper_test.go b/livekit/pkg/sfu/buffer/h26xhelper_test.go new file mode 100644 index 0000000..424d6bb --- /dev/null +++ b/livekit/pkg/sfu/buffer/h26xhelper_test.go @@ -0,0 +1,40 @@ +package buffer + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExtractH26xVideoSize(t *testing.T) { + type testcase struct { + payload string + width uint32 + height uint32 + isH264 bool + } + + testcases := []testcase{ + {"eAAOZ0LAH4xoBQBboB4RCNQABGjOPIA=", 1280, 720, true}, + {"eAAPZ0LAFoxoCgL3lgHhEI1AAARozjyA", 640, 360, true}, + {"eAAOZ0LADIxoFBl54B4RCNQABGjOPIA=", 320, 180, true}, + {"YAEAGkABDAP//wFgAAADALAAAAMAAAMAXQAAGwJAAC9CAQMBYAAAAwCwAAADAAADAF0AAKACgIAtFiBu5FIy5+E9C+ob1SmoCAgIH8IBBAAHRAHAcvBbJA==", 1280, 720, false}, + {"YAEAGkABDAP//wFgAAADALAAAAMAAAMAPwAAGwJAADBCAQMBYAAAAwCwAAADAAADAD8AAKAFAgFx8uIG7kUjLn4T0L6hvVKagICAgfwgEEAAB0QBwHLwWyQ=", 640, 360, false}, + {"QgEDAWAAAAMAsAAAAwAAAwA8AACgCggMHz4gM7kUhi5+E9C+ob1Q/qoI9VQT6qoK9VVBfqqqDPVVVKagICAgfwgEEA==", 320, 180, false}, + } + + for _, tc := range testcases { + payload, err := base64.StdEncoding.DecodeString(tc.payload) + require.NoError(t, err) + + var sz VideoSize + if tc.isH264 { + sz = ExtractH264VideoSize(payload) + } else { + sz = ExtractH265VideoSize(payload) + } + require.Equal(t, tc.width, sz.Width) + require.Equal(t, tc.height, sz.Height) + } +} diff --git a/livekit/pkg/sfu/buffer/helpers.go b/livekit/pkg/sfu/buffer/helpers.go new file mode 100644 index 0000000..e2ec36a --- /dev/null +++ b/livekit/pkg/sfu/buffer/helpers.go @@ -0,0 +1,476 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "encoding/binary" + "errors" + + "github.com/pion/rtp/codecs" + + "github.com/livekit/protocol/logger" +) + +var ( + errShortPacket = errors.New("packet is not large enough") + errNilPacket = errors.New("invalid nil packet") + errInvalidPacket = errors.New("invalid packet") +) + +// VP8 is a helper to get temporal data from VP8 packet header +/* + VP8 Payload Descriptor + 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 + +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ + |X|R|N|S|R| PID | (REQUIRED) |X|R|N|S|R| PID | (REQUIRED) + +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ + X: |I|L|T|K| RSV | (OPTIONAL) X: |I|L|T|K| RSV | (OPTIONAL) + +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ + I: |M| PictureID | (OPTIONAL) I: |M| PictureID | (OPTIONAL) + +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ + L: | TL0PICIDX | (OPTIONAL) | PictureID | + +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ + T/K:|TID|Y| KEYIDX | (OPTIONAL) L: | TL0PICIDX | (OPTIONAL) + +-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+ + T/K:|TID|Y| KEYIDX | (OPTIONAL) + +-+-+-+-+-+-+-+-+ +*/ +type VP8 struct { + FirstByte byte + S bool + + I bool + M bool + PictureID uint16 /* 7 or 15 bits, picture ID */ + + L bool + TL0PICIDX uint8 /* 8 bits temporal level zero index */ + + // Optional Header If either of the T or K bits are set to 1, + // the TID/Y/KEYIDX extension field MUST be present. + T bool + TID uint8 /* 2 bits temporal layer idx */ + Y bool + + K bool + KEYIDX uint8 /* 5 bits of key frame idx */ + + HeaderSize int + + // IsKeyFrame is a helper to detect if current packet is a keyframe + IsKeyFrame bool +} + +// Unmarshal parses the passed byte slice and stores the result in the VP8 this method is called upon +func (v *VP8) Unmarshal(payload []byte) error { + if payload == nil { + return errNilPacket + } + + payloadLen := len(payload) + if payloadLen < 1 { + return errShortPacket + } + + idx := 0 + v.FirstByte = payload[idx] + v.S = payload[idx]&0x10 > 0 + // Check for extended bit control + if payload[idx]&0x80 > 0 { + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + v.I = payload[idx]&0x80 > 0 + v.L = payload[idx]&0x40 > 0 + v.T = payload[idx]&0x20 > 0 + v.K = payload[idx]&0x10 > 0 + if v.L && !v.T { + return errInvalidPacket + } + + if v.I { + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + pid := payload[idx] & 0x7f + // if m is 1, then Picture ID is 15 bits + v.M = payload[idx]&0x80 > 0 + if v.M { + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + v.PictureID = binary.BigEndian.Uint16([]byte{pid, payload[idx]}) + } else { + v.PictureID = uint16(pid) + } + } + + if v.L { + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + v.TL0PICIDX = payload[idx] + } + + if v.T || v.K { + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + + if v.T { + v.TID = (payload[idx] & 0xc0) >> 6 + v.Y = (payload[idx] & 0x20) > 0 + } + + if v.K { + v.KEYIDX = payload[idx] & 0x1f + } + } + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + + // Check is packet is a keyframe by looking at P bit in vp8 payload + v.IsKeyFrame = payload[idx]&0x01 == 0 && v.S + } else { + idx++ + if payloadLen < idx+1 { + return errShortPacket + } + // Check is packet is a keyframe by looking at P bit in vp8 payload + v.IsKeyFrame = payload[idx]&0x01 == 0 && v.S + } + v.HeaderSize = idx + return nil +} + +func (v *VP8) Marshal() ([]byte, error) { + buf := make([]byte, v.HeaderSize) + n, err := v.MarshalTo(buf) + if err != nil { + return nil, err + } + return buf[:n], err +} + +func (v *VP8) MarshalTo(buf []byte) (int, error) { + if len(buf) < v.HeaderSize { + return 0, errShortPacket + } + + idx := 0 + buf[idx] = v.FirstByte + if v.I || v.L || v.T || v.K { + buf[idx] |= 0x80 // X bit + idx++ + + xpos := idx + xval := byte(0) + + idx++ + if v.I { + xval |= (1 << 7) + if v.M { + buf[idx] = 0x80 | byte((v.PictureID>>8)&0x7f) + buf[idx+1] = byte(v.PictureID & 0xff) + idx += 2 + } else { + buf[idx] = byte(v.PictureID) + idx++ + } + } + + if v.L { + xval |= (1 << 6) + buf[idx] = v.TL0PICIDX + idx++ + } + + if v.T || v.K { + buf[idx] = 0 + if v.T { + xval |= (1 << 5) + buf[idx] = v.TID << 6 + if v.Y { + buf[idx] |= (1 << 5) + } + } + + if v.K { + xval |= (1 << 4) + buf[idx] |= v.KEYIDX & 0x1f + } + idx++ + } + + buf[xpos] = xval + } else { + buf[idx] &^= 0x80 // X bit + idx++ + } + + return idx, nil +} + +// ------------------------------------- + +func VPxPictureIdSizeDiff(mBit1 bool, mBit2 bool) int { + if mBit1 == mBit2 { + return 0 + } + + if mBit1 { + return 1 + } + + return -1 +} + +// ------------------------------------- + +// IsH264KeyFrame detects if h264 payload is a keyframe +// this code was taken from https://github.com/jech/galene/blob/codecs/rtpconn/rtpreader.go#L45 +// all credits belongs to Juliusz Chroboczek @jech and the awesome Galene SFU +func IsH264KeyFrame(payload []byte) bool { + if len(payload) < 1 { + return false + } + nalu := payload[0] & 0x1F + if nalu == 0 { + // reserved + return false + } else if nalu <= 23 { + // simple NALU + return nalu == 7 + } else if nalu == 24 || nalu == 25 || nalu == 26 || nalu == 27 { + // STAP-A, STAP-B, MTAP16 or MTAP24 + i := 1 + if nalu == 25 || nalu == 26 || nalu == 27 { + // skip DON + i += 2 + } + for i < len(payload) { + if i+2 > len(payload) { + return false + } + length := uint16(payload[i])<<8 | + uint16(payload[i+1]) + i += 2 + if i+int(length) > len(payload) { + return false + } + offset := 0 + if nalu == 26 { + offset = 3 + } else if nalu == 27 { + offset = 4 + } + if offset >= int(length) { + return false + } + n := payload[i+offset] & 0x1F + if n == 7 { + return true + } else if n >= 24 { + // is this legal? + logger.Debugw("Non-simple NALU within a STAP") + } + i += int(length) + } + if i == len(payload) { + return false + } + return false + } else if nalu == 28 || nalu == 29 { + // FU-A or FU-B + if len(payload) < 2 { + return false + } + if (payload[1] & 0x80) == 0 { + // not a starting fragment + return false + } + return payload[1]&0x1F == 7 + } + return false +} + +// ------------------------------------- + +// IsVP9KeyFrame detects if vp9 payload is a keyframe +// taken from https://github.com/jech/galene/blob/master/codecs/codecs.go +// all credits belongs to Juliusz Chroboczek @jech and the awesome Galene SFU +func IsVP9KeyFrame(vp9 *codecs.VP9Packet, payload []byte) bool { + if vp9 == nil { + vp9 = &codecs.VP9Packet{} + _, err := vp9.Unmarshal(payload) + if err != nil || len(vp9.Payload) < 1 { + return false + } + } + + if !vp9.B { + return false + } + + if (vp9.Payload[0] & 0xc0) != 0x80 { + return false + } + + profile := (vp9.Payload[0] >> 4) & 0x3 + if profile != 3 { + return (vp9.Payload[0] & 0xC) == 0 + } + return (vp9.Payload[0] & 0x6) == 0 +} + +// ------------------------------------- + +// IsAV1KeyFrame detects if av1 payload is a keyframe +// taken from https://github.com/jech/galene/blob/master/codecs/codecs.go +// all credits belongs to Juliusz Chroboczek @jech and the awesome Galene SFU +func IsAV1KeyFrame(payload []byte) bool { + if len(payload) < 2 { + return false + } + // Z=0, N=1 + if (payload[0] & 0x88) != 0x08 { + return false + } + w := (payload[0] & 0x30) >> 4 + + getObu := func(data []byte, last bool) ([]byte, int, bool) { + if last { + return data, len(data), false + } + offset := 0 + length := 0 + for { + if len(data) <= offset { + return nil, offset, offset > 0 + } + l := data[offset] + length |= int(l&0x7f) << (offset * 7) + offset++ + if (l & 0x80) == 0 { + break + } + } + if len(data) < offset+length { + return data[offset:], len(data), true + } + return data[offset : offset+length], + offset + length, false + } + offset := 1 + i := 0 + for { + obu, length, truncated := + getObu(payload[offset:], int(w) == i+1) + if len(obu) < 1 { + return false + } + tpe := (obu[0] & 0x38) >> 3 + switch i { + case 0: + // OBU_SEQUENCE_HEADER + if tpe != 1 { + return false + } + default: + // OBU_FRAME_HEADER or OBU_FRAME + if tpe == 3 || tpe == 6 { + if len(obu) < 2 { + return false + } + // show_existing_frame == 0 + if (obu[1] & 0x80) != 0 { + return false + } + // frame_type == KEY_FRAME + return (obu[1] & 0x60) == 0 + } + } + if truncated || i >= int(w) { + // the first frame header is in a second + // packet, give up. + return false + } + offset += length + i++ + } +} + +func IsH265KeyFrame(payload []byte) (kf bool) { + if len(payload) < 2 { + return false + } + naluType := (payload[0] & 0x7E) >> 1 + switch { + case naluType == 33 || naluType == 34: + return true + case naluType == 48: // AP + idx := 2 + for idx < len(payload)-2 { + // TODO: check the DONL field (controlled by sprop-max-don-diff) + size := binary.BigEndian.Uint16(payload[idx:]) + idx += 2 + if idx >= len(payload) { + return false + } + naluType = (payload[idx] & 0x7E) >> 1 + if naluType == 33 || naluType == 34 { + return true + } + idx += int(size) + } + return false + + case naluType == 49: // FU + if len(payload) < 3 { + return false + } + naluType = (payload[2] & 0x7E) >> 1 + return naluType == 33 || naluType == 34 + default: + return false + } +} + +// ExtractVP8VideoSize extracts video resolution from VP8 key frame +func ExtractVP8VideoSize(vp8Packet *VP8, payload []byte) VideoSize { + if !vp8Packet.IsKeyFrame || len(payload) < vp8Packet.HeaderSize+10 { + return VideoSize{} + } + + vp8Payload := payload[vp8Packet.HeaderSize:] + + // Check for VP8 start code + if len(vp8Payload) < 10 || vp8Payload[3] != 0x9D || vp8Payload[4] != 0x01 || vp8Payload[5] != 0x2A { + return VideoSize{} + } + + // Read width and height from bytes 6-9 + width := uint32(vp8Payload[6]) | (uint32(vp8Payload[7]) << 8) + height := uint32(vp8Payload[8]) | (uint32(vp8Payload[9]) << 8) + + return VideoSize{width & 0x3FFF, height & 0x3FFF} +} diff --git a/livekit/pkg/sfu/buffer/helpers_test.go b/livekit/pkg/sfu/buffer/helpers_test.go new file mode 100644 index 0000000..378bfbe --- /dev/null +++ b/livekit/pkg/sfu/buffer/helpers_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVP8Helper_Unmarshal(t *testing.T) { + type args struct { + payload []byte + } + tests := []struct { + name string + args args + wantErr bool + checkTemporal bool + temporalSupport bool + checkKeyFrame bool + keyFrame bool + checkPictureID bool + pictureID uint16 + checkTlzIdx bool + tlzIdx uint8 + checkTempID bool + temporalID uint8 + }{ + { + name: "Empty or nil payload must return error", + args: args{payload: []byte{}}, + wantErr: true, + }, + { + name: "Temporal must be supported by setting T bit to 1", + args: args{payload: []byte{0xff, 0x20, 0x1, 0x2, 0x3, 0x4}}, + checkTemporal: true, + temporalSupport: true, + }, + { + name: "Picture must be ID 7 bits by setting M bit to 0 and present by I bit set to 1", + args: args{payload: []byte{0xff, 0xff, 0x11, 0x2, 0x3, 0x4}}, + checkPictureID: true, + pictureID: 17, + }, + { + name: "Picture ID must be 15 bits by setting M bit to 1 and present by I bit set to 1", + args: args{payload: []byte{0xff, 0xff, 0x92, 0x67, 0x3, 0x4, 0x5}}, + checkPictureID: true, + pictureID: 4711, + }, + { + name: "Temporal level zero index must be present if L set to 1", + args: args{payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x4, 0x5}}, + checkTlzIdx: true, + tlzIdx: 180, + }, + { + name: "Temporal index must be present and used if T bit set to 1", + args: args{payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x5, 0x6}}, + checkTempID: true, + temporalID: 2, + }, + { + name: "Check if packet is a keyframe by looking at P bit set to 0", + args: args{payload: []byte{0xff, 0xff, 0xff, 0xfd, 0xb4, 0x9f, 0x94, 0x1}}, + checkKeyFrame: true, + keyFrame: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + p := &VP8{} + if err := p.Unmarshal(tt.args.payload); (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.checkTemporal { + require.Equal(t, tt.temporalSupport, p.T) + } + if tt.checkKeyFrame { + require.Equal(t, tt.keyFrame, p.IsKeyFrame) + } + if tt.checkPictureID { + require.Equal(t, tt.pictureID, p.PictureID) + } + if tt.checkTlzIdx { + require.Equal(t, tt.tlzIdx, p.TL0PICIDX) + } + if tt.checkTempID { + require.Equal(t, tt.temporalID, p.TID) + } + }) + } +} + +// ------------------------------------------ diff --git a/livekit/pkg/sfu/buffer/rtcpreader.go b/livekit/pkg/sfu/buffer/rtcpreader.go new file mode 100644 index 0000000..32bffc7 --- /dev/null +++ b/livekit/pkg/sfu/buffer/rtcpreader.go @@ -0,0 +1,63 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "io" + + "go.uber.org/atomic" +) + +type RTCPReader struct { + ssrc uint32 + closed atomic.Bool + onPacket atomic.Value // func([]byte) + onClose func() +} + +func NewRTCPReader(ssrc uint32) *RTCPReader { + return &RTCPReader{ssrc: ssrc} +} + +//go:noinline +func (r *RTCPReader) Write(p []byte) (n int, err error) { + if r.closed.Load() { + err = io.EOF + return + } + if f, ok := r.onPacket.Load().(func([]byte)); ok && f != nil { + f(p) + } + return +} + +func (r *RTCPReader) OnClose(fn func()) { + r.onClose = fn +} + +func (r *RTCPReader) Close() error { + if r.closed.Swap(true) { + return nil + } + + r.onClose() + return nil +} + +func (r *RTCPReader) OnPacket(f func([]byte)) { + r.onPacket.Store(f) +} + +func (r *RTCPReader) Read(_ []byte) (n int, err error) { return } diff --git a/livekit/pkg/sfu/buffer/streamstats.go b/livekit/pkg/sfu/buffer/streamstats.go new file mode 100644 index 0000000..e45cb2e --- /dev/null +++ b/livekit/pkg/sfu/buffer/streamstats.go @@ -0,0 +1,24 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + +type StreamStatsWithLayers struct { + RTPStats *rtpstats.RTPDeltaInfo + Layers map[int32]*rtpstats.RTPDeltaInfo + + RTPStatsRemoteView *rtpstats.RTPDeltaInfo +} diff --git a/livekit/pkg/sfu/buffer/videolayer.go b/livekit/pkg/sfu/buffer/videolayer.go new file mode 100644 index 0000000..761cc1c --- /dev/null +++ b/livekit/pkg/sfu/buffer/videolayer.go @@ -0,0 +1,58 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import "fmt" + +const ( + InvalidLayerSpatial = int32(-1) + InvalidLayerTemporal = int32(-1) + + DefaultMaxLayerSpatial = int32(2) + DefaultMaxLayerTemporal = int32(3) +) + +var ( + InvalidLayer = VideoLayer{ + Spatial: InvalidLayerSpatial, + Temporal: InvalidLayerTemporal, + } + + DefaultMaxLayer = VideoLayer{ + Spatial: DefaultMaxLayerSpatial, + Temporal: DefaultMaxLayerTemporal, + } +) + +type VideoLayer struct { + Spatial int32 + Temporal int32 +} + +func (v VideoLayer) String() string { + return fmt.Sprintf("VideoLayer{s: %d, t: %d}", v.Spatial, v.Temporal) +} + +func (v VideoLayer) GreaterThan(v2 VideoLayer) bool { + return v.Spatial > v2.Spatial || (v.Spatial == v2.Spatial && v.Temporal > v2.Temporal) +} + +func (v VideoLayer) SpatialGreaterThanOrEqual(v2 VideoLayer) bool { + return v.Spatial >= v2.Spatial +} + +func (v VideoLayer) IsValid() bool { + return v.Spatial != InvalidLayerSpatial && v.Temporal != InvalidLayerTemporal +} diff --git a/livekit/pkg/sfu/buffer/videolayerutils.go b/livekit/pkg/sfu/buffer/videolayerutils.go new file mode 100644 index 0000000..2754b59 --- /dev/null +++ b/livekit/pkg/sfu/buffer/videolayerutils.go @@ -0,0 +1,520 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "slices" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +const ( + quarterResolutionQ = "q" + halfResolutionH = "h" + fullResolutionF = "f" + + quarterResolution2 = "2" + halfResolution1 = "1" + fullResolution0 = "0" +) + +type VideoLayersRid [DefaultMaxLayerSpatial + 1]string + +var ( + videoLayersRidQHF = VideoLayersRid{quarterResolutionQ, halfResolutionH, fullResolutionF} + videoLayersRid210 = VideoLayersRid{quarterResolution2, halfResolution1, fullResolution0} + DefaultVideoLayersRid = videoLayersRidQHF +) + +func LayerPresenceFromTrackInfo(mimeType mime.MimeType, trackInfo *livekit.TrackInfo) *[livekit.VideoQuality_HIGH + 1]bool { + if trackInfo == nil { + return nil + } + + layers := GetVideoLayersForMimeType(mimeType, trackInfo) + if len(layers) == 0 { + return nil + } + + var layerPresence [livekit.VideoQuality_HIGH + 1]bool + for _, layer := range layers { + // WARNING: comparing protobuf enum + if layer.Quality <= livekit.VideoQuality_HIGH { + layerPresence[layer.Quality] = true + } else { + logger.Warnw("unexpected quality in track info", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) + } + } + + return &layerPresence +} + +func RidToSpatialLayer(mimeType mime.MimeType, rid string, trackInfo *livekit.TrackInfo, ridSpace VideoLayersRid) int32 { + lp := LayerPresenceFromTrackInfo(mimeType, trackInfo) + if lp == nil { + switch rid { + case quarterResolutionQ: + return 0 + case halfResolutionH: + return 1 + case fullResolutionF: + return 2 + default: + return 0 + } + } + + switch rid { + case ridSpace[0]: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return 0 + + default: + // only one quality published, could be any + return 0 + } + + case ridSpace[1]: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return 1 + + default: + // only one quality published, could be any + return 0 + } + + case ridSpace[2]: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return 2 + + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + logger.Warnw("unexpected rid with only two qualities, low and medium", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo), "rid", ridSpace[2]) + return 1 + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + logger.Warnw("unexpected rid with only two qualities, low and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo), "rid", ridSpace[2]) + return 1 + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + logger.Warnw("unexpected rid with only two qualities, medium and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo), "rid", ridSpace[2]) + return 1 + + default: + // only one quality published, could be any + return 0 + } + + default: + // no rid, should be single layer + return 0 + } +} + +func SpatialLayerToRid(mimeType mime.MimeType, layer int32, trackInfo *livekit.TrackInfo, ridSpace VideoLayersRid) string { + lp := LayerPresenceFromTrackInfo(mimeType, trackInfo) + if lp == nil { + switch layer { + case 0: + return quarterResolutionQ + case 1: + return halfResolutionH + case 2: + return fullResolutionF + default: + return quarterResolutionQ + } + } + + switch layer { + case 0: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return ridSpace[0] + + default: + return ridSpace[0] + } + + case 1: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return ridSpace[1] + + default: + return ridSpace[0] + } + + case 2: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return ridSpace[2] + + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + logger.Warnw("unexpected layer 2 with only two qualities, low and medium", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) + return ridSpace[1] + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + logger.Warnw("unexpected layer 2 with only two qualities, low and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) + return ridSpace[1] + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + logger.Warnw("unexpected layer 2 with only two qualities, medium and high", nil, "trackID", trackInfo.Sid, "trackInfo", logger.Proto(trackInfo)) + return ridSpace[1] + + default: + return ridSpace[0] + } + + default: + return ridSpace[0] + } +} + +func VideoQualityToRid(mimeType mime.MimeType, quality livekit.VideoQuality, trackInfo *livekit.TrackInfo, ridSpace VideoLayersRid) string { + return SpatialLayerToRid(mimeType, VideoQualityToSpatialLayer(mimeType, quality, trackInfo), trackInfo, ridSpace) +} + +func SpatialLayerToVideoQuality(mimeType mime.MimeType, layer int32, trackInfo *livekit.TrackInfo) livekit.VideoQuality { + lp := LayerPresenceFromTrackInfo(mimeType, trackInfo) + if lp == nil { + switch layer { + case 0: + return livekit.VideoQuality_LOW + case 1: + return livekit.VideoQuality_MEDIUM + case 2: + return livekit.VideoQuality_HIGH + default: + return livekit.VideoQuality_OFF + } + } + + switch layer { + case 0: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW]: + return livekit.VideoQuality_LOW + + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM]: + return livekit.VideoQuality_MEDIUM + + default: + return livekit.VideoQuality_HIGH + } + + case 1: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + return livekit.VideoQuality_MEDIUM + + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return livekit.VideoQuality_HIGH + + default: + logger.Errorw("invalid layer", nil, "trackID", trackInfo.Sid, "layer", layer, "trackInfo", logger.Proto(trackInfo)) + return livekit.VideoQuality_HIGH + } + + case 2: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return livekit.VideoQuality_HIGH + + default: + logger.Errorw("invalid layer", nil, "trackID", trackInfo.Sid, "layer", layer, "trackInfo", logger.Proto(trackInfo)) + return livekit.VideoQuality_HIGH + } + } + + return livekit.VideoQuality_OFF +} + +func VideoQualityToSpatialLayer(mimeType mime.MimeType, quality livekit.VideoQuality, trackInfo *livekit.TrackInfo) int32 { + lp := LayerPresenceFromTrackInfo(mimeType, trackInfo) + if lp == nil { + switch quality { + case livekit.VideoQuality_LOW: + return 0 + case livekit.VideoQuality_MEDIUM: + return 1 + case livekit.VideoQuality_HIGH: + return 2 + default: + return InvalidLayerSpatial + } + } + + switch quality { + case livekit.VideoQuality_LOW: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + default: // only one quality published, could be any + return 0 + } + + case livekit.VideoQuality_MEDIUM: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + return 1 + + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return 0 + + default: // only one quality published, could be any + return 0 + } + + case livekit.VideoQuality_HIGH: + switch { + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return 2 + + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_MEDIUM]: + fallthrough + case lp[livekit.VideoQuality_LOW] && lp[livekit.VideoQuality_HIGH]: + fallthrough + case lp[livekit.VideoQuality_MEDIUM] && lp[livekit.VideoQuality_HIGH]: + return 1 + + default: // only one quality published, could be any + return 0 + } + } + + return InvalidLayerSpatial +} + +func GetVideoLayerModeForMimeType(mimeType mime.MimeType, ti *livekit.TrackInfo) livekit.VideoLayer_Mode { + if ti != nil { + for _, codec := range ti.Codecs { + if mime.NormalizeMimeType(codec.MimeType) == mimeType { + return codec.VideoLayerMode + } + } + } + + return livekit.VideoLayer_MODE_UNUSED +} + +func GetVideoLayersForMimeType(mimeType mime.MimeType, ti *livekit.TrackInfo) []*livekit.VideoLayer { + var layers []*livekit.VideoLayer + if ti != nil { + for _, codec := range ti.Codecs { + if mime.NormalizeMimeType(codec.MimeType) == mimeType { + layers = codec.Layers + break + } + } + if len(layers) == 0 { + layers = ti.Layers + } + } + return layers +} + +func GetSpatialLayerForRid(mimeType mime.MimeType, rid string, ti *livekit.TrackInfo) int32 { + if ti == nil { + return InvalidLayerSpatial + } + + if rid == "" { + // single layer without RID + return 0 + } + + layers := GetVideoLayersForMimeType(mimeType, ti) + for _, layer := range layers { + if layer.Rid == rid { + return layer.SpatialLayer + } + } + + if len(layers) != 0 { + // RID present in codec, but may not be specified via signalling + // (happens with older browsers setting a rid for SVC codecs) + hasRid := false + for _, layer := range layers { + if layer.Rid != "" { + hasRid = true + break + } + } + if !hasRid { + return 0 + } + } + + // SIMULCAST-CODEC-TODO - ideally should return invalid, but there are + // VP9 publishers using rid = f, if there are only two layers + // in TrackInfo, that will be q;h and f will become invalid. + // + // Actually, there should be no rids for VP9 in SDP and hence + // the above check should take effect. However, as simulcast + // codec/back up codec does not update rids from SDP, + // the default rids are used when vp9 (primary codec) + // is published. Due to that the above check gets bypassed. + // + // The full proper sequence would be + // 1. For primary codec using SVC, there will be no rids. + // The above check should take effect and it should + // return 0 even if some publisher uses a rid like `f`. + // 2. When secondary codec is published, rids for the codec + // corresponding to the back up codec mime type should + // be updated in `TrackInfo`. This is a bit tricky + // for a couple of cases + // a. Browsers like Firefox use a different CID everytime. + // So, it cannot be matched between `AddTrack` and SDP. + // One option is to look for a published track with + // back up codec and apply it there. But, that becomes + // a challenge if there are multiple published tracks + // with pending back up codec. + // b. The back up codec publish SDP will have the full + // codec list. It should be okay to assume that the + // codec that will be published is the back up codec, + // but just something to be aware of. + // 3. Use of this function with proper mime so that proper + // codec section can be looked up in `TrackInfo`. + // return InvalidLayerSpatial + logger.Infow( + "invalid layer for rid, returning default", + "trackID", ti.Sid, + "rid", rid, + "mimeType", mimeType, + "trackInfo", logger.Proto(ti), + ) + return 0 +} + +func GetSpatialLayerForVideoQuality(mimeType mime.MimeType, quality livekit.VideoQuality, ti *livekit.TrackInfo) int32 { + if ti == nil || quality == livekit.VideoQuality_OFF { + return InvalidLayerSpatial + } + + layers := GetVideoLayersForMimeType(mimeType, ti) + for _, layer := range layers { + if layer.Quality == quality { + return layer.SpatialLayer + } + } + + if len(layers) == 0 { + // single layer + return 0 + } + + // requested quality is higher than available layers, return the highest available layer + return VideoQualityToSpatialLayer(mimeType, quality, ti) +} + +func GetVideoQualityForSpatialLayer(mimeType mime.MimeType, spatialLayer int32, ti *livekit.TrackInfo) livekit.VideoQuality { + if spatialLayer == InvalidLayerSpatial || ti == nil { + return livekit.VideoQuality_OFF + } + + layers := GetVideoLayersForMimeType(mimeType, ti) + for _, layer := range layers { + if layer.SpatialLayer == spatialLayer { + return layer.Quality + } + } + + return livekit.VideoQuality_OFF +} + +func isVideoLayersRidKnown(rids VideoLayersRid, knownRids VideoLayersRid) bool { + for _, rid := range rids { + if rid == "" { + continue + } + + if !slices.Contains(knownRids[:], rid) { + return false + } + } + + return true +} + +func NormalizeVideoLayersRid(rids VideoLayersRid) VideoLayersRid { + out := rids + + normalize := func(knownRids VideoLayersRid) { + idx := 0 + for _, known := range knownRids { + if slices.Contains(rids[:], known) { + out[idx] = known + idx++ + } + } + } + + if isVideoLayersRidKnown(rids, videoLayersRidQHF) { + normalize(videoLayersRidQHF) + } + + if isVideoLayersRidKnown(rids, videoLayersRid210) { + normalize(videoLayersRid210) + } + + return out +} diff --git a/livekit/pkg/sfu/buffer/videolayerutils_test.go b/livekit/pkg/sfu/buffer/videolayerutils_test.go new file mode 100644 index 0000000..a7c5ad2 --- /dev/null +++ b/livekit/pkg/sfu/buffer/videolayerutils_test.go @@ -0,0 +1,892 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" +) + +func TestRidConversion(t *testing.T) { + type RidAndLayer struct { + rid string + layer int32 + } + tests := []struct { + name string + trackInfo *livekit.TrackInfo + mimeType mime.MimeType + ridToLayer map[string]RidAndLayer + }{ + { + "no track info", + nil, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: halfResolutionH, layer: 1}, + fullResolutionF: {rid: fullResolutionF, layer: 2}, + }, + }, + { + "no layers", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: halfResolutionH, layer: 1}, + fullResolutionF: {rid: fullResolutionF, layer: 2}, + }, + }, + { + "single layer, low", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: quarterResolutionQ, layer: 0}, + fullResolutionF: {rid: quarterResolutionQ, layer: 0}, + }, + }, + { + "single layer, medium", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: quarterResolutionQ, layer: 0}, + fullResolutionF: {rid: quarterResolutionQ, layer: 0}, + }, + }, + { + "single layer, high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: quarterResolutionQ, layer: 0}, + fullResolutionF: {rid: quarterResolutionQ, layer: 0}, + }, + }, + { + "two layers, low and medium", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: halfResolutionH, layer: 1}, + fullResolutionF: {rid: halfResolutionH, layer: 1}, + }, + }, + { + "two layers, low and high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: halfResolutionH, layer: 1}, + fullResolutionF: {rid: halfResolutionH, layer: 1}, + }, + }, + { + "two layers, medium and high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: halfResolutionH, layer: 1}, + fullResolutionF: {rid: halfResolutionH, layer: 1}, + }, + }, + { + "three layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_MEDIUM}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]RidAndLayer{ + "": {rid: quarterResolutionQ, layer: 0}, + quarterResolutionQ: {rid: quarterResolutionQ, layer: 0}, + halfResolutionH: {rid: halfResolutionH, layer: 1}, + fullResolutionF: {rid: fullResolutionF, layer: 2}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for testRid, expectedResult := range test.ridToLayer { + actualLayer := RidToSpatialLayer(test.mimeType, testRid, test.trackInfo, DefaultVideoLayersRid) + require.Equal(t, expectedResult.layer, actualLayer) + + actualRid := SpatialLayerToRid(test.mimeType, actualLayer, test.trackInfo, DefaultVideoLayersRid) + require.Equal(t, expectedResult.rid, actualRid) + } + }) + } +} + +func TestQualityConversion(t *testing.T) { + type QualityAndLayer struct { + quality livekit.VideoQuality + layer int32 + } + tests := []struct { + name string + trackInfo *livekit.TrackInfo + mimeType mime.MimeType + qualityToLayer map[livekit.VideoQuality]QualityAndLayer + }{ + { + "no track info", + nil, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_MEDIUM, layer: 1}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_HIGH, layer: 2}, + }, + }, + { + "no layers", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_MEDIUM, layer: 1}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_HIGH, layer: 2}, + }, + }, + { + "single layer, low", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_LOW, layer: 0}, + }, + }, + { + "single layer, medium", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_MEDIUM, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_MEDIUM, layer: 0}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_MEDIUM, layer: 0}, + }, + }, + { + "single layer, high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_HIGH, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_HIGH, layer: 0}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_HIGH, layer: 0}, + }, + }, + { + "two layers, low and medium", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_MEDIUM, layer: 1}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_MEDIUM, layer: 1}, + }, + }, + { + "two layers, low and high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_HIGH, layer: 1}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_HIGH, layer: 1}, + }, + }, + { + "two layers, medium and high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_MEDIUM, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_MEDIUM, layer: 0}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_HIGH, layer: 1}, + }, + }, + { + "three layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_MEDIUM}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]QualityAndLayer{ + livekit.VideoQuality_LOW: {quality: livekit.VideoQuality_LOW, layer: 0}, + livekit.VideoQuality_MEDIUM: {quality: livekit.VideoQuality_MEDIUM, layer: 1}, + livekit.VideoQuality_HIGH: {quality: livekit.VideoQuality_HIGH, layer: 2}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for testQuality, expectedResult := range test.qualityToLayer { + actualLayer := VideoQualityToSpatialLayer(test.mimeType, testQuality, test.trackInfo) + require.Equal(t, expectedResult.layer, actualLayer) + + actualQuality := SpatialLayerToVideoQuality(test.mimeType, actualLayer, test.trackInfo) + require.Equal(t, expectedResult.quality, actualQuality) + } + }) + } +} + +func TestVideoQualityToRidConversion(t *testing.T) { + tests := []struct { + name string + trackInfo *livekit.TrackInfo + mimeTye mime.MimeType + qualityToRid map[livekit.VideoQuality]string + }{ + { + "no track info", + nil, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: halfResolutionH, + livekit.VideoQuality_HIGH: fullResolutionF, + }, + }, + { + "no layers", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: halfResolutionH, + livekit.VideoQuality_HIGH: fullResolutionF, + }, + }, + { + "single layer, low", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: quarterResolutionQ, + livekit.VideoQuality_HIGH: quarterResolutionQ, + }, + }, + { + "single layer, medium", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: quarterResolutionQ, + livekit.VideoQuality_HIGH: quarterResolutionQ, + }, + }, + { + "single layer, high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: quarterResolutionQ, + livekit.VideoQuality_HIGH: quarterResolutionQ, + }, + }, + { + "two layers, low and medium", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_MEDIUM}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: halfResolutionH, + livekit.VideoQuality_HIGH: halfResolutionH, + }, + }, + { + "two layers, low and high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: halfResolutionH, + livekit.VideoQuality_HIGH: halfResolutionH, + }, + }, + { + "two layers, medium and high", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_MEDIUM}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: quarterResolutionQ, + livekit.VideoQuality_HIGH: halfResolutionH, + }, + }, + { + "three layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW}, + {Quality: livekit.VideoQuality_MEDIUM}, + {Quality: livekit.VideoQuality_HIGH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]string{ + livekit.VideoQuality_LOW: quarterResolutionQ, + livekit.VideoQuality_MEDIUM: halfResolutionH, + livekit.VideoQuality_HIGH: fullResolutionF, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for testQuality, expectedRid := range test.qualityToRid { + actualRid := VideoQualityToRid(test.mimeTye, testQuality, test.trackInfo, DefaultVideoLayersRid) + require.Equal(t, expectedRid, actualRid) + } + }) + } +} + +func TestGetSpatialLayerForRid(t *testing.T) { + tests := []struct { + name string + trackInfo *livekit.TrackInfo + mimeType mime.MimeType + ridToSpatialLayer map[string]int32 + }{ + { + "no track info", + nil, + mime.MimeTypeVP8, + map[string]int32{ + quarterResolutionQ: InvalidLayerSpatial, + halfResolutionH: InvalidLayerSpatial, + fullResolutionF: InvalidLayerSpatial, + }, + }, + { + "no layers", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[string]int32{ + // SIMULCAST-CODEC-TODO + // quarterResolutionQ: InvalidLayerSpatial, + // halfResolutionH: InvalidLayerSpatial, + // fullResolutionF: InvalidLayerSpatial, + quarterResolutionQ: 0, + halfResolutionH: 0, + fullResolutionF: 0, + }, + }, + { + "no rid", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[string]int32{ + "": 0, + }, + }, + { + "single layer", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW, SpatialLayer: 0}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]int32{ + quarterResolutionQ: 0, + halfResolutionH: 0, + fullResolutionF: 0, + }, + }, + { + "layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW, SpatialLayer: 0, Rid: quarterResolutionQ}, + {Quality: livekit.VideoQuality_MEDIUM, SpatialLayer: 1, Rid: halfResolutionH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]int32{ + quarterResolutionQ: 0, + halfResolutionH: 1, + // SIMULCAST-CODEC-TODO + // fullResolutionF: InvalidLayerSpatial, + fullResolutionF: 0, + }, + }, + { + "layers - no rid", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW, SpatialLayer: 0}, + {Quality: livekit.VideoQuality_MEDIUM, SpatialLayer: 1}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[string]int32{ + quarterResolutionQ: 0, + halfResolutionH: 0, + fullResolutionF: 0, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for testRid, expectedSpatialLayer := range test.ridToSpatialLayer { + actualSpatialLayer := GetSpatialLayerForRid(test.mimeType, testRid, test.trackInfo) + require.Equal(t, expectedSpatialLayer, actualSpatialLayer) + } + }) + } +} + +func TestGetSpatialLayerForVideoQuality(t *testing.T) { + tests := []struct { + name string + trackInfo *livekit.TrackInfo + mimeType mime.MimeType + videoQualityToSpatialLayer map[livekit.VideoQuality]int32 + }{ + { + "no track info", + nil, + mime.MimeTypeVP8, + map[livekit.VideoQuality]int32{ + livekit.VideoQuality_LOW: InvalidLayerSpatial, + livekit.VideoQuality_MEDIUM: InvalidLayerSpatial, + livekit.VideoQuality_HIGH: InvalidLayerSpatial, + livekit.VideoQuality_OFF: InvalidLayerSpatial, + }, + }, + { + "no layers", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[livekit.VideoQuality]int32{ + livekit.VideoQuality_LOW: 0, + livekit.VideoQuality_MEDIUM: 0, + livekit.VideoQuality_HIGH: 0, + livekit.VideoQuality_OFF: InvalidLayerSpatial, + }, + }, + { + "not all layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW, SpatialLayer: 0, Rid: quarterResolutionQ}, + {Quality: livekit.VideoQuality_MEDIUM, SpatialLayer: 1, Rid: halfResolutionH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]int32{ + livekit.VideoQuality_LOW: 0, + livekit.VideoQuality_MEDIUM: 1, + livekit.VideoQuality_HIGH: 1, + livekit.VideoQuality_OFF: InvalidLayerSpatial, + }, + }, + { + "all layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW, SpatialLayer: 0, Rid: quarterResolutionQ}, + {Quality: livekit.VideoQuality_MEDIUM, SpatialLayer: 1, Rid: halfResolutionH}, + {Quality: livekit.VideoQuality_HIGH, SpatialLayer: 2, Rid: fullResolutionF}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[livekit.VideoQuality]int32{ + livekit.VideoQuality_LOW: 0, + livekit.VideoQuality_MEDIUM: 1, + livekit.VideoQuality_HIGH: 2, + livekit.VideoQuality_OFF: InvalidLayerSpatial, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for testVideoQuality, expectedSpatialLayer := range test.videoQualityToSpatialLayer { + actualSpatialLayer := GetSpatialLayerForVideoQuality(test.mimeType, testVideoQuality, test.trackInfo) + require.Equal(t, expectedSpatialLayer, actualSpatialLayer) + } + }) + } +} + +func TestGetVideoQualityorSpatialLayer(t *testing.T) { + tests := []struct { + name string + trackInfo *livekit.TrackInfo + mimeType mime.MimeType + spatialLayerToVideoQuality map[int32]livekit.VideoQuality + }{ + { + "no track info", + nil, + mime.MimeTypeVP8, + map[int32]livekit.VideoQuality{ + InvalidLayerSpatial: livekit.VideoQuality_OFF, + 0: livekit.VideoQuality_OFF, + 1: livekit.VideoQuality_OFF, + 2: livekit.VideoQuality_OFF, + }, + }, + { + "no layers", + &livekit.TrackInfo{}, + mime.MimeTypeVP8, + map[int32]livekit.VideoQuality{ + InvalidLayerSpatial: livekit.VideoQuality_OFF, + 0: livekit.VideoQuality_OFF, + 1: livekit.VideoQuality_OFF, + 2: livekit.VideoQuality_OFF, + }, + }, + { + "layers", + &livekit.TrackInfo{ + Codecs: []*livekit.SimulcastCodecInfo{ + { + MimeType: mime.MimeTypeVP8.String(), + Layers: []*livekit.VideoLayer{ + {Quality: livekit.VideoQuality_LOW, SpatialLayer: 0, Rid: quarterResolutionQ}, + {Quality: livekit.VideoQuality_MEDIUM, SpatialLayer: 1, Rid: halfResolutionH}, + }, + }, + }, + }, + mime.MimeTypeVP8, + map[int32]livekit.VideoQuality{ + InvalidLayerSpatial: livekit.VideoQuality_OFF, + 0: livekit.VideoQuality_LOW, + 1: livekit.VideoQuality_MEDIUM, + 2: livekit.VideoQuality_OFF, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for testSpatialLayer, expectedVideoQuality := range test.spatialLayerToVideoQuality { + actualVideoQuality := GetVideoQualityForSpatialLayer(test.mimeType, testSpatialLayer, test.trackInfo) + require.Equal(t, expectedVideoQuality, actualVideoQuality) + } + }) + } +} + +func TestNormalizeVideoLayersRid(t *testing.T) { + tests := []struct { + name string + rids VideoLayersRid + normalized VideoLayersRid + }{ + { + "empty", + VideoLayersRid{}, + VideoLayersRid{}, + }, + { + "unknown pattern", + VideoLayersRid{"3", "2", "1"}, + VideoLayersRid{"3", "2", "1"}, + }, + { + "qhf", + videoLayersRidQHF, + videoLayersRidQHF, + }, + { + "scrambled qhf", + VideoLayersRid{"f", "h", "q"}, + videoLayersRidQHF, + }, + { + "partial qhf", + VideoLayersRid{"h", "q"}, + VideoLayersRid{"q", "h", ""}, + }, + { + "210", + videoLayersRid210, + videoLayersRid210, + }, + { + "scrambled 210", + VideoLayersRid{"2", "0", "1"}, + videoLayersRid210, + }, + { + "partial 210", + VideoLayersRid{"1", "2"}, + VideoLayersRid{"2", "1", ""}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + normalizedRids := NormalizeVideoLayersRid(test.rids) + require.Equal(t, test.normalized, normalizedRids) + }) + } +} diff --git a/livekit/pkg/sfu/bwe/bwe.go b/livekit/pkg/sfu/bwe/bwe.go new file mode 100644 index 0000000..ad61a70 --- /dev/null +++ b/livekit/pkg/sfu/bwe/bwe.go @@ -0,0 +1,123 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bwe + +import ( + "fmt" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/pion/rtcp" +) + +// ------------------------------------------------ + +const ( + DefaultRTT = float64(0.070) // 70 ms + RTTSmoothingFactor = float64(0.5) +) + +// ------------------------------------------------ + +type BWEType int + +const ( + BWETypeNone BWEType = iota + BWETypeRemote + BWETypeSendSide +) + +func (b BWEType) String() string { + switch b { + case BWETypeNone: + return "NONE" + case BWETypeRemote: + return "REMOTE" + case BWETypeSendSide: + return "SEND_SIDE" + default: + return fmt.Sprintf("%d", int(b)) + } +} + +// ------------------------------------------------ + +type CongestionState int + +const ( + CongestionStateNone CongestionState = iota + CongestionStateEarlyWarning + CongestionStateCongested +) + +func (c CongestionState) String() string { + switch c { + case CongestionStateNone: + return "NONE" + case CongestionStateEarlyWarning: + return "EARLY_WARNING" + case CongestionStateCongested: + return "CONGESTED" + default: + return fmt.Sprintf("%d", int(c)) + } +} + +// ------------------------------------------------ + +type BWE interface { + Type() BWEType + + SetBWEListener(bweListner BWEListener) + + Reset() + + HandleREMB( + receivedEstimate int64, + expectedBandwidthUsage int64, + sentPackets uint32, + repeatedNacks uint32, + ) + + // TWCC sequence number + RecordPacketSendAndGetSequenceNumber( + atMicro int64, + size int, + isRTX bool, + probeClusterId ccutils.ProbeClusterId, + isProbe bool, + ) uint16 + + HandleTWCCFeedback(report *rtcp.TransportLayerCC) + + UpdateRTT(rtt float64) + + CongestionState() CongestionState + + CanProbe() bool + ProbeDuration() time.Duration + ProbeClusterStarting(pci ccutils.ProbeClusterInfo) + ProbeClusterDone(pci ccutils.ProbeClusterInfo) + ProbeClusterIsGoalReached() bool + ProbeClusterFinalize() (ccutils.ProbeSignal, int64, bool) +} + +// ------------------------------------------------ + +type BWEListener interface { + OnCongestionStateChange(fromState CongestionState, toState CongestionState, estimatedAvailableChannelCapacity int64) +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/bwe/null_bwe.go b/livekit/pkg/sfu/bwe/null_bwe.go new file mode 100644 index 0000000..2a1c394 --- /dev/null +++ b/livekit/pkg/sfu/bwe/null_bwe.go @@ -0,0 +1,77 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bwe + +import ( + "time" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/pion/rtcp" +) + +type NullBWE struct { +} + +func (n *NullBWE) SetBWEListener(_bweListener BWEListener) {} + +func (n *NullBWE) Reset() {} + +func (n *NullBWE) RecordPacketSendAndGetSequenceNumber( + _atMicro int64, + _size int, + _isRTX bool, + _probeClusterId ccutils.ProbeClusterId, + _isProbe bool, +) uint16 { + return 0 +} + +func (n *NullBWE) HandleREMB( + _receivedEstimate int64, + _expectedBandwidthUsage int64, + _sentPackets uint32, + _repeatedNacks uint32, +) { +} + +func (n *NullBWE) HandleTWCCFeedback(_report *rtcp.TransportLayerCC) {} + +func (n *NullBWE) UpdateRTT(rtt float64) {} + +func (n *NullBWE) CongestionState() CongestionState { + return CongestionStateNone +} + +func (n *NullBWE) CanProbe() bool { + return false +} + +func (n *NullBWE) ProbeDuration() time.Duration { + return 0 +} + +func (n *NullBWE) ProbeClusterStarting(_pci ccutils.ProbeClusterInfo) {} + +func (n *NullBWE) ProbeClusterDone(_pci ccutils.ProbeClusterInfo) {} + +func (n *NullBWE) ProbeClusterIsGoalReached() bool { + return false +} + +func (n *NullBWE) ProbeClusterFinalize() (ccutils.ProbeSignal, int64, bool) { + return ccutils.ProbeSignalInconclusive, 0, false +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/bwe/remotebwe/channel_observer.go b/livekit/pkg/sfu/bwe/remotebwe/channel_observer.go new file mode 100644 index 0000000..bf112e6 --- /dev/null +++ b/livekit/pkg/sfu/bwe/remotebwe/channel_observer.go @@ -0,0 +1,202 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotebwe + +import ( + "fmt" + "time" + + "go.uber.org/zap/zapcore" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" +) + +// ------------------------------------------------ + +type channelTrend int + +const ( + channelTrendInconclusive channelTrend = iota + channelTrendClearing + channelTrendCongesting +) + +func (c channelTrend) String() string { + switch c { + case channelTrendInconclusive: + return "INCONCLUSIVE" + case channelTrendClearing: + return "CLEARING" + case channelTrendCongesting: + return "CONGESTING" + default: + return fmt.Sprintf("%d", int(c)) + } +} + +// ------------------------------------------------ + +type channelCongestionReason int + +const ( + channelCongestionReasonNone channelCongestionReason = iota + channelCongestionReasonEstimate + channelCongestionReasonLoss +) + +func (c channelCongestionReason) String() string { + switch c { + case channelCongestionReasonNone: + return "NONE" + case channelCongestionReasonEstimate: + return "ESTIMATE" + case channelCongestionReasonLoss: + return "LOSS" + default: + return fmt.Sprintf("%d", int(c)) + } +} + +// ------------------------------------------------ + +type ChannelObserverConfig struct { + Estimate ccutils.TrendDetectorConfig `yaml:"estimate,omitempty"` + Nack NackTrackerConfig `yaml:"nack,omitempty"` +} + +var ( + defaultTrendDetectorConfigProbe = ccutils.TrendDetectorConfig{ + RequiredSamples: 3, + RequiredSamplesMin: 3, + DownwardTrendThreshold: 0.0, + DownwardTrendMaxWait: 5 * time.Second, + CollapseThreshold: 0, + ValidityWindow: 10 * time.Second, + } + + defaultChannelObserverConfigProbe = ChannelObserverConfig{ + Estimate: defaultTrendDetectorConfigProbe, + Nack: defaultNackTrackerConfigProbe, + } + + defaultTrendDetectorConfigNonProbe = ccutils.TrendDetectorConfig{ + RequiredSamples: 12, + RequiredSamplesMin: 8, + DownwardTrendThreshold: -0.6, + DownwardTrendMaxWait: 5 * time.Second, + CollapseThreshold: 500 * time.Millisecond, + ValidityWindow: 10 * time.Second, + } + + defaultChannelObserverConfigNonProbe = ChannelObserverConfig{ + Estimate: defaultTrendDetectorConfigNonProbe, + Nack: defaultNackTrackerConfigNonProbe, + } +) + +// ------------------------------------------------ + +type channelObserverParams struct { + Name string + Config ChannelObserverConfig +} + +type channelObserver struct { + params channelObserverParams + logger logger.Logger + + estimateTrend *ccutils.TrendDetector[int64] + nackTracker *nackTracker +} + +func newChannelObserver(params channelObserverParams, logger logger.Logger) *channelObserver { + return &channelObserver{ + params: params, + logger: logger, + estimateTrend: ccutils.NewTrendDetector[int64](ccutils.TrendDetectorParams{ + Name: params.Name + "-estimate", + Logger: logger, + Config: params.Config.Estimate, + }), + nackTracker: newNackTracker(nackTrackerParams{ + Name: params.Name + "-nack", + Logger: logger, + Config: params.Config.Nack, + }), + } +} + +func (c *channelObserver) SeedEstimate(estimate int64) { + c.estimateTrend.Seed(estimate) +} + +func (c *channelObserver) AddEstimate(estimate int64) { + c.estimateTrend.AddValue(estimate) +} + +func (c *channelObserver) AddNack(packets uint32, repeatedNacks uint32) { + c.nackTracker.Add(packets, repeatedNacks) +} + +func (c *channelObserver) GetLowestEstimate() int64 { + return c.estimateTrend.GetLowest() +} + +func (c *channelObserver) GetHighestEstimate() int64 { + return c.estimateTrend.GetHighest() +} + +func (c *channelObserver) HasEnoughEstimateSamples() bool { + return c.estimateTrend.HasEnoughSamples() +} + +func (c *channelObserver) GetNackRatio() float64 { + return c.nackTracker.GetRatio() +} + +func (c *channelObserver) GetTrend() (channelTrend, channelCongestionReason) { + estimateDirection := c.estimateTrend.GetDirection() + + switch { + case estimateDirection == ccutils.TrendDirectionDownward: + return channelTrendCongesting, channelCongestionReasonEstimate + + case c.nackTracker.IsTriggered(): + return channelTrendCongesting, channelCongestionReasonLoss + + case estimateDirection == ccutils.TrendDirectionUpward: + return channelTrendClearing, channelCongestionReasonNone + } + + return channelTrendInconclusive, channelCongestionReasonNone +} + +func (c *channelObserver) MarshalLogObject(e zapcore.ObjectEncoder) error { + if c == nil { + return nil + } + + e.AddString("name", c.params.Name) + e.AddObject("estimate", c.estimateTrend) + e.AddObject("nack", c.nackTracker) + + channelTrend, channelCongestionReason := c.GetTrend() + e.AddString("channelTrend", channelTrend.String()) + e.AddString("channelCongestionReason", channelCongestionReason.String()) + return nil +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/bwe/remotebwe/nack_tracker.go b/livekit/pkg/sfu/bwe/remotebwe/nack_tracker.go new file mode 100644 index 0000000..106115f --- /dev/null +++ b/livekit/pkg/sfu/bwe/remotebwe/nack_tracker.go @@ -0,0 +1,130 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotebwe + +import ( + "time" + + "go.uber.org/zap/zapcore" + + "github.com/livekit/protocol/logger" +) + +// ------------------------------------------------ + +type NackTrackerConfig struct { + WindowMinDuration time.Duration `yaml:"window_min_duration,omitempty"` + WindowMaxDuration time.Duration `yaml:"window_max_duration,omitempty"` + RatioThreshold float64 `yaml:"ratio_threshold,omitempty"` +} + +var ( + defaultNackTrackerConfigProbe = NackTrackerConfig{ + WindowMinDuration: 500 * time.Millisecond, + WindowMaxDuration: 1 * time.Second, + RatioThreshold: 0.04, + } + + defaultNackTrackerConfigNonProbe = NackTrackerConfig{ + WindowMinDuration: 2 * time.Second, + WindowMaxDuration: 3 * time.Second, + RatioThreshold: 0.08, + } +) + +// ------------------------------------------------ + +type nackTrackerParams struct { + Name string + Logger logger.Logger + Config NackTrackerConfig +} + +type nackTracker struct { + params nackTrackerParams + + windowStartTime time.Time + packets uint32 + repeatedNacks uint32 +} + +func newNackTracker(params nackTrackerParams) *nackTracker { + return &nackTracker{ + params: params, + } +} + +func (n *nackTracker) Add(packets uint32, repeatedNacks uint32) { + if n.params.Config.WindowMaxDuration != 0 && !n.windowStartTime.IsZero() && time.Since(n.windowStartTime) > n.params.Config.WindowMaxDuration { + n.windowStartTime = time.Time{} + n.packets = 0 + n.repeatedNacks = 0 + } + + // + // Start NACK monitoring window only when a repeated NACK happens. + // This allows locking tightly to when NACKs start happening and + // check if the NACKs keep adding up (potentially a sign of congestion) + // or isolated losses + // + if n.repeatedNacks == 0 && repeatedNacks != 0 { + n.windowStartTime = time.Now() + } + + if !n.windowStartTime.IsZero() { + n.packets += packets + n.repeatedNacks += repeatedNacks + } +} + +func (n *nackTracker) GetRatio() float64 { + ratio := 0.0 + if n.packets != 0 { + ratio = float64(n.repeatedNacks) / float64(n.packets) + if ratio > 1.0 { + ratio = 1.0 + } + } + + return ratio +} + +func (n *nackTracker) IsTriggered() bool { + if n.params.Config.WindowMinDuration != 0 && !n.windowStartTime.IsZero() && time.Since(n.windowStartTime) > n.params.Config.WindowMinDuration { + return n.GetRatio() > n.params.Config.RatioThreshold + } + + return false +} + +func (n *nackTracker) MarshalLogObject(e zapcore.ObjectEncoder) error { + if n == nil { + return nil + } + + e.AddString("name", n.params.Name) + if n.windowStartTime.IsZero() { + e.AddString("window", "inactive") + } else { + e.AddTime("windowStartTime", n.windowStartTime) + e.AddDuration("windowDuration", time.Since(n.windowStartTime)) + e.AddUint32("packets", n.packets) + e.AddUint32("repeatedNacks", n.repeatedNacks) + e.AddFloat64("nackRatio", n.GetRatio()) + } + return nil +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/bwe/remotebwe/probe_controller.go b/livekit/pkg/sfu/bwe/remotebwe/probe_controller.go new file mode 100644 index 0000000..1efa3f6 --- /dev/null +++ b/livekit/pkg/sfu/bwe/remotebwe/probe_controller.go @@ -0,0 +1,184 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotebwe + +import ( + "fmt" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +// --------------------------------------------------------------------------- + +type probeControllerState int + +const ( + probeControllerStateNone probeControllerState = iota + probeControllerStateProbing + probeControllerStateHangover +) + +func (p probeControllerState) String() string { + switch p { + case probeControllerStateNone: + return "NONE" + case probeControllerStateProbing: + return "PROBING" + case probeControllerStateHangover: + return "HANGOVER" + default: + return fmt.Sprintf("%d", int(p)) + } +} + +// ------------------------------------------------ + +type ProbeControllerConfig struct { + ProbeRegulator ccutils.ProbeRegulatorConfig `yaml:"probe_regulator,omitempty"` + + SettleWaitNumRTT uint32 `yaml:"settle_wait_num_rtt,omitempty"` + SettleWaitMin time.Duration `yaml:"settle_wait_min,omitempty"` + SettleWaitMax time.Duration `yaml:"settle_wait_max,omitempty"` +} + +var ( + DefaultProbeControllerConfig = ProbeControllerConfig{ + ProbeRegulator: ccutils.DefaultProbeRegulatorConfig, + + SettleWaitNumRTT: 5, + SettleWaitMin: 250 * time.Millisecond, + SettleWaitMax: 5 * time.Second, + } +) + +// --------------------------------------------------------------------------- + +type probeControllerParams struct { + Config ProbeControllerConfig + Logger logger.Logger +} + +type probeController struct { + params probeControllerParams + + state probeControllerState + stateSwitchedAt time.Time + + pci ccutils.ProbeClusterInfo + rtt float64 + + *ccutils.ProbeRegulator +} + +func newProbeController(params probeControllerParams) *probeController { + return &probeController{ + params: params, + state: probeControllerStateNone, + stateSwitchedAt: mono.Now(), + pci: ccutils.ProbeClusterInfoInvalid, + rtt: bwe.DefaultRTT, + ProbeRegulator: ccutils.NewProbeRegulator( + ccutils.ProbeRegulatorParams{ + Config: params.Config.ProbeRegulator, + Logger: params.Logger, + }, + ), + } +} + +func (p *probeController) UpdateRTT(rtt float64) { + if rtt == 0 { + p.rtt = bwe.DefaultRTT + } else { + if p.rtt == 0 { + p.rtt = rtt + } else { + p.rtt = bwe.RTTSmoothingFactor*rtt + (1.0-bwe.RTTSmoothingFactor)*p.rtt + } + } +} + +func (p *probeController) GetRTT() float64 { + return p.rtt +} + +func (p *probeController) CanProbe() bool { + return p.state == probeControllerStateNone && p.ProbeRegulator.CanProbe() +} + +func (p *probeController) IsInProbe() bool { + return p.state != probeControllerStateNone +} + +func (p *probeController) ProbeClusterStarting(pci ccutils.ProbeClusterInfo) { + if p.state != probeControllerStateNone { + p.params.Logger.Warnw("unexpected probe controller state", nil, "state", p.state) + } + + p.setState(probeControllerStateProbing) + p.pci = pci +} + +func (p *probeController) ProbeClusterDone(pci ccutils.ProbeClusterInfo) { + if p.pci.Id != pci.Id { + return + } + + p.pci.Result = pci.Result + p.setState(probeControllerStateHangover) +} + +func (p *probeController) ProbeClusterIsGoalReached(estimate int64) bool { + if p.pci.Id == ccutils.ProbeClusterIdInvalid { + return false + } + + return estimate > int64(p.pci.Goal.DesiredBps) +} + +func (p *probeController) MaybeFinalizeProbe() (ccutils.ProbeClusterInfo, bool) { + if p.state != probeControllerStateHangover { + return ccutils.ProbeClusterInfoInvalid, false + } + + settleWait := time.Duration(float64(p.params.Config.SettleWaitNumRTT) * p.rtt * float64(time.Second)) + if settleWait < p.params.Config.SettleWaitMin { + settleWait = p.params.Config.SettleWaitMin + } + if settleWait > p.params.Config.SettleWaitMax { + settleWait = p.params.Config.SettleWaitMax + } + if time.Since(p.stateSwitchedAt) < settleWait { + return ccutils.ProbeClusterInfoInvalid, false + } + + p.setState(probeControllerStateNone) + return p.pci, true +} + +func (p *probeController) setState(state probeControllerState) { + if state == p.state { + return + } + + p.state = state + p.stateSwitchedAt = mono.Now() +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/bwe/remotebwe/remote_bwe.go b/livekit/pkg/sfu/bwe/remotebwe/remote_bwe.go new file mode 100644 index 0000000..f6eb4ba --- /dev/null +++ b/livekit/pkg/sfu/bwe/remotebwe/remote_bwe.go @@ -0,0 +1,359 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotebwe + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +var _ bwe.BWE = (*RemoteBWE)(nil) + +// --------------------------------------------------------------------------- + +type RemoteBWEConfig struct { + NackRatioAttenuator float64 `yaml:"nack_ratio_attenuator,omitempty"` + ExpectedUsageThreshold float64 `yaml:"expected_usage_threshold,omitempty"` + ChannelObserverProbe ChannelObserverConfig `yaml:"channel_observer_probe,omitempty"` + ChannelObserverNonProbe ChannelObserverConfig `yaml:"channel_observer_non_probe,omitempty"` + ProbeController ProbeControllerConfig `yaml:"probe_controller,omitempty"` +} + +var ( + DefaultRemoteBWEConfig = RemoteBWEConfig{ + NackRatioAttenuator: 0.4, + ExpectedUsageThreshold: 0.95, + ChannelObserverProbe: defaultChannelObserverConfigProbe, + ChannelObserverNonProbe: defaultChannelObserverConfigNonProbe, + ProbeController: DefaultProbeControllerConfig, + } +) + +// --------------------------------------------------------------------------- + +type RemoteBWEParams struct { + Config RemoteBWEConfig + Logger logger.Logger +} + +type RemoteBWE struct { + bwe.NullBWE + + params RemoteBWEParams + + lock sync.RWMutex + + lastReceivedEstimate int64 + lastExpectedBandwidthUsage int64 + committedChannelCapacity int64 + + probeController *probeController + + channelObserver *channelObserver + + congestionState bwe.CongestionState + congestionStateSwitchedAt time.Time + + bweListener bwe.BWEListener +} + +func NewRemoteBWE(params RemoteBWEParams) *RemoteBWE { + r := &RemoteBWE{ + params: params, + } + + r.Reset() + return r +} + +func (r *RemoteBWE) Type() bwe.BWEType { + return bwe.BWETypeRemote +} + +func (r *RemoteBWE) SetBWEListener(bweListener bwe.BWEListener) { + r.lock.Lock() + defer r.lock.Unlock() + + r.bweListener = bweListener +} + +func (r *RemoteBWE) getBWEListener() bwe.BWEListener { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.bweListener +} + +func (r *RemoteBWE) Reset() { + r.lock.Lock() + defer r.lock.Unlock() + + r.lastReceivedEstimate = 0 + r.lastExpectedBandwidthUsage = 0 + r.committedChannelCapacity = 100_000_000 + + r.congestionState = bwe.CongestionStateNone + r.congestionStateSwitchedAt = mono.Now() + + r.probeController = newProbeController(probeControllerParams{ + Config: r.params.Config.ProbeController, + Logger: r.params.Logger, + }) + + r.newChannelObserver() +} + +func (r *RemoteBWE) HandleREMB( + receivedEstimate int64, + expectedBandwidthUsage int64, + sentPackets uint32, + repeatedNacks uint32, +) { + r.lock.Lock() + r.lastReceivedEstimate = receivedEstimate + r.lastExpectedBandwidthUsage = expectedBandwidthUsage + + // in probe, freeze channel observer state if probe causes congestion till the probe is done, + // this is to ensure that probe result is not marked a success, + // an unsuccessful probe will not up allocate any tracks + if r.congestionState != bwe.CongestionStateNone && r.probeController.IsInProbe() { + r.lock.Unlock() + return + } + + r.channelObserver.AddEstimate(r.lastReceivedEstimate) + r.channelObserver.AddNack(sentPackets, repeatedNacks) + + shouldNotify, fromState, toState, committedChannelCapacity := r.congestionDetectionStateMachine() + r.lock.Unlock() + + if shouldNotify { + if bweListener := r.getBWEListener(); bweListener != nil { + bweListener.OnCongestionStateChange(fromState, toState, committedChannelCapacity) + } + } +} + +func (r *RemoteBWE) UpdateRTT(rtt float64) { + r.lock.Lock() + defer r.lock.Unlock() + + r.probeController.UpdateRTT(rtt) +} + +func (r *RemoteBWE) congestionDetectionStateMachine() (bool, bwe.CongestionState, bwe.CongestionState, int64) { + fromState := r.congestionState + toState := r.congestionState + update := false + trend, reason := r.channelObserver.GetTrend() + + switch fromState { + case bwe.CongestionStateNone: + if trend == channelTrendCongesting { + if r.probeController.IsInProbe() || r.estimateAvailableChannelCapacity(reason) { + // when in probe, if congested, stays there till probe is done, + // the estimate stays at pre-probe level + toState = bwe.CongestionStateCongested + } + } + + case bwe.CongestionStateCongested: + if trend == channelTrendCongesting { + if r.estimateAvailableChannelCapacity(reason) { + // update state as this needs to reset switch time to wait for congestion min duration again + update = true + } + } else { + toState = bwe.CongestionStateNone + } + } + + shouldNotify := false + if toState != fromState || update { + fromState, toState = r.updateCongestionState(toState, reason) + shouldNotify = true + } + + return shouldNotify, fromState, toState, r.committedChannelCapacity +} + +func (r *RemoteBWE) estimateAvailableChannelCapacity(reason channelCongestionReason) bool { + var estimateToCommit int64 + switch reason { + case channelCongestionReasonLoss: + estimateToCommit = int64(float64(r.lastExpectedBandwidthUsage) * (1.0 - r.params.Config.NackRatioAttenuator*r.channelObserver.GetNackRatio())) + default: + estimateToCommit = r.lastReceivedEstimate + } + if estimateToCommit > r.lastReceivedEstimate { + estimateToCommit = r.lastReceivedEstimate + } + + commitThreshold := int64(r.params.Config.ExpectedUsageThreshold * float64(r.lastExpectedBandwidthUsage)) + if estimateToCommit > commitThreshold || r.committedChannelCapacity == estimateToCommit { + return false + } + + r.params.Logger.Infow( + "remote bwe: channel congestion detected, applying channel capacity update", + "reason", reason, + "old(bps)", r.committedChannelCapacity, + "new(bps)", estimateToCommit, + "lastReceived(bps)", r.lastReceivedEstimate, + "expectedUsage(bps)", r.lastExpectedBandwidthUsage, + "commitThreshold(bps)", commitThreshold, + "channel", r.channelObserver, + ) + r.committedChannelCapacity = estimateToCommit + return true +} + +func (r *RemoteBWE) updateCongestionState(state bwe.CongestionState, reason channelCongestionReason) (bwe.CongestionState, bwe.CongestionState) { + r.params.Logger.Debugw( + "remote bwe: congestion state change", + "from", r.congestionState, + "to", state, + "reason", reason, + "committedChannelCapacity", r.committedChannelCapacity, + ) + + fromState := r.congestionState + r.congestionState = state + r.congestionStateSwitchedAt = mono.Now() + return fromState, r.congestionState +} + +func (r *RemoteBWE) CongestionState() bwe.CongestionState { + r.lock.Lock() + defer r.lock.Unlock() + + return r.congestionState +} + +func (r *RemoteBWE) CanProbe() bool { + r.lock.Lock() + defer r.lock.Unlock() + + return r.congestionState == bwe.CongestionStateNone && r.probeController.CanProbe() +} + +func (r *RemoteBWE) ProbeDuration() time.Duration { + r.lock.Lock() + defer r.lock.Unlock() + + return r.probeController.ProbeDuration() +} + +func (r *RemoteBWE) ProbeClusterStarting(pci ccutils.ProbeClusterInfo) { + r.lock.Lock() + defer r.lock.Unlock() + + r.lastExpectedBandwidthUsage = int64(pci.Goal.ExpectedUsageBps) + + r.params.Logger.Debugw( + "remote bwe: starting probe", + "lastReceived", r.lastReceivedEstimate, + "expectedBandwidthUsage", r.lastExpectedBandwidthUsage, + "channel", r.channelObserver, + ) + + r.probeController.ProbeClusterStarting(pci) + r.newChannelObserver() +} + +func (r *RemoteBWE) ProbeClusterDone(pci ccutils.ProbeClusterInfo) { + r.lock.Lock() + defer r.lock.Unlock() + + r.probeController.ProbeClusterDone(pci) +} + +func (r *RemoteBWE) ProbeClusterIsGoalReached() bool { + r.lock.Lock() + defer r.lock.Unlock() + + if !r.probeController.IsInProbe() || + r.congestionState != bwe.CongestionStateNone || + !r.channelObserver.HasEnoughEstimateSamples() { + return false + } + + return r.probeController.ProbeClusterIsGoalReached(r.channelObserver.GetHighestEstimate()) +} + +func (r *RemoteBWE) ProbeClusterFinalize() (ccutils.ProbeSignal, int64, bool) { + r.lock.Lock() + defer r.lock.Unlock() + + pci, isFinalized := r.probeController.MaybeFinalizeProbe() + if !isFinalized { + return ccutils.ProbeSignalInconclusive, 0, isFinalized + } + + // switch to a non-probe channel observer on probe end, + // reset congestion state to get a fresh trend + pco := r.channelObserver + probeCongestionState := r.congestionState + + r.congestionState = bwe.CongestionStateNone + r.newChannelObserver() + + r.params.Logger.Infow( + "remote bwe: probe finalized", + "lastReceived", r.lastReceivedEstimate, + "expectedBandwidthUsage", r.lastExpectedBandwidthUsage, + "channel", pco, + "isSignalValid", pco.HasEnoughEstimateSamples(), + "probeClusterInfo", pci, + "rtt", r.probeController.GetRTT(), + ) + + probeSignal := ccutils.ProbeSignalNotCongesting + if probeCongestionState != bwe.CongestionStateNone { + probeSignal = ccutils.ProbeSignalCongesting + } else if !pco.HasEnoughEstimateSamples() { + probeSignal = ccutils.ProbeSignalInconclusive + } else { + highestEstimate := pco.GetHighestEstimate() + if highestEstimate > r.committedChannelCapacity { + r.committedChannelCapacity = highestEstimate + } + } + + r.probeController.ProbeSignal(probeSignal, pci.CreatedAt) + return probeSignal, r.committedChannelCapacity, true +} + +func (r *RemoteBWE) newChannelObserver() { + var params channelObserverParams + if r.probeController.IsInProbe() { + params = channelObserverParams{ + Name: "probe", + Config: r.params.Config.ChannelObserverProbe, + } + } else { + params = channelObserverParams{ + Name: "non-probe", + Config: r.params.Config.ChannelObserverNonProbe, + } + } + + r.channelObserver = newChannelObserver(params, r.params.Logger) +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/congestion_detector.go b/livekit/pkg/sfu/bwe/sendsidebwe/congestion_detector.go new file mode 100644 index 0000000..f8a3e53 --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/congestion_detector.go @@ -0,0 +1,1162 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "fmt" + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + "github.com/pion/rtcp" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------------------------------------- + +type CongestionSignalConfig struct { + MinNumberOfGroups int `yaml:"min_number_of_groups,omitempty"` + MinDuration time.Duration `yaml:"min_duration,omitempty"` +} + +func (c CongestionSignalConfig) IsTriggered(numGroups int, duration int64) bool { + return numGroups >= c.MinNumberOfGroups && duration >= c.MinDuration.Microseconds() +} + +var ( + defaultQueuingDelayEarlyWarningJQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 2, + MinDuration: 200 * time.Millisecond, + } + + defaultQueuingDelayEarlyWarningDQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 3, + MinDuration: 300 * time.Millisecond, + } + + defaultLossEarlyWarningJQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 3, + MinDuration: 300 * time.Millisecond, + } + + defaultLossEarlyWarningDQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 4, + MinDuration: 400 * time.Millisecond, + } + + defaultQueuingDelayCongestedJQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 4, + MinDuration: 400 * time.Millisecond, + } + + defaultQueuingDelayCongestedDQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 5, + MinDuration: 500 * time.Millisecond, + } + + defaultLossCongestedJQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 6, + MinDuration: 600 * time.Millisecond, + } + + defaultLossCongestedDQRConfig = CongestionSignalConfig{ + MinNumberOfGroups: 6, + MinDuration: 600 * time.Millisecond, + } +) + +// ------------------------------------------------------------------------------- + +type ProbeSignalConfig struct { + MinBytesRatio float64 `yaml:"min_bytes_ratio,omitempty"` + MinDurationRatio float64 `yaml:"min_duration_ratio,omitempty"` + + JQRMinDelay time.Duration `yaml:"jqr_min_delay,omitempty"` + DQRMaxDelay time.Duration `yaml:"dqr_max_delay,omitempty"` + + WeightedLoss WeightedLossConfig `yaml:"weighted_loss,omitempty"` + JQRMinWeightedLoss float64 `yaml:"jqr_min_weighted_loss,omitempty"` + DQRMaxWeightedLoss float64 `yaml:"dqr_max_weighted_loss,omitempty"` +} + +func (p ProbeSignalConfig) IsValid(pci ccutils.ProbeClusterInfo) bool { + return pci.Result.Bytes() > int(p.MinBytesRatio*float64(pci.Goal.DesiredBytes)) && pci.Result.Duration() > time.Duration(p.MinDurationRatio*float64(pci.Goal.Duration)) +} + +func (p ProbeSignalConfig) ProbeSignal(ppg *probePacketGroup) (ccutils.ProbeSignal, int64) { + ts := newTrafficStats(trafficStatsParams{ + Config: p.WeightedLoss, + }) + ts.Merge(ppg.Traffic()) + + pqd := ppg.PropagatedQueuingDelay() + if pqd > p.JQRMinDelay.Microseconds() || ts.WeightedLoss() > p.JQRMinWeightedLoss { + return ccutils.ProbeSignalCongesting, ts.AcknowledgedBitrate() + } + + if pqd < p.DQRMaxDelay.Microseconds() && ts.WeightedLoss() < p.DQRMaxWeightedLoss { + return ccutils.ProbeSignalNotCongesting, ts.AcknowledgedBitrate() + } + + return ccutils.ProbeSignalInconclusive, ts.AcknowledgedBitrate() +} + +var ( + defaultProbeSignalConfig = ProbeSignalConfig{ + MinBytesRatio: 0.5, + MinDurationRatio: 0.5, + + JQRMinDelay: 50 * time.Millisecond, + DQRMaxDelay: 20 * time.Millisecond, + + WeightedLoss: defaultWeightedLossConfig, + JQRMinWeightedLoss: 0.25, + DQRMaxWeightedLoss: 0.1, + } +) + +// ------------------------------------------------------------------------------- + +type queuingRegion int + +const ( + queuingRegionDQR queuingRegion = iota + queuingRegionIndeterminate + queuingRegionJQR +) + +func (q queuingRegion) String() string { + switch q { + case queuingRegionDQR: + return "DQR" + case queuingRegionIndeterminate: + return "INDETERMINATE" + case queuingRegionJQR: + return "JQR" + default: + return fmt.Sprintf("%d", int(q)) + } +} + +// ------------------------------------------------------------------------------- + +type congestionReason int + +const ( + congestionReasonNone congestionReason = iota + congestionReasonQueuingDelay + congestionReasonLoss +) + +func (c congestionReason) String() string { + switch c { + case congestionReasonNone: + return "NONE" + case congestionReasonQueuingDelay: + return "QUEUING_DELAY" + case congestionReasonLoss: + return "LOSS" + default: + return fmt.Sprintf("%d", int(c)) + } +} + +// ------------------------------------------------------------------------------- + +type qdMeasurement struct { + jqrConfig CongestionSignalConfig + dqrConfig CongestionSignalConfig + jqrMinDelay int64 + jqrMinTrendCoefficient float64 + dqrMaxDelay int64 + + numGroups int + numJQRGroups int + numDQRGroups int + minSendTime int64 + maxSendTime int64 + propagatedQueuingDelays []int64 + + isSealed bool + minGroupIdx int + maxGroupIdx int + + queuingRegion queuingRegion +} + +func newQDMeasurement(jqrConfig CongestionSignalConfig, dqrConfig CongestionSignalConfig, jqrMinDelay int64, jqrMinTrendCoefficient float64, dqrMaxDelay int64) *qdMeasurement { + return &qdMeasurement{ + jqrConfig: jqrConfig, + dqrConfig: dqrConfig, + jqrMinDelay: jqrMinDelay, + jqrMinTrendCoefficient: jqrMinTrendCoefficient, + dqrMaxDelay: dqrMaxDelay, + queuingRegion: queuingRegionIndeterminate, + } +} + +func (q *qdMeasurement) ProcessPacketGroup(pg *packetGroup, groupIdx int) { + if q.isSealed { + return + } + + pqd, pqdOk := pg.FinalizedPropagatedQueuingDelay() + if !pqdOk { + return + } + + q.numGroups++ + if q.minGroupIdx == 0 || q.minGroupIdx > groupIdx { + q.minGroupIdx = groupIdx + } + q.maxGroupIdx = max(q.maxGroupIdx, groupIdx) + + minSendTime, maxSendTime := pg.SendWindow() + if q.minSendTime == 0 || minSendTime < q.minSendTime { + q.minSendTime = minSendTime + } + q.maxSendTime = max(q.maxSendTime, maxSendTime) + + q.propagatedQueuingDelays = append(q.propagatedQueuingDelays, pqd) + + switch { + case pqd < q.dqrMaxDelay: + q.numDQRGroups++ + if q.numJQRGroups > 0 { + // broken continuity, seal + q.isSealed = true + } else if q.dqrConfig.IsTriggered(q.numDQRGroups, q.maxSendTime-q.minSendTime) { + q.isSealed = true + q.queuingRegion = queuingRegionDQR + } + + case pqd > q.jqrMinDelay: + q.numJQRGroups++ + if q.numDQRGroups > 0 { + // broken continuity, seal + q.isSealed = true + } else if q.jqrConfig.IsTriggered(q.numJQRGroups, q.maxSendTime-q.minSendTime) && q.trendCoefficient() > q.jqrMinTrendCoefficient { + q.isSealed = true + q.queuingRegion = queuingRegionJQR + } + + default: + if q.numDQRGroups > 0 || q.numJQRGroups > 0 { + // broken continuity, seal + q.isSealed = true + } + } +} + +func (q *qdMeasurement) IsSealed() bool { + return q.isSealed +} + +func (q *qdMeasurement) QueuingRegion() queuingRegion { + return q.queuingRegion +} + +func (q *qdMeasurement) GroupRange() (int, int) { + return max(0, q.minGroupIdx), max(0, q.maxGroupIdx) +} + +func (q *qdMeasurement) MarshalLogObject(e zapcore.ObjectEncoder) error { + if q == nil { + return nil + } + + e.AddInt("numGroups", q.numGroups) + e.AddInt("numJQRGroups", q.numJQRGroups) + e.AddInt("numDQRGroups", q.numDQRGroups) + e.AddInt64("minSendTime", q.minSendTime) + e.AddInt64("maxSendTime", q.maxSendTime) + e.AddArray("propagatedQueuingDelays", logger.Int64Slice(q.propagatedQueuingDelays)) + e.AddFloat64("trendCoefficient", q.trendCoefficient()) + e.AddDuration("duration", time.Duration((q.maxSendTime-q.minSendTime)*1000)) + e.AddBool("isSealed", q.isSealed) + e.AddInt("minGroupIdx", q.minGroupIdx) + e.AddInt("maxGroupIdx", q.maxGroupIdx) + e.AddString("queuingRegion", q.queuingRegion.String()) + return nil +} + +func (q *qdMeasurement) trendCoefficient() float64 { + concordantPairs := 0 + discordantPairs := 0 + + // the packet groups are processed from newest to oldest, + // so a concordant pair is when the value drops, + // i. e. the propagated queuing delay is increasing if newer (earlier entity in slice) is higher than older (later entity in slice) + for i := 0; i < len(q.propagatedQueuingDelays)-1; i++ { + for j := i + 1; j < len(q.propagatedQueuingDelays); j++ { + if q.propagatedQueuingDelays[i] > q.propagatedQueuingDelays[j] { + concordantPairs++ + } else if q.propagatedQueuingDelays[i] < q.propagatedQueuingDelays[j] { + discordantPairs++ + } + } + } + + if (concordantPairs + discordantPairs) == 0 { + // if the min requirements is only one sample, trend calculation is not possible, declare highest trend value + if len(q.propagatedQueuingDelays) == 1 { + return 1.0 + } + + return 0.0 + } + + return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs)) +} + +// ------------------------------------------------------------------------------- + +type lossMeasurement struct { + jqrConfig CongestionSignalConfig + dqrConfig CongestionSignalConfig + jqrMinLoss float64 + dqrMaxLoss float64 + + numGroups int + ts *trafficStats + + isJQRSealed bool + isDQRSealed bool + minGroupIdx int + maxGroupIdx int + + weightedLoss float64 + + queuingRegion queuingRegion +} + +func newLossMeasurement( + jqrConfig CongestionSignalConfig, + dqrConfig CongestionSignalConfig, + weightedLossConfig WeightedLossConfig, + jqrMinLoss float64, + dqrMaxLoss float64, + logger logger.Logger, +) *lossMeasurement { + return &lossMeasurement{ + jqrConfig: jqrConfig, + dqrConfig: dqrConfig, + jqrMinLoss: jqrMinLoss, + dqrMaxLoss: dqrMaxLoss, + ts: newTrafficStats(trafficStatsParams{ + Config: weightedLossConfig, + Logger: logger, + }), + queuingRegion: queuingRegionIndeterminate, + } +} + +func (l *lossMeasurement) ProcessPacketGroup(pg *packetGroup, groupIdx int) { + if (l.isJQRSealed && l.isDQRSealed) || !pg.IsFinalized() { + return + } + + l.numGroups++ + if l.minGroupIdx == 0 || l.minGroupIdx > groupIdx { + l.minGroupIdx = groupIdx + } + l.maxGroupIdx = max(l.maxGroupIdx, groupIdx) + + l.ts.Merge(pg.Traffic()) + + if !l.isJQRSealed && l.jqrConfig.IsTriggered(l.numGroups, l.ts.Duration()) { + l.isJQRSealed = true + + weightedLoss := l.ts.WeightedLoss() + if weightedLoss > l.jqrMinLoss { + l.weightedLoss = weightedLoss + l.queuingRegion = queuingRegionJQR + l.isDQRSealed = true // seal DQR also as queuing region has been determined + return + } + } + + if !l.isDQRSealed && l.dqrConfig.IsTriggered(l.numGroups, l.ts.Duration()) { + l.isDQRSealed = true + + weightedLoss := l.ts.WeightedLoss() + if weightedLoss < l.dqrMaxLoss { + l.weightedLoss = weightedLoss + l.queuingRegion = queuingRegionDQR + l.isJQRSealed = true // seal JQR also as queuing region has been determined + return + } + } +} + +func (l *lossMeasurement) IsSealed() bool { + return l.isJQRSealed && l.isDQRSealed +} + +func (l *lossMeasurement) QueuingRegion() queuingRegion { + return l.queuingRegion +} + +func (l *lossMeasurement) GroupRange() (int, int) { + return max(0, l.minGroupIdx), max(0, l.maxGroupIdx) +} + +func (l *lossMeasurement) MarshalLogObject(e zapcore.ObjectEncoder) error { + if l == nil { + return nil + } + + e.AddInt("numGroups", l.numGroups) + e.AddObject("ts", l.ts) + e.AddBool("isJQRSealed", l.isJQRSealed) + e.AddBool("isDQRSealed", l.isDQRSealed) + e.AddInt("minGroupIdx", l.minGroupIdx) + e.AddInt("maxGroupIdx", l.maxGroupIdx) + e.AddFloat64("weightedLoss", l.weightedLoss) + e.AddString("queuingRegion", l.queuingRegion.String()) + return nil +} + +// ------------------------------------------------------------------------------- + +type CongestionDetectorConfig struct { + PacketGroup PacketGroupConfig `yaml:"packet_group,omitempty"` + PacketGroupMaxAge time.Duration `yaml:"packet_group_max_age,omitempty"` + + ProbePacketGroup ProbePacketGroupConfig `yaml:"probe_packet_group,omitempty"` + ProbeRegulator ccutils.ProbeRegulatorConfig `yaml:"probe_regulator,omitempty"` + ProbeSignal ProbeSignalConfig `yaml:"probe_signal,omitempty"` + + JQRMinDelay time.Duration `yaml:"jqr_min_delay,omitempty"` + JQRMinTrendCoefficient float64 `yaml:"jqr_min_trend_coefficient,omitempty"` + DQRMaxDelay time.Duration `yaml:"dqr_max_delay,omitempty"` + + WeightedLoss WeightedLossConfig `yaml:"weighted_loss,omitempty"` + JQRMinWeightedLoss float64 `yaml:"jqr_min_weighted_loss,omitempty"` + DQRMaxWeightedLoss float64 `yaml:"dqr_max_weighted_loss,omitempty"` + + QueuingDelayEarlyWarningJQR CongestionSignalConfig `yaml:"queuing_delay_early_warning_jqr,omitempty"` + QueuingDelayEarlyWarningDQR CongestionSignalConfig `yaml:"queuing_delay_early_warning_dqr,omitempty"` + LossEarlyWarningJQR CongestionSignalConfig `yaml:"loss_early_warning_jqr,omitempty"` + LossEarlyWarningDQR CongestionSignalConfig `yaml:"loss_early_warning_dqr,omitempty"` + + QueuingDelayCongestedJQR CongestionSignalConfig `yaml:"queuing_delay_congested_jqr,omitempty"` + QueuingDelayCongestedDQR CongestionSignalConfig `yaml:"queuing_delay_congested_dqr,omitempty"` + LossCongestedJQR CongestionSignalConfig `yaml:"loss_congested_jqr,omitempty"` + LossCongestedDQR CongestionSignalConfig `yaml:"loss_congested_dqr,omitempty"` + + CongestedCTRTrend ccutils.TrendDetectorConfig `yaml:"congested_ctr_trend,omitempty"` + CongestedCTREpsilon float64 `yaml:"congested_ctr_epsilon,omitempty"` + CongestedPacketGroup PacketGroupConfig `yaml:"congested_packet_group,omitempty"` + + EstimationWindowDuration time.Duration `yaml:"estimaton_window_duration,omitempty"` +} + +var ( + defaultTrendDetectorConfigCongestedCTR = ccutils.TrendDetectorConfig{ + RequiredSamples: 4, + RequiredSamplesMin: 2, + DownwardTrendThreshold: -0.5, + DownwardTrendMaxWait: 2 * time.Second, + CollapseThreshold: 500 * time.Millisecond, + ValidityWindow: 10 * time.Second, + } + + defaultCongestedPacketGroupConfig = PacketGroupConfig{ + MinPackets: 20, + MaxWindowDuration: 150 * time.Millisecond, + } + + defaultCongestionDetectorConfig = CongestionDetectorConfig{ + PacketGroup: defaultPacketGroupConfig, + PacketGroupMaxAge: 10 * time.Second, + + ProbePacketGroup: defaultProbePacketGroupConfig, + ProbeRegulator: ccutils.DefaultProbeRegulatorConfig, + ProbeSignal: defaultProbeSignalConfig, + + JQRMinDelay: 50 * time.Millisecond, + JQRMinTrendCoefficient: 0.8, + DQRMaxDelay: 20 * time.Millisecond, + + WeightedLoss: defaultWeightedLossConfig, + JQRMinWeightedLoss: 0.25, + DQRMaxWeightedLoss: 0.1, + + QueuingDelayEarlyWarningJQR: defaultQueuingDelayEarlyWarningJQRConfig, + QueuingDelayEarlyWarningDQR: defaultQueuingDelayEarlyWarningDQRConfig, + LossEarlyWarningJQR: defaultLossEarlyWarningJQRConfig, + LossEarlyWarningDQR: defaultLossEarlyWarningDQRConfig, + + QueuingDelayCongestedJQR: defaultQueuingDelayCongestedJQRConfig, + QueuingDelayCongestedDQR: defaultQueuingDelayCongestedDQRConfig, + LossCongestedJQR: defaultLossCongestedJQRConfig, + LossCongestedDQR: defaultLossCongestedDQRConfig, + + CongestedCTRTrend: defaultTrendDetectorConfigCongestedCTR, + CongestedCTREpsilon: 0.05, + CongestedPacketGroup: defaultCongestedPacketGroupConfig, + + EstimationWindowDuration: time.Second, + } +) + +// ------------------------------------------------------------------------------- + +type congestionDetectorParams struct { + Config CongestionDetectorConfig + Logger logger.Logger +} + +type congestionDetector struct { + params congestionDetectorParams + + lock sync.Mutex + + rtt float64 + + *packetTracker + twccFeedback *twccFeedback + + packetGroups []*packetGroup + + probePacketGroup *probePacketGroup + probeRegulator *ccutils.ProbeRegulator + + estimatedAvailableChannelCapacity int64 + estimateTrafficStats *trafficStats + + congestionState bwe.CongestionState + congestionStateSwitchedAt time.Time + + congestedCTRTrend *ccutils.TrendDetector[float64] + congestedTrafficStats *trafficStats + congestedPacketGroup *packetGroup + + congestionReason congestionReason + qdMeasurement *qdMeasurement + lossMeasurement *lossMeasurement + + bweListener bwe.BWEListener +} + +func newCongestionDetector(params congestionDetectorParams) *congestionDetector { + c := &congestionDetector{ + params: params, + packetTracker: newPacketTracker(packetTrackerParams{Logger: params.Logger}), + twccFeedback: newTWCCFeedback(twccFeedbackParams{Logger: params.Logger}), + } + c.Reset() + + return c +} + +func (c *congestionDetector) Reset() { + c.lock.Lock() + defer c.lock.Unlock() + + c.rtt = bwe.DefaultRTT + + c.packetGroups = nil + + c.probePacketGroup = nil + c.probeRegulator = ccutils.NewProbeRegulator(ccutils.ProbeRegulatorParams{ + Config: c.params.Config.ProbeRegulator, + Logger: c.params.Logger, + }) + + c.estimatedAvailableChannelCapacity = 100_000_000 + c.estimateTrafficStats = nil + + c.congestionState = bwe.CongestionStateNone + c.congestionStateSwitchedAt = mono.Now() + + c.clearCTRTrend() + + c.congestionReason = congestionReasonNone + c.qdMeasurement = nil + c.lossMeasurement = nil +} + +func (c *congestionDetector) SetBWEListener(bweListener bwe.BWEListener) { + c.lock.Lock() + defer c.lock.Unlock() + + c.bweListener = bweListener +} + +func (c *congestionDetector) getBWEListener() bwe.BWEListener { + c.lock.Lock() + defer c.lock.Unlock() + + return c.bweListener +} + +func (c *congestionDetector) HandleTWCCFeedback(report *rtcp.TransportLayerCC) { + c.lock.Lock() + recvRefTime, isOutOfOrder := c.twccFeedback.ProcessReport(report, mono.Now()) + if isOutOfOrder { + c.params.Logger.Infow("send side bwe: received out-of-order feedback report") + } + + if len(c.packetGroups) == 0 { + c.packetGroups = append( + c.packetGroups, + newPacketGroup( + packetGroupParams{ + Config: c.params.Config.PacketGroup, + WeightedLoss: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }, + 0, + ), + ) + } + + pg := c.packetGroups[len(c.packetGroups)-1] + trackPacketGroup := func(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) { + if pi == nil { + return + } + + c.updateCTRTrend(pi, sendDelta, recvDelta, isLost) + + if c.probePacketGroup != nil { + c.probePacketGroup.Add(pi, sendDelta, recvDelta, isLost) + } + + err := pg.Add(pi, sendDelta, recvDelta, isLost) + if err == nil { + return + } + + if err == errGroupFinalized { + // previous group ended, start a new group + pg = newPacketGroup( + packetGroupParams{ + Config: c.params.Config.PacketGroup, + WeightedLoss: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }, + pg.PropagatedQueuingDelay(), + ) + c.packetGroups = append(c.packetGroups, pg) + + if err = pg.Add(pi, sendDelta, recvDelta, isLost); err != nil { + c.params.Logger.Warnw("send side bwe: could not add packet to new packet group", err, "packetInfo", pi, "packetGroup", pg) + } + return + } + + // try an older group + for idx := len(c.packetGroups) - 2; idx >= 0; idx-- { + opg := c.packetGroups[idx] + if err := opg.Add(pi, sendDelta, recvDelta, isLost); err == nil { + return + } else if err == errGroupFinalized { + c.params.Logger.Infow("send side bwe: unexpected finalized group", "packetInfo", pi, "packetGroup", opg) + } + } + } + + // 1. go through the TWCC feedback report and record receive time as reported by remote + // 2. process acknowledged packet and group them + // + // losses are not recorded if a feedback report is completely lost. + // RFC recommends treating lost reports by ignoring packets that would have been in it. + // ----------------------------------------------------------------------------------- + // | From a congestion control perspective, lost feedback messages are | + // | handled by ignoring packets which would have been reported as lost or | + // | received in the lost feedback messages. This behavior is similar to | + // | how a lost RTCP receiver report is handled. | + // ----------------------------------------------------------------------------------- + // Reference: https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01#page-4 + sequenceNumber := report.BaseSequenceNumber + endSequenceNumberExclusive := sequenceNumber + report.PacketStatusCount + deltaIdx := 0 + processSymbol := func(symbol uint16) { + recvTime := int64(0) + isLost := false + if symbol != rtcp.TypeTCCPacketNotReceived { + recvRefTime += report.RecvDeltas[deltaIdx].Delta + deltaIdx++ + + recvTime = recvRefTime + } else { + isLost = true + } + pi, sendDelta, recvDelta := c.packetTracker.RecordPacketIndicationFromRemote(sequenceNumber, recvTime) + if pi.sendTime != 0 { + trackPacketGroup(&pi, sendDelta, recvDelta, isLost) + } + sequenceNumber++ + } + for _, chunk := range report.PacketChunks { + if sequenceNumber == endSequenceNumberExclusive { + break + } + + switch chunk := chunk.(type) { + case *rtcp.RunLengthChunk: + for i := uint16(0); i < chunk.RunLength; i++ { + if sequenceNumber == endSequenceNumberExclusive { + break + } + + processSymbol(chunk.PacketStatusSymbol) + } + + case *rtcp.StatusVectorChunk: + for _, symbol := range chunk.SymbolList { + if sequenceNumber == endSequenceNumberExclusive { + break + } + + processSymbol(symbol) + } + } + } + + c.prunePacketGroups() + shouldNotify, fromState, toState, committedChannelCapacity := c.congestionDetectionStateMachine() + c.lock.Unlock() + + if shouldNotify { + if bweListener := c.getBWEListener(); bweListener != nil { + bweListener.OnCongestionStateChange(fromState, toState, committedChannelCapacity) + } + } +} + +func (c *congestionDetector) UpdateRTT(rtt float64) { + c.lock.Lock() + defer c.lock.Unlock() + + if rtt == 0 { + c.rtt = bwe.DefaultRTT + } else { + if c.rtt == 0 { + c.rtt = rtt + } else { + c.rtt = bwe.RTTSmoothingFactor*rtt + (1.0-bwe.RTTSmoothingFactor)*c.rtt + } + } +} + +func (c *congestionDetector) CongestionState() bwe.CongestionState { + c.lock.Lock() + defer c.lock.Unlock() + + return c.congestionState +} + +func (c *congestionDetector) CanProbe() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.congestionState == bwe.CongestionStateNone && c.probePacketGroup == nil && c.probeRegulator.CanProbe() +} + +func (c *congestionDetector) ProbeDuration() time.Duration { + c.lock.Lock() + defer c.lock.Unlock() + + return c.probeRegulator.ProbeDuration() +} + +func (c *congestionDetector) ProbeClusterStarting(pci ccutils.ProbeClusterInfo) { + c.lock.Lock() + defer c.lock.Unlock() + + c.probePacketGroup = newProbePacketGroup( + probePacketGroupParams{ + Config: c.params.Config.ProbePacketGroup, + WeightedLoss: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }, + pci, + ) + + c.packetTracker.ProbeClusterStarting(pci.Id) +} + +func (c *congestionDetector) ProbeClusterDone(pci ccutils.ProbeClusterInfo) { + c.lock.Lock() + defer c.lock.Unlock() + + c.packetTracker.ProbeClusterDone(pci.Id) + if c.probePacketGroup != nil { + c.probePacketGroup.ProbeClusterDone(pci) + } +} + +func (c *congestionDetector) ProbeClusterIsGoalReached() bool { + c.lock.Lock() + defer c.lock.Unlock() + + if c.probePacketGroup == nil || c.congestionState != bwe.CongestionStateNone { + return false + } + + pci := c.probePacketGroup.ProbeClusterInfo() + if !c.params.Config.ProbeSignal.IsValid(pci) { + return false + } + + probeSignal, estimatedAvailableChannelCapacity := c.params.Config.ProbeSignal.ProbeSignal(c.probePacketGroup) + return probeSignal != ccutils.ProbeSignalNotCongesting && estimatedAvailableChannelCapacity > int64(pci.Goal.DesiredBps) +} + +func (c *congestionDetector) ProbeClusterFinalize() (ccutils.ProbeSignal, int64, bool) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.probePacketGroup == nil { + return ccutils.ProbeSignalInconclusive, 0, false + } + + pci, isFinalized := c.probePacketGroup.MaybeFinalizeProbe(c.packetTracker.ProbeMaxSequenceNumber(), c.rtt) + if !isFinalized { + return ccutils.ProbeSignalInconclusive, 0, isFinalized + } + + isSignalValid := c.params.Config.ProbeSignal.IsValid(pci) + c.params.Logger.Infow( + "send side bwe: probe finalized", + "isSignalValid", isSignalValid, + "probeClusterInfo", pci, + "probePacketGroup", c.probePacketGroup, + "congestionState", c.congestionState, + "rtt", c.rtt, + ) + + // if congestion signal changed during probe, defer to that signal + if c.congestionState != bwe.CongestionStateNone { + probeSignal := ccutils.ProbeSignalCongesting + c.probeRegulator.ProbeSignal(probeSignal, pci.CreatedAt) + c.probePacketGroup = nil + return probeSignal, c.estimatedAvailableChannelCapacity, true + } + + probeSignal, estimatedAvailableChannelCapacity := c.params.Config.ProbeSignal.ProbeSignal(c.probePacketGroup) + if probeSignal == ccutils.ProbeSignalNotCongesting && estimatedAvailableChannelCapacity > c.estimatedAvailableChannelCapacity { + c.estimatedAvailableChannelCapacity = estimatedAvailableChannelCapacity + } + + c.probeRegulator.ProbeSignal(probeSignal, pci.CreatedAt) + c.probePacketGroup = nil + return probeSignal, c.estimatedAvailableChannelCapacity, true +} + +func (c *congestionDetector) prunePacketGroups() { + if len(c.packetGroups) == 0 { + return + } + + threshold, ok := c.packetTracker.BaseSendTimeThreshold(c.params.Config.PacketGroupMaxAge.Microseconds()) + if !ok { + return + } + + for idx, pg := range c.packetGroups { + if mst := pg.MinSendTime(); mst > threshold { + c.packetGroups = c.packetGroups[idx:] + return + } + } +} + +func (c *congestionDetector) updateCongestionSignal( + qdJQRConfig CongestionSignalConfig, + qdDQRConfig CongestionSignalConfig, + lossJQRConfig CongestionSignalConfig, + lossDQRConfig CongestionSignalConfig, +) queuingRegion { + c.qdMeasurement = newQDMeasurement( + qdJQRConfig, + qdDQRConfig, + c.params.Config.JQRMinDelay.Microseconds(), + c.params.Config.JQRMinTrendCoefficient, + c.params.Config.DQRMaxDelay.Microseconds(), + ) + c.lossMeasurement = newLossMeasurement( + lossJQRConfig, + lossDQRConfig, + c.params.Config.WeightedLoss, + c.params.Config.JQRMinWeightedLoss, + c.params.Config.DQRMaxWeightedLoss, + c.params.Logger, + ) + + var idx int + for idx = len(c.packetGroups) - 1; idx >= 0; idx-- { + pg := c.packetGroups[idx] + c.qdMeasurement.ProcessPacketGroup(pg, idx) + c.lossMeasurement.ProcessPacketGroup(pg, idx) + + // if both measurements have enough data to make a decision, stop processing groups + if c.qdMeasurement.IsSealed() && c.lossMeasurement.IsSealed() { + break + } + } + + qr := queuingRegionIndeterminate + qdQueuingRegion := c.qdMeasurement.QueuingRegion() + lossQueuingRegion := c.lossMeasurement.QueuingRegion() + switch { + case qdQueuingRegion == queuingRegionJQR: + qr = queuingRegionJQR + c.congestionReason = congestionReasonQueuingDelay + case lossQueuingRegion == queuingRegionJQR: + qr = queuingRegionJQR + c.congestionReason = congestionReasonLoss + case qdQueuingRegion == queuingRegionDQR && lossQueuingRegion == queuingRegionDQR: + qr = queuingRegionDQR + c.congestionReason = congestionReasonNone + } + + return qr +} + +func (c *congestionDetector) updateEarlyWarningSignal() queuingRegion { + return c.updateCongestionSignal( + c.params.Config.QueuingDelayEarlyWarningJQR, + c.params.Config.QueuingDelayEarlyWarningDQR, + c.params.Config.LossEarlyWarningJQR, + c.params.Config.LossEarlyWarningDQR, + ) +} + +func (c *congestionDetector) updateCongestedSignal() queuingRegion { + return c.updateCongestionSignal( + c.params.Config.QueuingDelayCongestedJQR, + c.params.Config.QueuingDelayCongestedDQR, + c.params.Config.LossCongestedJQR, + c.params.Config.LossCongestedDQR, + ) +} + +func (c *congestionDetector) congestionDetectionStateMachine() (bool, bwe.CongestionState, bwe.CongestionState, int64) { + fromState := c.congestionState + toState := c.congestionState + + switch fromState { + case bwe.CongestionStateNone: + if c.updateEarlyWarningSignal() == queuingRegionJQR { + toState = bwe.CongestionStateEarlyWarning + } + + case bwe.CongestionStateEarlyWarning: + if c.updateCongestedSignal() == queuingRegionJQR { + toState = bwe.CongestionStateCongested + } else if c.updateEarlyWarningSignal() == queuingRegionDQR { + toState = bwe.CongestionStateNone + } + + case bwe.CongestionStateCongested: + if c.updateCongestedSignal() == queuingRegionDQR { + toState = bwe.CongestionStateNone + } + } + + shouldNotify := false + if toState != fromState { + c.estimateAvailableChannelCapacity() + fromState, toState = c.updateCongestionState(toState) + shouldNotify = true + } + + if c.congestedCTRTrend != nil && c.congestedCTRTrend.GetDirection() == ccutils.TrendDirectionDownward { + congestedAckedBitrate := c.congestedTrafficStats.AcknowledgedBitrate() + if congestedAckedBitrate < c.estimatedAvailableChannelCapacity { + c.estimatedAvailableChannelCapacity = congestedAckedBitrate + + c.params.Logger.Infow( + "send side bwe: captured traffic ratio is trending downward", + "channel", c.congestedCTRTrend, + "trafficStats", c.congestedTrafficStats, + "estimatedAvailableChannelCapacity", c.estimatedAvailableChannelCapacity, + ) + + shouldNotify = true + } + + // reset to get new set of samples for next trend + c.resetCTRTrend() + } + + return shouldNotify, fromState, toState, c.estimatedAvailableChannelCapacity +} + +func (c *congestionDetector) createCTRTrend() { + c.resetCTRTrend() + c.congestedPacketGroup = nil +} + +func (c *congestionDetector) resetCTRTrend() { + c.congestedCTRTrend = ccutils.NewTrendDetector[float64](ccutils.TrendDetectorParams{ + Name: "ssbwe-ctr", + Logger: c.params.Logger, + Config: c.params.Config.CongestedCTRTrend, + }) + c.congestedTrafficStats = newTrafficStats(trafficStatsParams{ + Config: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }) +} + +func (c *congestionDetector) clearCTRTrend() { + c.congestedCTRTrend = nil + c.congestedTrafficStats = nil + c.congestedPacketGroup = nil +} + +func (c *congestionDetector) updateCTRTrend(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) { + if c.congestedCTRTrend == nil { + return + } + + if c.congestedPacketGroup == nil { + c.congestedPacketGroup = newPacketGroup( + packetGroupParams{ + Config: c.params.Config.CongestedPacketGroup, + WeightedLoss: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }, + 0, + ) + } + + if err := c.congestedPacketGroup.Add(pi, sendDelta, recvDelta, isLost); err == errGroupFinalized { + // progressively keep increasing the window and make measurements over longer windows, + // if congestion is not relieving, CTR will trend down + c.congestedTrafficStats.Merge(c.congestedPacketGroup.Traffic()) + + ts := newTrafficStats(trafficStatsParams{ + Config: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }) + ts.Merge(c.congestedPacketGroup.Traffic()) + ctr := ts.CapturedTrafficRatio() + + // quantise CTR to filter out small changes + c.congestedCTRTrend.AddValue(float64(int((ctr+(c.params.Config.CongestedCTREpsilon/2))/c.params.Config.CongestedCTREpsilon)) * c.params.Config.CongestedCTREpsilon) + + c.congestedPacketGroup = newPacketGroup( + packetGroupParams{ + Config: c.params.Config.CongestedPacketGroup, + WeightedLoss: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }, + c.congestedPacketGroup.PropagatedQueuingDelay(), + ) + } +} + +func (c *congestionDetector) estimateAvailableChannelCapacity() { + c.estimateTrafficStats = nil + if len(c.packetGroups) == 0 || c.probePacketGroup != nil { + return + } + + // when congested, use contributing groups, + // else use a time windowed measurement + useWindow := false + isAggValid := true + minGroupIdx := 0 + maxGroupIdx := len(c.packetGroups) - 1 + switch c.congestionReason { + case congestionReasonQueuingDelay: + minGroupIdx, maxGroupIdx = c.qdMeasurement.GroupRange() + case congestionReasonLoss: + minGroupIdx, maxGroupIdx = c.lossMeasurement.GroupRange() + default: + useWindow = true + isAggValid = false + } + + agg := newTrafficStats(trafficStatsParams{ + Config: c.params.Config.WeightedLoss, + Logger: c.params.Logger, + }) + for idx := maxGroupIdx; idx >= minGroupIdx; idx-- { + pg := c.packetGroups[idx] + if !pg.IsFinalized() { + continue + } + + agg.Merge(pg.Traffic()) + if useWindow && agg.Duration() > c.params.Config.EstimationWindowDuration.Microseconds() { + isAggValid = true + break + } + } + + if isAggValid { + c.estimatedAvailableChannelCapacity = agg.AcknowledgedBitrate() + c.estimateTrafficStats = agg + } +} + +func (c *congestionDetector) updateCongestionState(state bwe.CongestionState) (bwe.CongestionState, bwe.CongestionState) { + loggingFields := []any{ + "from", c.congestionState, + "to", state, + "congestionReason", c.congestionReason, + "qdMeasurement", c.qdMeasurement, + "lossMeasurement", c.lossMeasurement, + "numPacketGroups", len(c.packetGroups), + "estimatedAvailableChannelCapacity", c.estimatedAvailableChannelCapacity, + "estimateTrafficStats", c.estimateTrafficStats, + } + if c.congestionReason != congestionReasonNone { + var minGroupIdx, maxGroupIdx int + switch c.congestionReason { + case congestionReasonQueuingDelay: + minGroupIdx, maxGroupIdx = c.qdMeasurement.GroupRange() + case congestionReasonLoss: + minGroupIdx, maxGroupIdx = c.lossMeasurement.GroupRange() + } + loggingFields = append( + loggingFields, + "contributingGroups", logger.ObjectSlice(c.packetGroups[minGroupIdx:maxGroupIdx+1]), + ) + } + c.params.Logger.Infow("send side bwe: congestion state change", loggingFields...) + + if state != c.congestionState { + c.congestionStateSwitchedAt = mono.Now() + } + + fromState := c.congestionState + c.congestionState = state + + // when in congested state, monitor changes in captured traffic ratio (CTR) + // to ensure allocations are in line with latest estimates, it is possible that + // the estimate is incorrect when congestion starts and the allocation may be + // sub-optimal and not enough to reduce/relieve congestion, by monitoring CTR + // on a continuous basis allocations can be adjusted in the direction of + // reducing/relieving congestion + if state == bwe.CongestionStateCongested && fromState != bwe.CongestionStateCongested { + c.createCTRTrend() + } else if state != bwe.CongestionStateCongested { + c.clearCTRTrend() + } + + return fromState, c.congestionState +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/packet_group.go b/livekit/pkg/sfu/bwe/sendsidebwe/packet_group.go new file mode 100644 index 0000000..b64727f --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/packet_group.go @@ -0,0 +1,310 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "errors" + "time" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------------------- + +var ( + errGroupFinalized = errors.New("packet group is finalized") + errOldPacket = errors.New("packet is older than packet group start") +) + +// ------------------------------------------------------------- + +type PacketGroupConfig struct { + MinPackets int `yaml:"min_packets,omitempty"` + MaxWindowDuration time.Duration `yaml:"max_window_duration,omitempty"` +} + +var ( + defaultPacketGroupConfig = PacketGroupConfig{ + MinPackets: 30, + MaxWindowDuration: 500 * time.Millisecond, + } +) + +// ------------------------------------------------------------- + +type stat struct { + numPackets int + numBytes int +} + +func (s *stat) add(size int) { + s.numPackets++ + s.numBytes += size +} + +func (s *stat) remove(size int) { + s.numPackets-- + s.numBytes -= size +} + +func (s *stat) getNumPackets() int { + return s.numPackets +} + +func (s *stat) getNumBytes() int { + return s.numBytes +} + +func (s stat) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddInt("numPackets", s.numPackets) + e.AddInt("numBytes", s.numBytes) + return nil +} + +// ------------------------------------------------------------- + +type classStat struct { + primary stat + rtx stat + probe stat +} + +func (c *classStat) add(size int, isRTX bool, isProbe bool) { + if isRTX { + c.rtx.add(size) + } else if isProbe { + c.probe.add(size) + } else { + c.primary.add(size) + } +} + +func (c *classStat) remove(size int, isRTX bool, isProbe bool) { + if isRTX { + c.rtx.remove(size) + } else if isProbe { + c.probe.remove(size) + } else { + c.primary.remove(size) + } +} + +func (c *classStat) numPackets() int { + return c.primary.getNumPackets() + c.rtx.getNumPackets() + c.probe.getNumPackets() +} + +func (c *classStat) numBytes() int { + return c.primary.getNumBytes() + c.rtx.getNumBytes() + c.probe.getNumBytes() +} + +func (c classStat) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddObject("primary", c.primary) + e.AddObject("rtx", c.rtx) + e.AddObject("probe", c.probe) + return nil +} + +// ------------------------------------------------------------- + +type packetGroupParams struct { + Config PacketGroupConfig + WeightedLoss WeightedLossConfig + Logger logger.Logger +} + +type packetGroup struct { + params packetGroupParams + + minSequenceNumber uint64 + maxSequenceNumber uint64 + + minSendTime int64 + maxSendTime int64 + + minRecvTime int64 // for information only + maxRecvTime int64 // for information only + + acked classStat + lost classStat + snBitmap *utils.Bitmap[uint64] + + aggregateSendDelta int64 + aggregateRecvDelta int64 + inheritedQueuingDelay int64 + + isFinalized bool +} + +func newPacketGroup(params packetGroupParams, inheritedQueuingDelay int64) *packetGroup { + return &packetGroup{ + params: params, + inheritedQueuingDelay: inheritedQueuingDelay, + snBitmap: utils.NewBitmap[uint64](params.Config.MinPackets), + } +} + +func (p *packetGroup) Add(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) error { + if isLost { + return p.lostPacket(pi) + } + + if err := p.inGroup(pi.sequenceNumber); err != nil { + return err + } + + if p.minSequenceNumber == 0 || pi.sequenceNumber < p.minSequenceNumber { + p.minSequenceNumber = pi.sequenceNumber + } + p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber) + + if p.minSendTime == 0 || (pi.sendTime-sendDelta) < p.minSendTime { + p.minSendTime = pi.sendTime - sendDelta + } + p.maxSendTime = max(p.maxSendTime, pi.sendTime) + + if p.minRecvTime == 0 || (pi.recvTime-recvDelta) < p.minRecvTime { + p.minRecvTime = pi.recvTime - recvDelta + } + p.maxRecvTime = max(p.maxRecvTime, pi.recvTime) + + p.acked.add(int(pi.size), pi.isRTX, pi.isProbe) + if int(pi.sequenceNumber-p.minSequenceNumber) < p.snBitmap.Len() && p.snBitmap.IsSet(pi.sequenceNumber-p.minSequenceNumber) { + // an earlier packet reported as lost has been received + p.snBitmap.Clear(pi.sequenceNumber - p.minSequenceNumber) + p.lost.remove(int(pi.size), pi.isRTX, pi.isProbe) + } + + // note that out-of-order deliveries will amplify the queueing delay. + // for e.g. a, b, c getting delivered as a, c, b. + // let us say packets are delivered with interval of `x` + // send delta aggregate will go up by x((a, c) = 2x + (c, b) -1x) + // recv delta aggregate will go up by 3x((a, c) = 2x + (c, b) 1x) + p.aggregateSendDelta += sendDelta + p.aggregateRecvDelta += recvDelta + + if p.acked.numPackets() == p.params.Config.MinPackets || (pi.sendTime-p.minSendTime) > p.params.Config.MaxWindowDuration.Microseconds() { + p.isFinalized = true + } + return nil +} + +func (p *packetGroup) lostPacket(pi *packetInfo) error { + if pi.recvTime != 0 { + // previously received packet, so not lost + return nil + } + + if err := p.inGroup(pi.sequenceNumber); err != nil { + return err + } + + if p.minSequenceNumber == 0 || pi.sequenceNumber < p.minSequenceNumber { + p.minSequenceNumber = pi.sequenceNumber + } + p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber) + p.snBitmap.Set(pi.sequenceNumber - p.minSequenceNumber) + + p.lost.add(int(pi.size), pi.isRTX, pi.isProbe) + return nil +} + +func (p *packetGroup) MinSendTime() int64 { + return p.minSendTime +} + +func (p *packetGroup) SendWindow() (int64, int64) { + return p.minSendTime, p.maxSendTime +} + +func (p *packetGroup) PropagatedQueuingDelay() int64 { + if p.inheritedQueuingDelay+p.aggregateRecvDelta-p.aggregateSendDelta > 0 { + return p.inheritedQueuingDelay + p.aggregateRecvDelta - p.aggregateSendDelta + } + + return max(0, p.aggregateRecvDelta-p.aggregateSendDelta) +} + +func (p *packetGroup) FinalizedPropagatedQueuingDelay() (int64, bool) { + if !p.isFinalized { + return 0, false + } + + return p.PropagatedQueuingDelay(), true +} + +func (p *packetGroup) IsFinalized() bool { + return p.isFinalized +} + +func (p *packetGroup) Traffic() *trafficStats { + return &trafficStats{ + minSendTime: p.minSendTime, + maxSendTime: p.maxSendTime, + sendDelta: p.aggregateSendDelta, + recvDelta: p.aggregateRecvDelta, + ackedPackets: p.acked.numPackets(), + ackedBytes: p.acked.numBytes(), + lostPackets: p.lost.numPackets(), + lostBytes: p.lost.numBytes(), + } +} + +func (p *packetGroup) MarshalLogObject(e zapcore.ObjectEncoder) error { + if p == nil { + return nil + } + + e.AddUint64("minSequenceNumber", p.minSequenceNumber) + e.AddUint64("maxSequenceNumber", p.maxSequenceNumber) + e.AddObject("acked", p.acked) + e.AddObject("lost", p.lost) + + e.AddInt64("minRecvTime", p.minRecvTime) + e.AddInt64("maxRecvTime", p.maxRecvTime) + recvDuration := time.Duration((p.maxRecvTime - p.minRecvTime) * 1000) + e.AddDuration("recvDuration", recvDuration) + + recvBitrate := float64(0) + if recvDuration != 0 { + recvBitrate = float64(p.acked.numBytes()*8) / recvDuration.Seconds() + e.AddFloat64("recvBitrate", recvBitrate) + } + + ts := newTrafficStats(trafficStatsParams{ + Config: p.params.WeightedLoss, + Logger: p.params.Logger, + }) + ts.Merge(p.Traffic()) + e.AddObject("trafficStats", ts) + e.AddInt64("inheritedQueuingDelay", p.inheritedQueuingDelay) + e.AddInt64("propagatedQueuingDelay", p.PropagatedQueuingDelay()) + + e.AddBool("isFinalized", p.isFinalized) + return nil +} + +func (p *packetGroup) inGroup(sequenceNumber uint64) error { + if p.isFinalized && sequenceNumber > p.maxSequenceNumber { + return errGroupFinalized + } + + if sequenceNumber < p.minSequenceNumber { + return errOldPacket + } + + return nil +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/packet_info.go b/livekit/pkg/sfu/bwe/sendsidebwe/packet_info.go new file mode 100644 index 0000000..730e2a7 --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/packet_info.go @@ -0,0 +1,45 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "go.uber.org/zap/zapcore" +) + +type packetInfo struct { + sequenceNumber uint64 + sendTime int64 + recvTime int64 + probeClusterId ccutils.ProbeClusterId + size uint16 + isRTX bool + isProbe bool +} + +func (pi *packetInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + if pi == nil { + return nil + } + + e.AddUint64("sequenceNumber", pi.sequenceNumber) + e.AddInt64("sendTime", pi.sendTime) + e.AddInt64("recvTime", pi.recvTime) + e.AddUint32("probeClusterId", uint32(pi.probeClusterId)) + e.AddUint16("size", pi.size) + e.AddBool("isRTX", pi.isRTX) + e.AddBool("isProbe", pi.isProbe) + return nil +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/packet_tracker.go b/livekit/pkg/sfu/bwe/sendsidebwe/packet_tracker.go new file mode 100644 index 0000000..a41e50f --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/packet_tracker.go @@ -0,0 +1,168 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "math/rand" + "sync" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +// ------------------------------------------------------------------------------- + +type packetTrackerParams struct { + Logger logger.Logger +} + +type packetTracker struct { + params packetTrackerParams + + lock sync.Mutex + + sequenceNumber uint64 + + baseSendTime int64 + packetInfos [2048]packetInfo + + baseRecvTime int64 + piLastRecv *packetInfo + + probeClusterId ccutils.ProbeClusterId + probeMaxSequenceNumber uint64 +} + +func newPacketTracker(params packetTrackerParams) *packetTracker { + return &packetTracker{ + params: params, + sequenceNumber: uint64(rand.Intn(1<<14)) + uint64(1<<15), // a random number in third quartile of sequence number space + } +} + +func (p *packetTracker) RecordPacketSendAndGetSequenceNumber( + atMicro int64, + size int, + isRTX bool, + probeClusterId ccutils.ProbeClusterId, + isProbe bool, +) uint16 { + p.lock.Lock() + defer p.lock.Unlock() + + if p.baseSendTime == 0 { + p.baseSendTime = atMicro + } + + pi := p.getPacketInfo(uint16(p.sequenceNumber)) + *pi = packetInfo{ + sequenceNumber: p.sequenceNumber, + sendTime: atMicro - p.baseSendTime, + size: uint16(size), + isRTX: isRTX, + probeClusterId: probeClusterId, + isProbe: isProbe, + } + + p.sequenceNumber++ + + // extreme case of wrap around before receiving any feedback + if pi == p.piLastRecv { + p.piLastRecv = nil + } + + if p.probeClusterId != ccutils.ProbeClusterIdInvalid && p.probeClusterId == pi.probeClusterId && pi.sequenceNumber > p.probeMaxSequenceNumber { + p.probeMaxSequenceNumber = pi.sequenceNumber + } + + return uint16(pi.sequenceNumber) +} + +func (p *packetTracker) BaseSendTimeThreshold(threshold int64) (int64, bool) { + p.lock.Lock() + defer p.lock.Unlock() + + if p.baseSendTime == 0 { + return 0, false + } + + return mono.UnixMicro() - p.baseSendTime - threshold, true +} + +func (p *packetTracker) RecordPacketIndicationFromRemote(sn uint16, recvTime int64) (piRecv packetInfo, sendDelta, recvDelta int64) { + p.lock.Lock() + defer p.lock.Unlock() + + pi := p.getPacketInfoExisting(sn) + if pi == nil { + return + } + + if recvTime == 0 { + // maybe lost OR already received but reported lost in a later report + piRecv = *pi + return + } + + if p.baseRecvTime == 0 { + p.baseRecvTime = recvTime + p.piLastRecv = pi + } + + pi.recvTime = recvTime - p.baseRecvTime + piRecv = *pi + if p.piLastRecv != nil { + sendDelta, recvDelta = pi.sendTime-p.piLastRecv.sendTime, pi.recvTime-p.piLastRecv.recvTime + } + p.piLastRecv = pi + return +} + +func (p *packetTracker) getPacketInfo(sn uint16) *packetInfo { + return &p.packetInfos[int(sn)%len(p.packetInfos)] +} + +func (p *packetTracker) getPacketInfoExisting(sn uint16) *packetInfo { + pi := &p.packetInfos[int(sn)%len(p.packetInfos)] + if uint16(pi.sequenceNumber) == sn { + return pi + } + + return nil +} + +func (p *packetTracker) ProbeClusterStarting(probeClusterId ccutils.ProbeClusterId) { + p.lock.Lock() + defer p.lock.Unlock() + + p.probeClusterId = probeClusterId +} + +func (p *packetTracker) ProbeClusterDone(probeClusterId ccutils.ProbeClusterId) { + p.lock.Lock() + defer p.lock.Unlock() + + if p.probeClusterId == probeClusterId { + p.probeClusterId = ccutils.ProbeClusterIdInvalid + } +} + +func (p *packetTracker) ProbeMaxSequenceNumber() uint64 { + p.lock.Lock() + defer p.lock.Unlock() + + return p.probeMaxSequenceNumber +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/probe_packet_group.go b/livekit/pkg/sfu/bwe/sendsidebwe/probe_packet_group.go new file mode 100644 index 0000000..68a1b0c --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/probe_packet_group.go @@ -0,0 +1,137 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "time" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------------------- + +type ProbePacketGroupConfig struct { + PacketGroup PacketGroupConfig `yaml:"packet_group,omitempty"` + + SettleWaitNumRTT uint32 `yaml:"settle_wait_num_rtt,omitempty"` + SettleWaitMin time.Duration `yaml:"settle_wait_min,omitempty"` + SettleWaitMax time.Duration `yaml:"settle_wait_max,omitempty"` +} + +var ( + // large numbers to treat a probe packet group as one + defaultProbePacketGroupConfig = ProbePacketGroupConfig{ + PacketGroup: PacketGroupConfig{ + MinPackets: 16384, + MaxWindowDuration: time.Minute, + }, + + SettleWaitNumRTT: 5, + SettleWaitMin: 250 * time.Millisecond, + SettleWaitMax: 5 * time.Second, + } +) + +// ------------------------------------------------------------- + +type probePacketGroupParams struct { + Config ProbePacketGroupConfig + WeightedLoss WeightedLossConfig + Logger logger.Logger +} + +type probePacketGroup struct { + params probePacketGroupParams + pci ccutils.ProbeClusterInfo + *packetGroup + maxSequenceNumber uint64 + doneAt time.Time +} + +func newProbePacketGroup(params probePacketGroupParams, pci ccutils.ProbeClusterInfo) *probePacketGroup { + return &probePacketGroup{ + params: params, + pci: pci, + packetGroup: newPacketGroup( + packetGroupParams{ + Config: params.Config.PacketGroup, + WeightedLoss: params.WeightedLoss, + Logger: params.Logger, + }, + 0, + ), + } +} + +func (p *probePacketGroup) ProbeClusterDone(pci ccutils.ProbeClusterInfo) { + if p.pci.Id != pci.Id { + return + } + + p.pci.Result = pci.Result + p.doneAt = mono.Now() +} + +func (p *probePacketGroup) ProbeClusterInfo() ccutils.ProbeClusterInfo { + return p.pci +} + +func (p *probePacketGroup) MaybeFinalizeProbe(maxSequenceNumber uint64, rtt float64) (ccutils.ProbeClusterInfo, bool) { + if p.doneAt.IsZero() { + return ccutils.ProbeClusterInfoInvalid, false + } + + if maxSequenceNumber != 0 && p.maxSequenceNumber >= maxSequenceNumber { + return p.pci, true + } + + settleWait := time.Duration(float64(p.params.Config.SettleWaitNumRTT) * rtt * float64(time.Second)) + if settleWait < p.params.Config.SettleWaitMin { + settleWait = p.params.Config.SettleWaitMin + } + if settleWait > p.params.Config.SettleWaitMax { + settleWait = p.params.Config.SettleWaitMax + } + if time.Since(p.doneAt) < settleWait { + return ccutils.ProbeClusterInfoInvalid, false + } + + return p.pci, true +} + +func (p *probePacketGroup) Add(pi *packetInfo, sendDelta, recvDelta int64, isLost bool) error { + if pi.probeClusterId != p.pci.Id { + return nil + } + + p.maxSequenceNumber = max(p.maxSequenceNumber, pi.sequenceNumber) + + return p.packetGroup.Add(pi, sendDelta, recvDelta, isLost) +} + +func (p *probePacketGroup) MarshalLogObject(e zapcore.ObjectEncoder) error { + if p == nil { + return nil + } + + e.AddObject("pci", p.pci) + e.AddObject("packetGroup", p.packetGroup) + e.AddUint64("maxSequenceNumber", p.maxSequenceNumber) + e.AddTime("doneAt", p.doneAt) + return nil +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/send_side_bwe.go b/livekit/pkg/sfu/bwe/sendsidebwe/send_side_bwe.go new file mode 100644 index 0000000..c31be1e --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/send_side_bwe.go @@ -0,0 +1,151 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "time" + + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/pion/rtcp" +) + +var _ bwe.BWE = (*SendSideBWE)(nil) + +// +// Based on a simplified/modified version of JitterPath paper +// (https://homepage.iis.sinica.edu.tw/papers/lcs/2114-F.pdf) +// +// TWCC feedback is uesed to calcualte delta one-way-delay. +// It is accumulated/propagated to determine in which region +// groups of packets are operating in. +// +// In simplified terms, +// o JQR (Join Queuing Region) is when channel is congested. +// o DQR (Disjoint Queuing Region) is when channel is not. +// +// Packets are grouped and thresholds applied to smooth over +// small variations. For example, in the paper, +// if propagated_queuing_delay + delta_one_way_delay > 0 { +// possibly_operating_in_jqr +// } +// But, in this implementation it is checked at packet group level, +// i. e. using queuing delay and aggreated delta one-way-delay of +// the group and a minimum value threshold is applied before declaring +// that a group is in JQR. +// +// There is also hysteresis to make transisitons smoother, i.e. if the +// metric is above a certain threshold, it is JQR and it is DQR only if it +// is below a certain value and the gap in between those two thresholds +// are treated as interdeterminate groups. +// + +// --------------------------------------------------------------------------- + +type SendSideBWEConfig struct { + CongestionDetector CongestionDetectorConfig `yaml:"congestion_detector,omitempty"` +} + +var ( + DefaultSendSideBWEConfig = SendSideBWEConfig{ + CongestionDetector: defaultCongestionDetectorConfig, + } +) + +// --------------------------------------------------------------------------- + +type SendSideBWEParams struct { + Config SendSideBWEConfig + Logger logger.Logger +} + +type SendSideBWE struct { + bwe.NullBWE + + params SendSideBWEParams + + *congestionDetector +} + +func NewSendSideBWE(params SendSideBWEParams) *SendSideBWE { + return &SendSideBWE{ + params: params, + congestionDetector: newCongestionDetector(congestionDetectorParams{ + Config: params.Config.CongestionDetector, + Logger: params.Logger, + }), + } +} + +func (r *SendSideBWE) Type() bwe.BWEType { + return bwe.BWETypeSendSide +} + +func (s *SendSideBWE) SetBWEListener(bweListener bwe.BWEListener) { + s.congestionDetector.SetBWEListener(bweListener) +} + +func (s *SendSideBWE) Reset() { + s.congestionDetector.Reset() +} + +func (s *SendSideBWE) RecordPacketSendAndGetSequenceNumber( + atMicro int64, + size int, + isRTX bool, + probeClusterId ccutils.ProbeClusterId, + isProbe bool, +) uint16 { + return s.congestionDetector.RecordPacketSendAndGetSequenceNumber(atMicro, size, isRTX, probeClusterId, isProbe) +} + +func (s *SendSideBWE) HandleTWCCFeedback(report *rtcp.TransportLayerCC) { + s.congestionDetector.HandleTWCCFeedback(report) +} + +func (s *SendSideBWE) UpdateRTT(rtt float64) { + s.congestionDetector.UpdateRTT(rtt) +} + +func (s *SendSideBWE) CongestionState() bwe.CongestionState { + return s.congestionDetector.CongestionState() +} + +func (s *SendSideBWE) CanProbe() bool { + return s.congestionDetector.CanProbe() +} + +func (s *SendSideBWE) ProbeDuration() time.Duration { + return s.congestionDetector.ProbeDuration() +} + +func (s *SendSideBWE) ProbeClusterStarting(pci ccutils.ProbeClusterInfo) { + s.congestionDetector.ProbeClusterStarting(pci) +} + +func (s *SendSideBWE) ProbeClusterDone(pci ccutils.ProbeClusterInfo) { + s.congestionDetector.ProbeClusterDone(pci) +} + +func (s *SendSideBWE) ProbeClusterIsGoalReached() bool { + return s.congestionDetector.ProbeClusterIsGoalReached() +} + +func (s *SendSideBWE) ProbeClusterFinalize() (ccutils.ProbeSignal, int64, bool) { + return s.congestionDetector.ProbeClusterFinalize() +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/traffic_stats.go b/livekit/pkg/sfu/bwe/sendsidebwe/traffic_stats.go new file mode 100644 index 0000000..f1c07d5 --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/traffic_stats.go @@ -0,0 +1,191 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "math" + "time" + + "github.com/livekit/protocol/logger" + "go.uber.org/zap/zapcore" +) + +// ----------------------------------------------------------- + +type WeightedLossConfig struct { + MinDurationForLossValidity time.Duration `yaml:"min_duration_for_loss_validity,omitempty"` + BaseDuration time.Duration `yaml:"base_duration,omitempty"` + BasePPS int `yaml:"base_pps,omitempty"` + LossPenaltyFactor float64 `yaml:"loss_penalty_factor,omitempty"` +} + +var ( + defaultWeightedLossConfig = WeightedLossConfig{ + MinDurationForLossValidity: 100 * time.Millisecond, + BaseDuration: 500 * time.Millisecond, + BasePPS: 30, + LossPenaltyFactor: 0.25, + } +) + +// ----------------------------------------------------------- + +type trafficStatsParams struct { + Config WeightedLossConfig + Logger logger.Logger +} + +type trafficStats struct { + params trafficStatsParams + + minSendTime int64 + maxSendTime int64 + sendDelta int64 + recvDelta int64 + ackedPackets int + ackedBytes int + lostPackets int + lostBytes int +} + +func newTrafficStats(params trafficStatsParams) *trafficStats { + return &trafficStats{ + params: params, + } +} + +func (ts *trafficStats) Merge(rhs *trafficStats) { + if ts.minSendTime == 0 || rhs.minSendTime < ts.minSendTime { + ts.minSendTime = rhs.minSendTime + } + if rhs.maxSendTime > ts.maxSendTime { + ts.maxSendTime = rhs.maxSendTime + } + ts.sendDelta += rhs.sendDelta + ts.recvDelta += rhs.recvDelta + ts.ackedPackets += rhs.ackedPackets + ts.ackedBytes += rhs.ackedBytes + ts.lostPackets += rhs.lostPackets + ts.lostBytes += rhs.lostBytes +} + +func (ts *trafficStats) NumBytes() int { + return ts.ackedBytes + ts.lostBytes +} + +func (ts *trafficStats) Duration() int64 { + return ts.maxSendTime - ts.minSendTime +} + +func (ts *trafficStats) AcknowledgedBitrate() int64 { + duration := ts.Duration() + if duration == 0 { + return 0 + } + + ackedBitrate := float64(ts.ackedBytes) * 8 * 1e6 / float64(ts.Duration()) + return int64(ackedBitrate * ts.CapturedTrafficRatio()) +} + +func (ts *trafficStats) CapturedTrafficRatio() float64 { + if ts.recvDelta == 0 { + return 0.0 + } + + // apply a penalty for lost packets, + // the rationale being packet dropping is a strategy to relieve congestion + // and if they were not dropped, they would have increased queuing delay, + // as it is not possible to know the reason for the losses, + // apply a small penalty to receive delta aggregate to simulate those packets + // building up queuing delay. + return min(1.0, float64(ts.sendDelta)/float64(ts.recvDelta+ts.lossPenalty())) +} + +func (ts *trafficStats) WeightedLoss() float64 { + durationMicro := ts.Duration() + if time.Duration(durationMicro*1000) < ts.params.Config.MinDurationForLossValidity { + return 0.0 + } + + totalPackets := float64(ts.lostPackets + ts.ackedPackets) + pps := totalPackets * 1e6 / float64(durationMicro) + + // longer duration, i. e. more time resolution, lower pps is acceptable as the measurement is more stable + deltaDuration := time.Duration(durationMicro*1000) - ts.params.Config.BaseDuration + if deltaDuration < 0 { + deltaDuration = 0 + } + threshold := math.Exp(-deltaDuration.Seconds()) * float64(ts.params.Config.BasePPS) + if pps < threshold { + return 0.0 + } + + lossRatio := float64(0.0) + if totalPackets != 0 { + lossRatio = float64(ts.lostPackets) / totalPackets + } + + // Log10 is used to give higher weight for the same loss ratio at higher packet rates, + // for e.g. + // - 10% loss at 20 pps = 0.1 * log10(20) = 0.130 + // - 10% loss at 100 pps = 0.1 * log10(100) = 0.2 + // - 10% loss at 1000 pps = 0.1 * log10(1000) = 0.3 + return lossRatio * math.Log10(pps) +} + +func (ts *trafficStats) lossPenalty() int64 { + return int64(float64(ts.recvDelta) * ts.WeightedLoss() * ts.params.Config.LossPenaltyFactor) +} + +func (ts *trafficStats) MarshalLogObject(e zapcore.ObjectEncoder) error { + if ts == nil { + return nil + } + + e.AddInt64("minSendTime", ts.minSendTime) + e.AddInt64("maxSendTime", ts.maxSendTime) + duration := time.Duration(ts.Duration() * 1000) + e.AddDuration("duration", duration) + + e.AddInt("ackedPackets", ts.ackedPackets) + e.AddInt("ackedBytes", ts.ackedBytes) + e.AddInt("lostPackets", ts.lostPackets) + e.AddInt("lostBytes", ts.lostBytes) + + bitrate := float64(0) + if duration != 0 { + bitrate = float64(ts.ackedBytes*8) / duration.Seconds() + e.AddFloat64("bitrate", bitrate) + } + + e.AddInt64("sendDelta", ts.sendDelta) + e.AddInt64("recvDelta", ts.recvDelta) + e.AddInt64("groupDelay", ts.recvDelta-ts.sendDelta) + + totalPackets := ts.lostPackets + ts.ackedPackets + if duration != 0 { + e.AddFloat64("pps", float64(totalPackets)/duration.Seconds()) + } + if (totalPackets) != 0 { + e.AddFloat64("rawLoss", float64(ts.lostPackets)/float64(totalPackets)) + } + e.AddFloat64("weightedLoss", ts.WeightedLoss()) + e.AddInt64("lossPenalty", ts.lossPenalty()) + + capturedTrafficRatio := ts.CapturedTrafficRatio() + e.AddFloat64("capturedTrafficRatio", capturedTrafficRatio) + e.AddFloat64("estimatedAvailableChannelCapacity", bitrate*capturedTrafficRatio) + return nil +} diff --git a/livekit/pkg/sfu/bwe/sendsidebwe/twcc_feedback.go b/livekit/pkg/sfu/bwe/sendsidebwe/twcc_feedback.go new file mode 100644 index 0000000..f6e10e1 --- /dev/null +++ b/livekit/pkg/sfu/bwe/sendsidebwe/twcc_feedback.go @@ -0,0 +1,128 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sendsidebwe + +import ( + "errors" + "time" + + "github.com/livekit/protocol/logger" + "github.com/pion/rtcp" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------------ + +const ( + cOutlierReportFactor = 3 + cEstimatedFeedbackIntervalAlpha = float64(0.9) + + cReferenceTimeMask = (1 << 24) - 1 + cReferenceTimeResolution = 64 // 64 ms +) + +// ------------------------------------------------------ + +var ( + errFeedbackReportOutOfOrder = errors.New("feedback report out-of-order") +) + +// ------------------------------------------------------ + +type twccFeedbackParams struct { + Logger logger.Logger +} + +type twccFeedback struct { + params twccFeedbackParams + + lastFeedbackTime time.Time + estimatedFeedbackInterval time.Duration + numReports int + numReportsOutOfOrder int + + highestFeedbackCount uint8 + + cycles int64 + highestReferenceTime uint32 +} + +func newTWCCFeedback(params twccFeedbackParams) *twccFeedback { + return &twccFeedback{ + params: params, + } +} + +func (t *twccFeedback) ProcessReport(report *rtcp.TransportLayerCC, at time.Time) (int64, bool) { + t.numReports++ + if t.lastFeedbackTime.IsZero() { + t.lastFeedbackTime = at + t.highestReferenceTime = report.ReferenceTime + t.highestFeedbackCount = report.FbPktCount + return (t.cycles + int64(report.ReferenceTime)) * cReferenceTimeResolution * 1000, false + } + + isOutOfOrder := false + if (report.FbPktCount - t.highestFeedbackCount) > (1 << 7) { + t.numReportsOutOfOrder++ + isOutOfOrder = true + } + + // reference time wrap around handling + var referenceTime int64 + if (report.ReferenceTime-t.highestReferenceTime)&cReferenceTimeMask < (1 << 23) { + if report.ReferenceTime < t.highestReferenceTime { + t.cycles += (1 << 24) + } + t.highestReferenceTime = report.ReferenceTime + referenceTime = t.cycles + int64(report.ReferenceTime) + } else { + cycles := t.cycles + if report.ReferenceTime > t.highestReferenceTime && cycles >= (1<<24) { + cycles -= (1 << 24) + } + referenceTime = cycles + int64(report.ReferenceTime) + } + + if !isOutOfOrder { + sinceLast := at.Sub(t.lastFeedbackTime) + if t.estimatedFeedbackInterval == 0 { + t.estimatedFeedbackInterval = sinceLast + } else { + // filter out outliers from estimate + if sinceLast > t.estimatedFeedbackInterval/cOutlierReportFactor && sinceLast < cOutlierReportFactor*t.estimatedFeedbackInterval { + // smoothed version of inter feedback interval + t.estimatedFeedbackInterval = time.Duration(cEstimatedFeedbackIntervalAlpha*float64(t.estimatedFeedbackInterval) + (1.0-cEstimatedFeedbackIntervalAlpha)*float64(sinceLast)) + } + } + t.lastFeedbackTime = at + t.highestFeedbackCount = report.FbPktCount + } + + return referenceTime * cReferenceTimeResolution * 1000, isOutOfOrder +} + +func (t *twccFeedback) MarshalLogObject(e zapcore.ObjectEncoder) error { + if t == nil { + return nil + } + + e.AddTime("lastFeedbackTime", t.lastFeedbackTime) + e.AddDuration("estimatedFeedbackInterval", t.estimatedFeedbackInterval) + e.AddInt("numReports", t.numReports) + e.AddInt("numReportsOutOfOrder", t.numReportsOutOfOrder) + e.AddInt64("cycles", t.cycles/(1<<24)) + return nil +} diff --git a/livekit/pkg/sfu/ccutils/probe_regulator.go b/livekit/pkg/sfu/ccutils/probe_regulator.go new file mode 100644 index 0000000..ed0940b --- /dev/null +++ b/livekit/pkg/sfu/ccutils/probe_regulator.go @@ -0,0 +1,106 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ccutils + +import ( + "time" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +// ------------------------------------------------ + +type ProbeRegulatorConfig struct { + BaseInterval time.Duration `yaml:"base_interval,omitempty"` + BackoffFactor float64 `yaml:"backoff_factor,omitempty"` + MaxInterval time.Duration `yaml:"max_interval,omitempty"` + + MinDuration time.Duration `yaml:"min_duration,omitempty"` + MaxDuration time.Duration `yaml:"max_duration,omitempty"` + DurationIncreaseFactor float64 `yaml:"duration_increase_factor,omitempty"` +} + +var ( + DefaultProbeRegulatorConfig = ProbeRegulatorConfig{ + BaseInterval: 3 * time.Second, + BackoffFactor: 1.5, + MaxInterval: 2 * time.Minute, + + MinDuration: 200 * time.Millisecond, + MaxDuration: 20 * time.Second, + DurationIncreaseFactor: 1.5, + } +) + +// --------------------------------------------------------------------------- + +type ProbeRegulatorParams struct { + Config ProbeRegulatorConfig + Logger logger.Logger +} + +type ProbeRegulator struct { + params ProbeRegulatorParams + + probeInterval time.Duration + probeDuration time.Duration + nextProbeEarliestAt time.Time +} + +func NewProbeRegulator(params ProbeRegulatorParams) *ProbeRegulator { + return &ProbeRegulator{ + params: params, + probeInterval: params.Config.BaseInterval, + probeDuration: params.Config.MinDuration, + nextProbeEarliestAt: mono.Now(), + } +} + +func (p *ProbeRegulator) CanProbe() bool { + return mono.Now().After(p.nextProbeEarliestAt) +} + +func (p *ProbeRegulator) ProbeDuration() time.Duration { + return p.probeDuration +} + +func (p *ProbeRegulator) ProbeSignal(probeSignal ProbeSignal, baseTime time.Time) { + if probeSignal == ProbeSignalCongesting { + // wait longer till next probe + p.probeInterval = time.Duration(p.probeInterval.Seconds()*p.params.Config.BackoffFactor) * time.Second + if p.probeInterval > p.params.Config.MaxInterval { + p.probeInterval = p.params.Config.MaxInterval + } + + // revert back to starting with shortest probe + p.probeDuration = p.params.Config.MinDuration + } else { + // probe can be started again after minimal interval as previous congestion signal indicated congestion clearing + p.probeInterval = p.params.Config.BaseInterval + + // can do longer probe after a good probe + p.probeDuration = time.Duration(float64(p.probeDuration.Milliseconds())*p.params.Config.DurationIncreaseFactor) * time.Millisecond + if p.probeDuration > p.params.Config.MaxDuration { + p.probeDuration = p.params.Config.MaxDuration + } + } + + if baseTime.IsZero() { + p.nextProbeEarliestAt = mono.Now().Add(p.probeInterval) + } else { + p.nextProbeEarliestAt = baseTime.Add(p.probeInterval) + } +} diff --git a/livekit/pkg/sfu/ccutils/probe_signal.go b/livekit/pkg/sfu/ccutils/probe_signal.go new file mode 100644 index 0000000..22c3bf8 --- /dev/null +++ b/livekit/pkg/sfu/ccutils/probe_signal.go @@ -0,0 +1,40 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ccutils + +import "fmt" + +// ------------------------------------------------ + +type ProbeSignal int + +const ( + ProbeSignalInconclusive ProbeSignal = iota + ProbeSignalCongesting + ProbeSignalNotCongesting +) + +func (p ProbeSignal) String() string { + switch p { + case ProbeSignalInconclusive: + return "INCONCLUSIVE" + case ProbeSignalCongesting: + return "CONGESTING" + case ProbeSignalNotCongesting: + return "NOT_CONGESTING" + default: + return fmt.Sprintf("%d", int(p)) + } +} diff --git a/livekit/pkg/sfu/ccutils/prober.go b/livekit/pkg/sfu/ccutils/prober.go new file mode 100644 index 0000000..577042a --- /dev/null +++ b/livekit/pkg/sfu/ccutils/prober.go @@ -0,0 +1,596 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Design of Prober +// +// Probing is used to check for existence of excess channel capacity. +// This is especially useful in the downstream direction of SFU. +// SFU forwards audio/video streams from one or more publishers to +// all the subscribers. But, the downstream channel of a subscriber +// may not be big enough to carry all the streams. It is also a time +// varying quantity. +// +// When there is not enough capacity, some streams will be paused. +// To resume a stream, SFU would need to know that the channel has +// enough capacity. That's where probing comes in. When conditions +// are favorable, SFU can send probe packets so that the bandwidth +// estimator has more data to estimate available channel capacity +// better. +// NOTE: What defines `favorable conditions` is implementation dependent. +// +// There are two options for probing +// - Use padding only RTP packets: This one is preferable as +// probe rate can be controlled more tightly. +// - Resume a paused stream or forward a higher spatial layer: +// Have to find a stream at probing rate. Also, a stream could +// get a key frame unexpectedly boosting rate in the probing +// window. +// +// The strategy used depends on stream allocator implementation. +// This module can be used if the stream allocator decides to use +// padding only RTP packets for probing purposes. +// +// Implementation: +// There are a couple of options +// - Check prober in the forwarding path (pull from prober). +// This is preferred for scalability reasons. But, this +// suffers from not being able to probe when all streams +// are paused (could be due to downstream bandwidth +// constraints or the corresponding upstream tracks may +// have paused due to upstream bandwidth constraints). +// Another issue is not being able to have tight control on +// probing window boundary as the packet forwarding path +// may not have a packet to forward. But, it should not +// be a major concern as long as some stream(s) is/are +// forwarded as there should be a packet at least every +// 60 ms or so (forwarding only one stream at 15 fps). +// Usually, it will be serviced much more frequently when +// there are multiple streams getting forwarded. +// - Run it a go routine. But, that would have to wake up +// very often to prevent bunching up of probe +// packets. So, a scalability concern as there is one prober +// per subscriber peer connection. But, probe windows +// should be very short (of the order of 100s of ms). +// So, this approach might be fine. +// +// The implementation here follows the second approach of using a +// go routine. +// +// Pacing: +// ------ +// Ideally, the subscriber peer connection should have a pacer which +// trickles data out at the estimated channel capacity rate (and +// estimated channel capacity + probing rate when actively probing). +// +// But, there a few significant challenges +// 1. Pacer will require buffering of forwarded packets. That means +// more memory, more CPU (have to make copy of packets) and +// more latency in the media stream. +// 2. Scalability concern as SFU may be handling hundreds of +// subscriber peer connections and each one processing the pacing +// loop at 5ms interval will add up. +// +// So, this module assumes that pacing is inherently provided by the +// publishers for media streams. That is a reasonable assumption given +// that publishing clients will run their own pacer and pacing data out +// at a steady rate. +// +// A further assumption is that if there are multiple publishers for +// a subscriber peer connection, all the publishers are not pacing +// in sync, i.e. each publisher's pacer is completely independent +// and SFU will be receiving the media packets with a good spread and +// not clumped together. +// +// Given those assumptions, this module monitors media send rate and +// adjusts probing packet sends accordingly. Although the probing may +// have a high enough wake up frequency, it is for short windows. +// For example, probing at 5 Mbps for 1/2 second and sending 1000 byte +// probe per iteration will wake up every 1.6 ms. That is very high, +// but should last for 1/2 second or so. +// +// 5 Mbps over 1/2 second = 2.5 Mbps +// 2.5 Mbps = 312500 bytes = 313 probes at 1000 byte probes +// 313 probes over 1/2 second = 1.6 ms between probes +// +// A few things to note +// 1. When a probe cluster is added, the expected media rate is provided. +// So, the wake-up interval takes that into account. For example, +// if probing at 5 Mbps for 1/2 second and if 4 Mbps of it is expected +// to be provided by media traffic, the wake-up interval becomes 8 ms. +// 2. The amount of probing should actually be capped at some value to +// avoid too much self-induced congestion. It maybe something like 500 kbps. +// That will increase the wake-up interval to 16 ms in the above example. +// 3. In practice, the probing interval may also be shorter. Typically, +// it can be run for 2 - 3 RTTs to get a good measurement. For +// the longest hauls, RTT could be 250 ms or so leading to the probing +// window being long(ish). But, RTT should be much shorter especially if +// the subscriber peer connection of the client is able to connect to +// the nearest data center. +package ccutils + +import ( + "fmt" + "math" + "sync" + "time" + + "github.com/gammazero/deque" + "go.uber.org/atomic" + "go.uber.org/zap/zapcore" + + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +type ProberListener interface { + OnProbeClusterSwitch(info ProbeClusterInfo) + OnSendProbe(bytesToSend int) +} + +type ProberParams struct { + Listener ProberListener + Logger logger.Logger +} + +type Prober struct { + params ProberParams + + clusterId atomic.Uint32 + + clustersMu sync.RWMutex + clusters deque.Deque[*Cluster] + activeCluster *Cluster +} + +func NewProber(params ProberParams) *Prober { + p := &Prober{ + params: params, + } + p.clusters.SetBaseCap(2) + return p +} + +func (p *Prober) IsRunning() bool { + p.clustersMu.RLock() + defer p.clustersMu.RUnlock() + + return p.clusters.Len() > 0 +} + +func (p *Prober) Reset(info ProbeClusterInfo) { + p.clustersMu.Lock() + defer p.clustersMu.Unlock() + + if p.activeCluster != nil && p.activeCluster.Id() == info.Id { + p.activeCluster.MarkCompleted(info.Result) + p.params.Logger.Debugw("prober: resetting active cluster", "cluster", p.activeCluster) + } + + p.clusters.Clear() + p.activeCluster = nil +} + +func (p *Prober) AddCluster(mode ProbeClusterMode, pcg ProbeClusterGoal) ProbeClusterInfo { + if pcg.DesiredBps <= 0 { + return ProbeClusterInfoInvalid + } + + clusterId := ProbeClusterId(p.clusterId.Inc()) + cluster := newCluster(clusterId, mode, pcg, p.params.Listener) + p.params.Logger.Debugw("cluster added", "cluster", cluster) + + p.pushBackClusterAndMaybeStart(cluster) + + return cluster.Info() +} + +func (p *Prober) ProbesSent(bytesSent int) { + cluster := p.getFrontCluster() + if cluster == nil { + return + } + + cluster.ProbesSent(bytesSent) +} + +func (p *Prober) ClusterDone(info ProbeClusterInfo) { + cluster := p.getFrontCluster() + if cluster == nil { + return + } + + if cluster.Id() == info.Id { + cluster.MarkCompleted(info.Result) + p.params.Logger.Debugw("cluster done", "cluster", cluster) + p.popFrontCluster(cluster) + } +} + +func (p *Prober) GetActiveClusterId() ProbeClusterId { + p.clustersMu.RLock() + defer p.clustersMu.RUnlock() + + if p.activeCluster != nil { + return p.activeCluster.Id() + } + + return ProbeClusterIdInvalid +} + +func (p *Prober) getFrontCluster() *Cluster { + p.clustersMu.Lock() + defer p.clustersMu.Unlock() + + if p.activeCluster != nil { + return p.activeCluster + } + + if p.clusters.Len() == 0 { + p.activeCluster = nil + } else { + p.activeCluster = p.clusters.Front() + p.activeCluster.Start() + } + return p.activeCluster +} + +func (p *Prober) popFrontCluster(cluster *Cluster) { + p.clustersMu.Lock() + + if p.clusters.Len() == 0 { + p.activeCluster = nil + p.clustersMu.Unlock() + return + } + + if p.clusters.Front() == cluster { + p.clusters.PopFront() + } + + if cluster == p.activeCluster { + p.activeCluster = nil + } + + p.clustersMu.Unlock() +} + +func (p *Prober) pushBackClusterAndMaybeStart(cluster *Cluster) { + p.clustersMu.Lock() + p.clusters.PushBack(cluster) + + if p.clusters.Len() == 1 { + go p.run() + } + p.clustersMu.Unlock() +} + +func (p *Prober) run() { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for { + cluster := p.getFrontCluster() + if cluster == nil { + return + } + + sleepDuration := cluster.Process() + if sleepDuration == 0 { + p.popFrontCluster(cluster) + continue + } + + ticker.Reset(sleepDuration) + <-ticker.C + } +} + +// --------------------------------- + +type ProbeClusterId uint32 + +const ( + ProbeClusterIdInvalid ProbeClusterId = 0 + + // padding only packets are 255 bytes max + 20 byte header = 4 packets per probe, + // when not using padding only packets, this is a min and actual sent could be higher + cBytesPerProbe = 1100 + cSleepDuration = 20 * time.Millisecond + cSleepDurationMin = 10 * time.Millisecond +) + +// ----------------------------------- + +type ProbeClusterMode int + +const ( + ProbeClusterModeUniform ProbeClusterMode = iota + ProbeClusterModeLinearChirp +) + +func (p ProbeClusterMode) String() string { + switch p { + case ProbeClusterModeUniform: + return "UNIFORM" + case ProbeClusterModeLinearChirp: + return "LINEAR_CHIRP" + default: + return fmt.Sprintf("%d", int(p)) + } +} + +// --------------------------------------------------------------------------- + +type ProbeClusterGoal struct { + AvailableBandwidthBps int + ExpectedUsageBps int + DesiredBps int + Duration time.Duration + DesiredBytes int +} + +func (p ProbeClusterGoal) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddInt("AvailableBandwidthBps", p.AvailableBandwidthBps) + e.AddInt("ExpectedUsageBps", p.ExpectedUsageBps) + e.AddInt("DesiredBps", p.DesiredBps) + e.AddDuration("Duration", p.Duration) + e.AddInt("DesiredBytes", p.DesiredBytes) + return nil +} + +type ProbeClusterResult struct { + StartTime int64 + EndTime int64 + PacketsProbe int + BytesProbe int + PacketsNonProbePrimary int + BytesNonProbePrimary int + PacketsNonProbeRTX int + BytesNonProbeRTX int + IsCompleted bool +} + +func (p ProbeClusterResult) Bytes() int { + return p.BytesProbe + p.BytesNonProbePrimary + p.BytesNonProbeRTX +} + +func (p ProbeClusterResult) Duration() time.Duration { + return time.Duration(p.EndTime - p.StartTime) +} + +func (p ProbeClusterResult) Bitrate() float64 { + duration := p.Duration().Seconds() + if duration != 0 { + return float64(p.Bytes()*8) / duration + } + + return 0 +} + +func (p ProbeClusterResult) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddTime("StartTime", time.Unix(0, p.StartTime)) + e.AddTime("EndTime", time.Unix(0, p.EndTime)) + e.AddDuration("Duration", p.Duration()) + e.AddInt("PacketsProbe", p.PacketsProbe) + e.AddInt("BytesProbe", p.BytesProbe) + e.AddInt("PacketsNonProbePrimary", p.PacketsNonProbePrimary) + e.AddInt("BytesNonProbePrimary", p.BytesNonProbePrimary) + e.AddInt("PacketsNonProbeRTX", p.PacketsNonProbeRTX) + e.AddInt("BytesNonProbeRTX", p.BytesNonProbeRTX) + e.AddInt("Bytes", p.Bytes()) + e.AddFloat64("Bitrate", p.Bitrate()) + e.AddBool("IsCompleted", p.IsCompleted) + return nil +} + +type ProbeClusterInfo struct { + Id ProbeClusterId + CreatedAt time.Time + Goal ProbeClusterGoal + Result ProbeClusterResult +} + +var ( + ProbeClusterInfoInvalid = ProbeClusterInfo{Id: ProbeClusterIdInvalid} +) + +func (p ProbeClusterInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddUint32("Id", uint32(p.Id)) + e.AddTime("CreatedAt", p.CreatedAt) + e.AddObject("Goal", p.Goal) + e.AddObject("Result", p.Result) + return nil +} + +// --------------------------------------------------------------------------- + +type bucket struct { + expectedElapsedDuration time.Duration + expectedProbeBytesSent int +} + +func (b bucket) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddDuration("expectedElapsedDuration", b.expectedElapsedDuration) + e.AddInt("expectedProbesBytesSent", b.expectedProbeBytesSent) + return nil +} + +// --------------------------------------------------------------------------- + +type Cluster struct { + lock sync.RWMutex + + info ProbeClusterInfo + mode ProbeClusterMode + listener ProberListener + + baseSleepDuration time.Duration + buckets []bucket + bucketIdx int + + probeBytesSent int + + startTime time.Time + isComplete bool +} + +func newCluster(id ProbeClusterId, mode ProbeClusterMode, pcg ProbeClusterGoal, listener ProberListener) *Cluster { + c := &Cluster{ + mode: mode, + info: ProbeClusterInfo{ + Id: id, + CreatedAt: mono.Now(), + Goal: pcg, + }, + listener: listener, + } + c.initProbes() + return c +} + +func (c *Cluster) initProbes() { + c.info.Goal.DesiredBytes = int(math.Round(float64(c.info.Goal.DesiredBps)*c.info.Goal.Duration.Seconds()/8 + 0.5)) + + numBuckets := int(math.Round(c.info.Goal.Duration.Seconds()/cSleepDuration.Seconds() + 0.5)) + if numBuckets < 1 { + numBuckets = 1 + } + numIntervals := numBuckets + + // for linear chirp, group intervals with decreasing duration, i.e. incraasing bitrate, + // by aiming to send same number of bytes in each interval, as intervals get shorter, the bitrate is higher + if c.mode == ProbeClusterModeLinearChirp { + sum := 0 + i := 1 + for { + sum += i + if sum >= numBuckets { + break + } + i++ + } + numBuckets = i + numIntervals = sum + } + + c.baseSleepDuration = c.info.Goal.Duration / time.Duration(numIntervals) + if c.baseSleepDuration < cSleepDurationMin { + c.baseSleepDuration = cSleepDurationMin + } + + numIntervals = int(math.Round(c.info.Goal.Duration.Seconds()/c.baseSleepDuration.Seconds() + 0.5)) + desiredProbeBytesPerInterval := int(math.Round(((c.info.Goal.Duration.Seconds()*float64(c.info.Goal.DesiredBps-c.info.Goal.ExpectedUsageBps)/8)+float64(numIntervals)-1)/float64(numIntervals) + 0.5)) + + c.buckets = make([]bucket, numBuckets) + for i := 0; i < numBuckets; i++ { + switch c.mode { + case ProbeClusterModeUniform: + c.buckets[i] = bucket{ + expectedElapsedDuration: c.baseSleepDuration, + } + + case ProbeClusterModeLinearChirp: + c.buckets[i] = bucket{ + expectedElapsedDuration: time.Duration(numBuckets-i) * c.baseSleepDuration, + } + } + if i > 0 { + c.buckets[i].expectedElapsedDuration += c.buckets[i-1].expectedElapsedDuration + } + c.buckets[i].expectedProbeBytesSent = (i + 1) * desiredProbeBytesPerInterval + } +} + +func (c *Cluster) Start() { + if c.listener != nil { + c.listener.OnProbeClusterSwitch(c.info) + } +} + +func (c *Cluster) Id() ProbeClusterId { + return c.info.Id +} + +func (c *Cluster) Info() ProbeClusterInfo { + c.lock.RLock() + defer c.lock.RUnlock() + + return c.info +} + +func (c *Cluster) ProbesSent(bytesSent int) { + c.lock.Lock() + defer c.lock.Unlock() + + c.probeBytesSent += bytesSent +} + +func (c *Cluster) MarkCompleted(result ProbeClusterResult) { + c.lock.Lock() + defer c.lock.Unlock() + + c.isComplete = true + c.info.Result = result +} + +func (c *Cluster) Process() time.Duration { + c.lock.Lock() + if c.isComplete { + c.lock.Unlock() + return 0 + } + + bytesToSend := 0 + if c.startTime.IsZero() { + c.startTime = mono.Now() + bytesToSend = cBytesPerProbe + } else { + sinceStart := time.Since(c.startTime) + if sinceStart > c.buckets[c.bucketIdx].expectedElapsedDuration { + c.bucketIdx++ + overflow := false + if c.bucketIdx >= len(c.buckets) { + // when overflowing, repeat the last bucket + c.bucketIdx = len(c.buckets) - 1 + overflow = true + } + if c.buckets[c.bucketIdx].expectedProbeBytesSent > c.probeBytesSent || overflow { + bytesToSend = max(cBytesPerProbe, c.buckets[c.bucketIdx].expectedProbeBytesSent-c.probeBytesSent) + } + } + } + c.lock.Unlock() + + if bytesToSend != 0 && c.listener != nil { + c.listener.OnSendProbe(bytesToSend) + } + + return cSleepDurationMin +} + +func (c *Cluster) MarshalLogObject(e zapcore.ObjectEncoder) error { + if c != nil { + e.AddString("mode", c.mode.String()) + e.AddObject("info", c.info) + e.AddDuration("baseSleepDuration", c.baseSleepDuration) + e.AddInt("numBuckets", len(c.buckets)) + e.AddInt("bucketIdx", c.bucketIdx) + e.AddInt("probeBytesSent", c.probeBytesSent) + e.AddTime("startTime", c.startTime) + e.AddDuration("elapsed", time.Since(c.startTime)) + e.AddBool("isComplete", c.isComplete) + } + return nil +} + +// ---------------------------------------------------------------------- diff --git a/livekit/pkg/sfu/ccutils/trenddetector.go b/livekit/pkg/sfu/ccutils/trenddetector.go new file mode 100644 index 0000000..c3f8ff6 --- /dev/null +++ b/livekit/pkg/sfu/ccutils/trenddetector.go @@ -0,0 +1,274 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ccutils + +import ( + "fmt" + "time" + + "github.com/livekit/protocol/logger" + "go.uber.org/zap/zapcore" +) + +// ------------------------------------------------ + +type TrendDirection int + +const ( + TrendDirectionInconclusive TrendDirection = iota + TrendDirectionUpward + TrendDirectionDownward +) + +func (t TrendDirection) String() string { + switch t { + case TrendDirectionInconclusive: + return "INCONCLUSIVE" + case TrendDirectionUpward: + return "UPWARD" + case TrendDirectionDownward: + return "DOWNWARD" + default: + return fmt.Sprintf("%d", int(t)) + } +} + +// ------------------------------------------------ + +type trendDetectorNumber interface { + int64 | float64 +} + +// ------------------------------------------------ + +type trendDetectorSample[T trendDetectorNumber] struct { + value T + at time.Time +} + +type trendDetectorSampleElapsed[T trendDetectorNumber] struct { + value T + sinceFirst time.Duration +} + +func (t trendDetectorSampleElapsed[T]) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddFloat64("value", float64(t.value)) + e.AddDuration("sinceFirst", t.sinceFirst) + return nil +} + +// ------------------------------------------------ + +type TrendDetectorConfig struct { + RequiredSamples int `yaml:"required_samples,omitempty"` + RequiredSamplesMin int `yaml:"required_samples_min,omitempty"` + DownwardTrendThreshold float64 `yaml:"downward_trend_threshold,omitempty"` + DownwardTrendMaxWait time.Duration `yaml:"downward_trend_max_wait,omitempty"` + CollapseThreshold time.Duration `yaml:"collapse_threshold,omitempty"` + ValidityWindow time.Duration `yaml:"validity_window,omitempty"` +} + +// ------------------------------------------------ + +type TrendDetectorParams struct { + Name string + Logger logger.Logger + Config TrendDetectorConfig +} + +type TrendDetector[T trendDetectorNumber] struct { + params TrendDetectorParams + + startTime time.Time + numSamples int + samples []trendDetectorSample[T] + lowestValue T + highestValue T + + direction TrendDirection +} + +func NewTrendDetector[T trendDetectorNumber](params TrendDetectorParams) *TrendDetector[T] { + return &TrendDetector[T]{ + params: params, + startTime: time.Now(), + direction: TrendDirectionInconclusive, + } +} + +func (t *TrendDetector[T]) Seed(value T) { + if len(t.samples) != 0 { + return + } + + t.samples = append(t.samples, trendDetectorSample[T]{value: value, at: time.Now()}) +} + +func (t *TrendDetector[T]) AddValue(value T) { + t.numSamples++ + if t.lowestValue == 0 || value < t.lowestValue { + t.lowestValue = value + } + if value > t.highestValue { + t.highestValue = value + } + + // Ignore duplicate values in collapse window. + // + // Bandwidth estimate is received periodically. If the estimate does not change, it will be repeated. + // When there is congestion, there are several estimates received with decreasing values. + // + // Using a sliding window, collapsing repeated values and waiting for falling trend to ensure that + // the reaction is not too fast, i. e. reacting to falling values too quick could mean a lot of re-allocation + // resulting in layer switches, key frames and more congestion. + // + // But, on the flip side, estimate could fall once or twice within a sliding window and stay there. + // In those cases, using a collapse window to record a value even if it is duplicate. By doing that, + // a trend could be detected eventually. It will be delayed, but that is fine with slow changing estimates. + var lastSample *trendDetectorSample[T] + if len(t.samples) != 0 { + lastSample = &t.samples[len(t.samples)-1] + } + if lastSample != nil && lastSample.value == value && t.params.Config.CollapseThreshold > 0 && time.Since(lastSample.at) < t.params.Config.CollapseThreshold { + return + } + + t.samples = append(t.samples, trendDetectorSample[T]{value: value, at: time.Now()}) + t.prune() + t.updateDirection() +} + +func (t *TrendDetector[T]) GetLowest() T { + return t.lowestValue +} + +func (t *TrendDetector[T]) GetHighest() T { + return t.highestValue +} + +func (t *TrendDetector[T]) GetDirection() TrendDirection { + return t.direction +} + +func (t *TrendDetector[T]) HasEnoughSamples() bool { + return t.numSamples >= t.params.Config.RequiredSamples +} + +func (t *TrendDetector[T]) MarshalLogObject(e zapcore.ObjectEncoder) error { + if t == nil { + return nil + } + + var samples []trendDetectorSampleElapsed[T] + if len(t.samples) > 0 { + firstTime := t.samples[0].at + for _, sample := range t.samples { + samples = append(samples, trendDetectorSampleElapsed[T]{sample.value, sample.at.Sub(firstTime)}) + } + } + + e.AddString("name", t.params.Name) + e.AddTime("startTime", t.startTime) + e.AddDuration("elapsed", time.Since(t.startTime)) + e.AddInt("numSamples", t.numSamples) + e.AddArray("samples", logger.ObjectSlice(samples)) + e.AddFloat64("lowestValue", float64(t.lowestValue)) + e.AddFloat64("highestValue", float64(t.highestValue)) + e.AddFloat64("kendallsTau", t.kendallsTau()) + e.AddString("direction", t.direction.String()) + return nil +} + +func (t *TrendDetector[T]) prune() { + // prune based on a few rules + + // 1. If there are more than required samples + if len(t.samples) > t.params.Config.RequiredSamples { + t.samples = t.samples[len(t.samples)-t.params.Config.RequiredSamples:] + } + + // 2. drop samples that are too old + if len(t.samples) != 0 && t.params.Config.ValidityWindow > 0 { + cutoffTime := time.Now().Add(-t.params.Config.ValidityWindow) + cutoffIndex := -1 + for i := 0; i < len(t.samples); i++ { + if t.samples[i].at.After(cutoffTime) { + cutoffIndex = i + break + } + } + if cutoffIndex >= 0 { + t.samples = t.samples[cutoffIndex:] + } + } + + // 3. collapse same values at the front to just the last of those samples + if len(t.samples) != 0 { + cutoffIndex := -1 + firstValue := t.samples[0].value + for i := 1; i < len(t.samples); i++ { + if t.samples[i].value != firstValue { + cutoffIndex = i - 1 + break + } + } + + if cutoffIndex >= 0 { + t.samples = t.samples[cutoffIndex:] + } else { + // all values are the same, just keep the last one + t.samples = t.samples[len(t.samples)-1:] + } + } +} + +func (t *TrendDetector[T]) updateDirection() { + if len(t.samples) < t.params.Config.RequiredSamplesMin { + t.direction = TrendDirectionInconclusive + return + } + + // using Kendall's Tau to find trend + kt := t.kendallsTau() + + t.direction = TrendDirectionInconclusive + switch { + case kt > 0 && len(t.samples) >= t.params.Config.RequiredSamples: + t.direction = TrendDirectionUpward + case kt < t.params.Config.DownwardTrendThreshold && (len(t.samples) >= t.params.Config.RequiredSamples || t.samples[len(t.samples)-1].at.Sub(t.samples[0].at) > t.params.Config.DownwardTrendMaxWait): + t.direction = TrendDirectionDownward + } +} + +func (t *TrendDetector[T]) kendallsTau() float64 { + concordantPairs := 0 + discordantPairs := 0 + + for i := 0; i < len(t.samples)-1; i++ { + for j := i + 1; j < len(t.samples); j++ { + if t.samples[i].value < t.samples[j].value { + concordantPairs++ + } else if t.samples[i].value > t.samples[j].value { + discordantPairs++ + } + } + } + + if (concordantPairs + discordantPairs) == 0 { + return 0.0 + } + + return (float64(concordantPairs) - float64(discordantPairs)) / (float64(concordantPairs) + float64(discordantPairs)) +} diff --git a/livekit/pkg/sfu/codecmunger/codecmunger.go b/livekit/pkg/sfu/codecmunger/codecmunger.go new file mode 100644 index 0000000..413af96 --- /dev/null +++ b/livekit/pkg/sfu/codecmunger/codecmunger.go @@ -0,0 +1,39 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package codecmunger + +import ( + "errors" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" +) + +var ( + ErrNotVP8 = errors.New("not VP8") + ErrOutOfOrderVP8PictureIdCacheMiss = errors.New("out-of-order VP8 picture id not found in cache") + ErrFilteredVP8TemporalLayer = errors.New("filtered VP8 temporal layer") +) + +type CodecMunger interface { + GetState() any + SeedState(state any) + + SetLast(extPkt *buffer.ExtPacket) + UpdateOffsets(extPkt *buffer.ExtPacket) + + UpdateAndGet(extPkt *buffer.ExtPacket, snOutOfOrder bool, snHasGap bool, maxTemporal int32) (int, []byte, error) + + UpdateAndGetPadding(newPicture bool) ([]byte, error) +} diff --git a/livekit/pkg/sfu/codecmunger/null.go b/livekit/pkg/sfu/codecmunger/null.go new file mode 100644 index 0000000..616786a --- /dev/null +++ b/livekit/pkg/sfu/codecmunger/null.go @@ -0,0 +1,54 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package codecmunger + +import ( + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" +) + +type Null struct { + seededState any +} + +func NewNull(_logger logger.Logger) *Null { + return &Null{} +} + +func (n *Null) GetState() any { + return nil +} + +func (n *Null) SeedState(state any) { + n.seededState = state +} + +func (n *Null) GetSeededState() any { + return n.seededState +} + +func (n *Null) SetLast(_extPkt *buffer.ExtPacket) { +} + +func (n *Null) UpdateOffsets(_extPkt *buffer.ExtPacket) { +} + +func (n *Null) UpdateAndGet(_extPkt *buffer.ExtPacket, snOutOfOrder bool, snHasGap bool, maxTemporal int32) (int, []byte, error) { + return 0, nil, nil +} + +func (n *Null) UpdateAndGetPadding(newPicture bool) ([]byte, error) { + return nil, nil +} diff --git a/livekit/pkg/sfu/codecmunger/vp8.go b/livekit/pkg/sfu/codecmunger/vp8.go new file mode 100644 index 0000000..71eace0 --- /dev/null +++ b/livekit/pkg/sfu/codecmunger/vp8.go @@ -0,0 +1,490 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package codecmunger + +import ( + "github.com/elliotchance/orderedmap/v2" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" +) + +const ( + missingPictureIdsThreshold = 50 + droppedPictureIdsThreshold = 20 + exemptedPictureIdsThreshold = 20 +) + +// ----------------------------------------------------------- + +type VP8 struct { + logger logger.Logger + + pictureIdWrapHandler VP8PictureIdWrapHandler + extLastPictureId int32 + pictureIdOffset int32 + pictureIdUsed bool + lastTl0PicIdx uint8 + tl0PicIdxOffset uint8 + tl0PicIdxUsed bool + tidUsed bool + lastKeyIdx uint8 + keyIdxOffset uint8 + keyIdxUsed bool + + missingPictureIds *orderedmap.OrderedMap[int32, int32] + droppedPictureIds *orderedmap.OrderedMap[int32, bool] + exemptedPictureIds *orderedmap.OrderedMap[int32, bool] +} + +func NewVP8(logger logger.Logger) *VP8 { + return &VP8{ + logger: logger, + missingPictureIds: orderedmap.NewOrderedMap[int32, int32](), + droppedPictureIds: orderedmap.NewOrderedMap[int32, bool](), + exemptedPictureIds: orderedmap.NewOrderedMap[int32, bool](), + } +} + +func NewVP8FromOther(cm CodecMunger, logger logger.Logger) *VP8 { + v := NewVP8(logger) + switch cm := cm.(type) { + case *Null: + v.SeedState(cm.GetSeededState()) + case *VP8: + v.SeedState(cm.GetState()) + } + return v +} + +func (v *VP8) GetState() any { + return &livekit.VP8MungerState{ + ExtLastPictureId: v.extLastPictureId, + PictureIdUsed: v.pictureIdUsed, + LastTl0PicIdx: uint32(v.lastTl0PicIdx), + Tl0PicIdxUsed: v.tl0PicIdxUsed, + TidUsed: v.tidUsed, + LastKeyIdx: uint32(v.lastKeyIdx), + KeyIdxUsed: v.keyIdxUsed, + } +} + +func (v *VP8) SeedState(seed any) { + var state *livekit.VP8MungerState + switch cm := seed.(type) { + case *livekit.RTPForwarderState_Vp8Munger: + state = cm.Vp8Munger + case *livekit.VP8MungerState: + state = cm + } + if state != nil { + v.extLastPictureId = state.ExtLastPictureId + v.pictureIdUsed = state.PictureIdUsed + v.lastTl0PicIdx = uint8(state.LastTl0PicIdx) + v.tl0PicIdxUsed = state.Tl0PicIdxUsed + v.tidUsed = state.TidUsed + v.lastKeyIdx = uint8(state.LastKeyIdx) + v.keyIdxUsed = state.KeyIdxUsed + } +} + +func (v *VP8) SetLast(extPkt *buffer.ExtPacket) { + vp8, ok := extPkt.Payload.(buffer.VP8) + if !ok { + return + } + + v.pictureIdUsed = vp8.I + if v.pictureIdUsed { + v.pictureIdWrapHandler.Init(int32(vp8.PictureID)-1, vp8.M) + v.extLastPictureId = int32(vp8.PictureID) + } + + v.tl0PicIdxUsed = vp8.L + if v.tl0PicIdxUsed { + v.lastTl0PicIdx = vp8.TL0PICIDX + } + + v.tidUsed = vp8.T + + v.keyIdxUsed = vp8.K + if v.keyIdxUsed { + v.lastKeyIdx = vp8.KEYIDX + } +} + +func (v *VP8) UpdateOffsets(extPkt *buffer.ExtPacket) { + vp8, ok := extPkt.Payload.(buffer.VP8) + if !ok { + return + } + + if v.pictureIdUsed { + v.pictureIdWrapHandler.Init(int32(vp8.PictureID)-1, vp8.M) + v.pictureIdOffset = int32(vp8.PictureID) - v.extLastPictureId - 1 + } + + if v.tl0PicIdxUsed { + v.tl0PicIdxOffset = vp8.TL0PICIDX - v.lastTl0PicIdx - 1 + } + + if v.keyIdxUsed { + v.keyIdxOffset = (vp8.KEYIDX - v.lastKeyIdx - 1) & 0x1f + } + + // clear picture id caches on layer switch + v.missingPictureIds = orderedmap.NewOrderedMap[int32, int32]() + v.droppedPictureIds = orderedmap.NewOrderedMap[int32, bool]() + v.exemptedPictureIds = orderedmap.NewOrderedMap[int32, bool]() +} + +func (v *VP8) UpdateAndGet(extPkt *buffer.ExtPacket, snOutOfOrder bool, snHasGap bool, maxTemporalLayer int32) (int, []byte, error) { + vp8, ok := extPkt.Payload.(buffer.VP8) + if !ok { + return 0, nil, ErrNotVP8 + } + + extPictureId := v.pictureIdWrapHandler.Unwrap(vp8.PictureID, vp8.M) + + // if out-of-order, look up missing picture id cache + if snOutOfOrder { + pictureIdOffset, ok := v.missingPictureIds.Get(extPictureId) + if !ok { + return 0, nil, ErrOutOfOrderVP8PictureIdCacheMiss + } + + // the out-of-order picture id cannot be deleted from the cache + // as there could more than one packet in a picture and more + // than one packet of a picture could come out-of-order. + // To prevent picture id cache from growing, it is truncated + // when it reaches a certain size. + + mungedPictureId := uint16((extPictureId - pictureIdOffset) & 0x7fff) + vp8Packet := &buffer.VP8{ + FirstByte: vp8.FirstByte, + I: vp8.I, + M: mungedPictureId > 127, + PictureID: mungedPictureId, + L: vp8.L, + TL0PICIDX: vp8.TL0PICIDX - v.tl0PicIdxOffset, + T: vp8.T, + TID: vp8.TID, + Y: vp8.Y, + K: vp8.K, + KEYIDX: vp8.KEYIDX - v.keyIdxOffset, + IsKeyFrame: vp8.IsKeyFrame, + HeaderSize: vp8.HeaderSize + buffer.VPxPictureIdSizeDiff(mungedPictureId > 127, vp8.M), + } + vp8HeaderBytes, err := vp8Packet.Marshal() + if err != nil { + return 0, nil, err + } + return vp8.HeaderSize, vp8HeaderBytes, nil + } + + prevMaxPictureId := v.pictureIdWrapHandler.MaxPictureId() + v.pictureIdWrapHandler.UpdateMaxPictureId(extPictureId, vp8.M) + + // if there is a gap in sequence number, record possible pictures that + // the missing packets can belong to in missing picture id cache. + // The missing picture cache should contain the previous picture id + // and the current picture id and all the intervening pictures. + // This is to handle a scenario as follows + // o Packet 10 -> Picture ID 10 + // o Packet 11 -> missing + // o Packet 12 -> Picture ID 11 + // In this case, Packet 11 could belong to either Picture ID 10 (last packet of that picture) + // or Picture ID 11 (first packet of the current picture). Although in this simple case, + // it is possible to deduce that (for example by looking at previous packet's RTP marker + // and check if that was the last packet of Picture 10), it could get complicated when + // the gap is larger. + if snHasGap { + for lostPictureId := prevMaxPictureId; lostPictureId <= extPictureId; lostPictureId++ { + // Record missing only if picture id was not dropped. This is to avoid a subsequent packet of dropped frame going through. + // A sequence like this + // o Packet 10 - Picture 11 - TID that should be dropped + // o Packet 11 - missing - belongs to Picture 11 still + // o Packet 12 - Picture 12 - will be reported as GAP, so missing picture id mapping will be set up for Picture 11 also. + // o Next packet - Packet 11 - this will use the wrong offset from missing pictures cache + _, ok := v.droppedPictureIds.Get(lostPictureId) + if !ok { + v.missingPictureIds.Set(lostPictureId, v.pictureIdOffset) + } + } + + // trim cache if necessary + for v.missingPictureIds.Len() > missingPictureIdsThreshold { + el := v.missingPictureIds.Front() + v.missingPictureIds.Delete(el.Key) + } + + // if there is a gap, packet is forwarded irrespective of temporal layer as it cannot be determined + // which layer the missing packets belong to. A layer could have multiple packets. So, keep track + // of pictures that are forwarded even though they will be filtered out based on temporal layer + // requirements. That allows forwarding of the complete picture. + if extPkt.Temporal > maxTemporalLayer { + v.exemptedPictureIds.Set(extPictureId, true) + // trim cache if necessary + for v.exemptedPictureIds.Len() > exemptedPictureIdsThreshold { + el := v.exemptedPictureIds.Front() + v.exemptedPictureIds.Delete(el.Key) + } + } + } else { + if extPkt.Temporal > maxTemporalLayer { + // drop only if not exempted + _, ok := v.exemptedPictureIds.Get(extPictureId) + if !ok { + // adjust only once per picture as a picture could have multiple packets + if vp8.I && prevMaxPictureId != extPictureId { + // keep track of dropped picture ids so that they do not get into the missing picture cache + v.droppedPictureIds.Set(extPictureId, true) + // trim cache if necessary + for v.droppedPictureIds.Len() > droppedPictureIdsThreshold { + el := v.droppedPictureIds.Front() + v.droppedPictureIds.Delete(el.Key) + } + + v.pictureIdOffset += 1 + } + return 0, nil, ErrFilteredVP8TemporalLayer + } + } + } + + // in-order incoming sequence number, may or may not be contiguous. + // In the case of loss (i.e. incoming sequence number is not contiguous), + // forward even if it is a filtered layer. With temporal scalability, + // it is unclear if the current packet should be dropped if it is not + // contiguous. Hence, forward anything that is not contiguous. + // Reference: http://www.rtcbits.com/2017/04/howto-implement-temporal-scalability.html + extMungedPictureId := extPictureId - v.pictureIdOffset + mungedPictureId := uint16(extMungedPictureId & 0x7fff) + mungedTl0PicIdx := vp8.TL0PICIDX - v.tl0PicIdxOffset + mungedKeyIdx := (vp8.KEYIDX - v.keyIdxOffset) & 0x1f + + v.extLastPictureId = extMungedPictureId + v.lastTl0PicIdx = mungedTl0PicIdx + v.lastKeyIdx = mungedKeyIdx + + vp8Packet := &buffer.VP8{ + FirstByte: vp8.FirstByte, + I: vp8.I, + M: mungedPictureId > 127, + PictureID: mungedPictureId, + L: vp8.L, + TL0PICIDX: mungedTl0PicIdx, + T: vp8.T, + TID: vp8.TID, + Y: vp8.Y, + K: vp8.K, + KEYIDX: mungedKeyIdx, + IsKeyFrame: vp8.IsKeyFrame, + HeaderSize: vp8.HeaderSize + buffer.VPxPictureIdSizeDiff(mungedPictureId > 127, vp8.M), + } + vp8HeaderBytes, err := vp8Packet.Marshal() + if err != nil { + return 0, nil, err + } + return vp8.HeaderSize, vp8HeaderBytes, nil +} + +func (v *VP8) UpdateAndGetPadding(newPicture bool) ([]byte, error) { + offset := 0 + if newPicture { + offset = 1 + } + + headerSize := 1 + if v.pictureIdUsed || v.tl0PicIdxUsed || v.tidUsed || v.keyIdxUsed { + headerSize += 1 + } + + extPictureId := v.extLastPictureId + if v.pictureIdUsed { + extPictureId = v.extLastPictureId + int32(offset) + v.extLastPictureId = extPictureId + v.pictureIdOffset -= int32(offset) + if (extPictureId & 0x7fff) > 127 { + headerSize += 2 + } else { + headerSize += 1 + } + } + pictureId := uint16(extPictureId & 0x7fff) + + tl0PicIdx := uint8(0) + if v.tl0PicIdxUsed { + tl0PicIdx = v.lastTl0PicIdx + uint8(offset) + v.lastTl0PicIdx = tl0PicIdx + v.tl0PicIdxOffset -= uint8(offset) + headerSize += 1 + } + + if v.tidUsed || v.keyIdxUsed { + headerSize += 1 + } + + keyIdx := uint8(0) + if v.keyIdxUsed { + keyIdx = (v.lastKeyIdx + uint8(offset)) & 0x1f + v.lastKeyIdx = keyIdx + v.keyIdxOffset -= uint8(offset) + } + + vp8Packet := &buffer.VP8{ + FirstByte: 0x10, // partition 0, start of VP8 Partition, reference frame + I: v.pictureIdUsed, + M: pictureId > 127, + PictureID: pictureId, + L: v.tl0PicIdxUsed, + TL0PICIDX: tl0PicIdx, + T: v.tidUsed, + TID: 0, + Y: true, + K: v.keyIdxUsed, + KEYIDX: keyIdx, + IsKeyFrame: true, + HeaderSize: headerSize, + } + return vp8Packet.Marshal() +} + +// for testing only +func (v *VP8) PictureIdOffset(extPictureId int32) (int32, bool) { + return v.missingPictureIds.Get(extPictureId) +} + +// ----------------------------- + +// VP8PictureIdWrapHandler +func isWrapping7Bit(val1 int32, val2 int32) bool { + return val2 < val1 && (val1-val2) > (1<<6) +} + +func isWrapping15Bit(val1 int32, val2 int32) bool { + return val2 < val1 && (val1-val2) > (1<<14) +} + +type VP8PictureIdWrapHandler struct { + maxPictureId int32 + maxMBit bool + totalWrap int32 + lastWrap int32 +} + +func (v *VP8PictureIdWrapHandler) Init(extPictureId int32, mBit bool) { + v.maxPictureId = extPictureId + v.maxMBit = mBit + v.totalWrap = 0 + v.lastWrap = 0 +} + +func (v *VP8PictureIdWrapHandler) MaxPictureId() int32 { + return v.maxPictureId +} + +// unwrap picture id and update the maxPictureId. return unwrapped value +func (v *VP8PictureIdWrapHandler) Unwrap(pictureId uint16, mBit bool) int32 { + // + // VP8 Picture ID is specified very flexibly. + // + // Reference: https://datatracker.ietf.org/doc/html/draft-ietf-payload-vp8 + // + // Quoting from the RFC + // ---------------------------- + // PictureID: 7 or 15 bits (shown left and right, respectively, in + // Figure 2) not including the M bit. This is a running index of + // the frames, which MAY start at a random value, MUST increase by + // 1 for each subsequent frame, and MUST wrap to 0 after reaching + // the maximum ID (all bits set). The 7 or 15 bits of the + // PictureID go from most significant to least significant, + // beginning with the first bit after the M bit. The sender + // chooses a 7 or 15 bit index and sets the M bit accordingly. + // The receiver MUST NOT assume that the number of bits in + // PictureID stay the same through the session. Having sent a + // 7-bit PictureID with all bits set to 1, the sender may either + // wrap the PictureID to 0, or extend to 15 bits and continue + // incrementing + // ---------------------------- + // + // While in practice, senders may not switch between modes indiscriminately, + // it is possible that small picture ids are sent in 7 bits and then switch + // to 15 bits. But, to ensure correctness, this code keeps track of how much + // quantity has wrapped and uses that to figure out if the incoming picture id + // is newer OR out-of-order. + // + maxPictureId := v.maxPictureId + // maxPictureId can be -1 at the start + if maxPictureId > 0 { + if v.maxMBit { + maxPictureId = v.maxPictureId & 0x7fff + } else { + maxPictureId = v.maxPictureId & 0x7f + } + } + + var newPictureId int32 + if mBit { + newPictureId = int32(pictureId & 0x7fff) + } else { + newPictureId = int32(pictureId & 0x7f) + } + + // + // if the new picture id is too far ahead of max, i.e. more than half of last wrap, + // it is out-of-order, unwrap backwards + // + if v.totalWrap > 0 { + if (v.maxPictureId + (v.lastWrap >> 1)) < (newPictureId + v.totalWrap) { + return newPictureId + v.totalWrap - v.lastWrap + } + } + + // + // check for wrap around based on mode of previous picture id. + // There are three cases here + // 1. Wrapping from 15-bit -> 8-bit (32767 -> 0) + // 2. Wrapping from 15-bit -> 15-bit (32767 -> 0) + // 3. Wrapping from 8-bit -> 8-bit (127 -> 0) + // In all cases, looking at the mode of previous picture id will + // ensure that we are calculating the wrap properly. + // + wrap := int32(0) + if v.maxMBit { + if isWrapping15Bit(maxPictureId, newPictureId) { + wrap = 1 << 15 + } + } else { + if isWrapping7Bit(maxPictureId, newPictureId) { + wrap = 1 << 7 + } + } + + v.totalWrap += wrap + if wrap != 0 { + v.lastWrap = wrap + } + newPictureId += v.totalWrap + + return newPictureId +} + +func (v *VP8PictureIdWrapHandler) UpdateMaxPictureId(extPictureId int32, mBit bool) { + v.maxPictureId = extPictureId + v.maxMBit = mBit +} diff --git a/livekit/pkg/sfu/codecmunger/vp8_test.go b/livekit/pkg/sfu/codecmunger/vp8_test.go new file mode 100644 index 0000000..31d33dd --- /dev/null +++ b/livekit/pkg/sfu/codecmunger/vp8_test.go @@ -0,0 +1,532 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package codecmunger + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/testutils" +) + +func compare(expected *VP8, actual *VP8) bool { + return reflect.DeepEqual(expected.pictureIdWrapHandler, actual.pictureIdWrapHandler) && + expected.extLastPictureId == actual.extLastPictureId && + expected.pictureIdOffset == actual.pictureIdOffset && + expected.pictureIdUsed == actual.pictureIdUsed && + expected.lastTl0PicIdx == actual.lastTl0PicIdx && + expected.tl0PicIdxOffset == actual.tl0PicIdxOffset && + expected.tl0PicIdxUsed == actual.tl0PicIdxUsed && + expected.tidUsed == actual.tidUsed && + expected.lastKeyIdx == actual.lastKeyIdx && + expected.keyIdxOffset == actual.keyIdxOffset && + expected.keyIdxUsed == actual.keyIdxUsed +} + +func newVP8() *VP8 { + return NewVP8(logger.GetLogger()) +} + +func TestSetLast(t *testing.T) { + v := newVP8() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 13, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, err := testutils.GetTestExtPacketVP8(params, vp8) + require.NoError(t, err) + require.NotNil(t, extPkt) + + expectedVP8 := VP8{ + pictureIdWrapHandler: VP8PictureIdWrapHandler{ + maxPictureId: 13466, + maxMBit: true, + totalWrap: 0, + lastWrap: 0, + }, + extLastPictureId: 13467, + pictureIdOffset: 0, + pictureIdUsed: true, + lastTl0PicIdx: 233, + tl0PicIdxOffset: 0, + tl0PicIdxUsed: true, + tidUsed: true, + lastKeyIdx: 23, + keyIdxOffset: 0, + keyIdxUsed: true, + } + + v.SetLast(extPkt) + require.True(t, compare(&expectedVP8, v)) +} + +func TestUpdateOffsets(t *testing.T) { + v := newVP8() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 13, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + v.SetLast(extPkt) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 56789, + Timestamp: 0xabcdef, + SSRC: 0x87654321, + } + vp8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 345, + L: true, + TL0PICIDX: 12, + T: true, + TID: 13, + Y: true, + K: true, + KEYIDX: 4, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + v.UpdateOffsets(extPkt) + + expectedVP8 := VP8{ + pictureIdWrapHandler: VP8PictureIdWrapHandler{ + maxPictureId: 344, + maxMBit: true, + totalWrap: 0, + lastWrap: 0, + }, + extLastPictureId: 13467, + pictureIdOffset: 345 - 13467 - 1, + pictureIdUsed: true, + lastTl0PicIdx: 233, + tl0PicIdxOffset: (12 - 233 - 1) & 0xff, + tl0PicIdxUsed: true, + tidUsed: true, + lastKeyIdx: 23, + keyIdxOffset: (4 - 23 - 1) & 0x1f, + keyIdxUsed: true, + } + require.True(t, compare(&expectedVP8, v)) +} + +func TestOutOfOrderPictureId(t *testing.T) { + v := newVP8() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + v.SetLast(extPkt) + v.UpdateAndGet(extPkt, false, false, 2) + + // out-of-order sequence number not in the missing picture id cache + vp8.PictureID = 13466 + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + nIn, buf, err := v.UpdateAndGet(extPkt, true, false, 2) + require.Error(t, err) + require.ErrorIs(t, err, ErrOutOfOrderVP8PictureIdCacheMiss) + require.Equal(t, 0, nIn) + require.Nil(t, buf) + + // create a hole in picture id + vp8.PictureID = 13469 + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + expectedVP8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13469, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err := expectedVP8.Marshal() + require.NoError(t, err) + nIn, buf, err = v.UpdateAndGet(extPkt, false, true, 2) + require.NoError(t, err) + require.Equal(t, 6, nIn) + require.Equal(t, marshalledVP8, buf) + + // all three, the last, the current and the in-between should have been added to missing picture id cache + value, ok := v.PictureIdOffset(13467) + require.True(t, ok) + require.EqualValues(t, 0, value) + + value, ok = v.PictureIdOffset(13468) + require.True(t, ok) + require.EqualValues(t, 0, value) + + value, ok = v.PictureIdOffset(13469) + require.True(t, ok) + require.EqualValues(t, 0, value) + + // out-of-order sequence number should be in the missing picture id cache + vp8.PictureID = 13468 + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + expectedVP8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13468, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + nIn, buf, err = v.UpdateAndGet(extPkt, true, false, 2) + require.NoError(t, err) + require.Equal(t, 6, nIn) + require.Equal(t, marshalledVP8, buf) +} + +func TestTemporalLayerFiltering(t *testing.T) { + v := newVP8() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + v.SetLast(extPkt) + + // translate + nIn, buf, err := v.UpdateAndGet(extPkt, false, false, 0) + require.Error(t, err) + require.ErrorIs(t, err, ErrFilteredVP8TemporalLayer) + require.Equal(t, 0, nIn) + require.Nil(t, buf) + dropped, _ := v.droppedPictureIds.Get(13467) + require.True(t, dropped) + require.EqualValues(t, 1, v.pictureIdOffset) + + // another packet with the same picture id. + // It should be dropped, but offset should not be updated. + params.SequenceNumber = 23334 + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + nIn, buf, err = v.UpdateAndGet(extPkt, false, false, 0) + require.Error(t, err) + require.ErrorIs(t, err, ErrFilteredVP8TemporalLayer) + require.Equal(t, 0, nIn) + require.Nil(t, buf) + dropped, _ = v.droppedPictureIds.Get(13467) + require.True(t, dropped) + require.EqualValues(t, 1, v.pictureIdOffset) + + // another packet with the same picture id, but a gap in sequence number. + // It should be dropped, but offset should not be updated. + params.SequenceNumber = 23337 + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + nIn, buf, err = v.UpdateAndGet(extPkt, false, false, 0) + require.Error(t, err) + require.ErrorIs(t, err, ErrFilteredVP8TemporalLayer) + require.Equal(t, 0, nIn) + require.Nil(t, buf) + dropped, _ = v.droppedPictureIds.Get(13467) + require.True(t, dropped) + require.EqualValues(t, 1, v.pictureIdOffset) +} + +func TestGapInSequenceNumberSamePicture(t *testing.T) { + v := newVP8() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 65533, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 33, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + v.SetLast(extPkt) + + expectedVP8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err := expectedVP8.Marshal() + require.NoError(t, err) + nIn, buf, err := v.UpdateAndGet(extPkt, false, false, 2) + require.NoError(t, err) + require.Equal(t, 6, nIn) + require.Equal(t, marshalledVP8, buf) + + // telling there is a gap in sequence number will add pictures to missing picture cache + expectedVP8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + nIn, buf, err = v.UpdateAndGet(extPkt, false, true, 2) + require.NoError(t, err) + require.Equal(t, 6, nIn) + require.Equal(t, marshalledVP8, buf) + + value, ok := v.PictureIdOffset(13467) + require.True(t, ok) + require.EqualValues(t, 0, value) +} + +func TestUpdateAndGetPadding(t *testing.T) { + v := newVP8() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 13, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + + v.SetLast(extPkt) + + // getting padding with repeat of last picture + buf, err := v.UpdateAndGetPadding(false) + require.NoError(t, err) + expectedVP8 := buffer.VP8{ + FirstByte: 16, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err := expectedVP8.Marshal() + require.NoError(t, err) + require.Equal(t, marshalledVP8, buf) + + // getting padding with new picture + buf, err = v.UpdateAndGetPadding(true) + require.NoError(t, err) + expectedVP8 = buffer.VP8{ + FirstByte: 16, + I: true, + M: true, + PictureID: 13468, + L: true, + TL0PICIDX: 234, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 24, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + require.Equal(t, marshalledVP8, buf) +} + +func TestVP8PictureIdWrapHandler(t *testing.T) { + v := &VP8PictureIdWrapHandler{} + + v.Init(109, false) + require.Equal(t, int32(109), v.MaxPictureId()) + require.False(t, v.maxMBit) + + v.UpdateMaxPictureId(109350, true) + require.Equal(t, int32(109350), v.MaxPictureId()) + require.True(t, v.maxMBit) + + // start with something close to the 15-bit wrap around point + v.Init(32766, true) + + // out-of-order, do not wrap + extPictureId := v.Unwrap(32750, true) + require.Equal(t, int32(32750), extPictureId) + require.Equal(t, int32(0), v.totalWrap) + require.Equal(t, int32(0), v.lastWrap) + + // wrap at 15-bits + extPictureId = v.Unwrap(5, false) + require.Equal(t, int32(32773), extPictureId) // 15-bit wrap at 32768 + 5 = 32773 + require.Equal(t, int32(32768), v.totalWrap) + require.Equal(t, int32(32768), v.lastWrap) + + // set things near 7-bit wrap point + v.UpdateMaxPictureId(32893, false) // 32768 + 125 + + // wrap at 7-bits + extPictureId = v.Unwrap(5, true) + require.Equal(t, int32(32901), extPictureId) // 15-bit wrap at 32768 + 7-bit wrap at 128 + 5 = 32901 + require.Equal(t, int32(32896), v.totalWrap) // one 15-bit wrap + one 7-bit wrap + require.Equal(t, int32(128), v.lastWrap) + + // a new picture in 7-bit mode much with a gap in between. + // A big enough gap which would have been treated as out-of-order in 7-bit mode. + v.UpdateMaxPictureId(32901, false) + extPictureId = v.Unwrap(73, false) + require.Equal(t, int32(32841), extPictureId) // 15-bit wrap at 32768 + 73 = 32841 + + // a new picture in 15-bit mode much with a gap in between. + // A big enough gap which would have been treated as out-of-order in 7-bit mode. + v.UpdateMaxPictureId(32901, true) + v.lastWrap = int32(32768) + extPictureId = v.Unwrap(73, false) + require.Equal(t, int32(32969), extPictureId) // 15-bit wrap at 32768 + 7-bit wrap at 128 + 73 = 32969 +} diff --git a/livekit/pkg/sfu/connectionquality/connectionstats.go b/livekit/pkg/sfu/connectionquality/connectionstats.go new file mode 100644 index 0000000..0668b9c --- /dev/null +++ b/livekit/pkg/sfu/connectionquality/connectionstats.go @@ -0,0 +1,519 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionquality + +import ( + "sync" + "time" + + "github.com/frostbyte73/core" + "go.uber.org/atomic" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" +) + +const ( + UpdateInterval = 5 * time.Second + noReceiverReportTooLongThreshold = 30 * time.Second +) + +type ConnectionStatsReceiverProvider interface { + GetDeltaStats() map[uint32]*buffer.StreamStatsWithLayers + GetLastSenderReportTime() time.Time +} + +type ConnectionStatsSenderProvider interface { + GetDeltaStatsSender() map[uint32]*buffer.StreamStatsWithLayers + GetPrimaryStreamLastReceiverReportTime() time.Time + GetPrimaryStreamPacketsSent() uint64 +} + +type ConnectionStatsParams struct { + UpdateInterval time.Duration + IncludeRTT bool + IncludeJitter bool + EnableBitrateScore bool + ReceiverProvider ConnectionStatsReceiverProvider + SenderProvider ConnectionStatsSenderProvider + Logger logger.Logger +} + +type ConnectionStats struct { + params ConnectionStatsParams + + codecMimeType atomic.Value // mime.MimeType + + isStarted atomic.Bool + isVideo atomic.Bool + + onStatsUpdate func(cs *ConnectionStats, stat *livekit.AnalyticsStat) + + lock sync.RWMutex + packetsSent uint64 + streamingStartedAt time.Time + + scorer *qualityScorer + + done core.Fuse +} + +func NewConnectionStats(params ConnectionStatsParams) *ConnectionStats { + return &ConnectionStats{ + params: params, + scorer: newQualityScorer(qualityScorerParams{ + IncludeRTT: params.IncludeRTT, + IncludeJitter: params.IncludeJitter, + EnableBitrateScore: params.EnableBitrateScore, + Logger: params.Logger, + }), + } +} + +func (cs *ConnectionStats) StartAt(codecMimeType mime.MimeType, isFECEnabled bool, at time.Time) { + if cs.isStarted.Swap(true) { + return + } + + cs.isVideo.Store(mime.IsMimeTypeVideo(codecMimeType)) + cs.codecMimeType.Store(codecMimeType) + cs.scorer.StartAt(getPacketLossWeight(codecMimeType, isFECEnabled), at) + + go cs.updateStatsWorker() +} + +func (cs *ConnectionStats) Start(codecMimeType mime.MimeType, isFECEnabled bool) { + cs.StartAt(codecMimeType, isFECEnabled, time.Now()) +} + +func (cs *ConnectionStats) Close() { + cs.done.Break() +} + +func (cs *ConnectionStats) UpdateCodec(codecMimeType mime.MimeType, isFECEnabled bool) { + cs.isVideo.Store(mime.IsMimeTypeVideo(codecMimeType)) + cs.codecMimeType.Store(codecMimeType) + cs.scorer.UpdatePacketLossWeight(getPacketLossWeight(codecMimeType, isFECEnabled)) +} + +func (cs *ConnectionStats) OnStatsUpdate(fn func(cs *ConnectionStats, stat *livekit.AnalyticsStat)) { + cs.onStatsUpdate = fn +} + +func (cs *ConnectionStats) UpdateMuteAt(isMuted bool, at time.Time) { + if cs.done.IsBroken() { + return + } + + cs.scorer.UpdateMuteAt(isMuted, at) +} + +func (cs *ConnectionStats) UpdateMute(isMuted bool) { + if cs.done.IsBroken() { + return + } + + cs.scorer.UpdateMute(isMuted) +} + +func (cs *ConnectionStats) AddBitrateTransitionAt(bitrate int64, at time.Time) { + if cs.done.IsBroken() { + return + } + + cs.scorer.AddBitrateTransitionAt(bitrate, at) +} + +func (cs *ConnectionStats) AddBitrateTransition(bitrate int64) { + if cs.done.IsBroken() { + return + } + + cs.scorer.AddBitrateTransition(bitrate) +} + +func (cs *ConnectionStats) UpdateLayerMuteAt(isMuted bool, at time.Time) { + if cs.done.IsBroken() { + return + } + + cs.scorer.UpdateLayerMuteAt(isMuted, at) +} + +func (cs *ConnectionStats) UpdateLayerMute(isMuted bool) { + if cs.done.IsBroken() { + return + } + + cs.scorer.UpdateLayerMute(isMuted) +} + +func (cs *ConnectionStats) UpdatePauseAt(isPaused bool, at time.Time) { + if cs.done.IsBroken() { + return + } + + cs.scorer.UpdatePauseAt(isPaused, at) +} + +func (cs *ConnectionStats) UpdatePause(isPaused bool) { + if cs.done.IsBroken() { + return + } + + cs.scorer.UpdatePause(isPaused) +} + +func (cs *ConnectionStats) AddLayerTransitionAt(distance float64, at time.Time) { + if cs.done.IsBroken() { + return + } + + cs.scorer.AddLayerTransitionAt(distance, at) +} + +func (cs *ConnectionStats) AddLayerTransition(distance float64) { + if cs.done.IsBroken() { + return + } + + cs.scorer.AddLayerTransition(distance) +} + +func (cs *ConnectionStats) GetScoreAndQuality() (float32, livekit.ConnectionQuality) { + return cs.scorer.GetMOSAndQuality() +} + +func (cs *ConnectionStats) updateScoreWithAggregate(agg *rtpstats.RTPDeltaInfo, lastRTCPAt time.Time, at time.Time) float32 { + var stat windowStat + if agg != nil { + stat.startedAt = agg.StartTime + stat.duration = agg.EndTime.Sub(agg.StartTime) + stat.packets = agg.Packets + stat.packetsPadding = agg.PacketsPadding + stat.packetsLost = agg.PacketsLost + stat.packetsMissing = agg.PacketsMissing + stat.packetsOutOfOrder = agg.PacketsOutOfOrder + stat.bytes = agg.Bytes - agg.HeaderBytes // only use media payload size + stat.rttMax = agg.RttMax + stat.jitterMax = agg.JitterMax + + stat.lastRTCPAt = lastRTCPAt + } + if at.IsZero() { + cs.scorer.Update(&stat) + } else { + cs.scorer.UpdateAt(&stat, at) + } + + mos, _ := cs.scorer.GetMOSAndQuality() + return mos +} + +func (cs *ConnectionStats) updateScoreFromReceiverReport(at time.Time) (float32, map[uint32]*buffer.StreamStatsWithLayers) { + if cs.params.SenderProvider == nil { + return MinMOS, nil + } + + streamingStartedAt := cs.updateStreamingStart(at) + if streamingStartedAt.IsZero() { + // not streaming, just return current score + mos, _ := cs.scorer.GetMOSAndQuality() + return mos, nil + } + + streams := cs.params.SenderProvider.GetDeltaStatsSender() + if len(streams) == 0 { + // check for receiver report not received for a while + marker := cs.params.SenderProvider.GetPrimaryStreamLastReceiverReportTime() + if marker.IsZero() || streamingStartedAt.After(marker) { + marker = streamingStartedAt + } + if time.Since(marker) > noReceiverReportTooLongThreshold { + // have not received receiver report for a long time when streaming, run with nil stat + return cs.updateScoreWithAggregate(nil, time.Time{}, at), nil + } + + // wait for receiver report, return current score + mos, _ := cs.scorer.GetMOSAndQuality() + return mos, nil + } + + // delta stat duration could be large due to not receiving receiver report for a long time (for example, due to mute), + // adjust to streaming start if necessary + if streamingStartedAt.After(cs.params.SenderProvider.GetPrimaryStreamLastReceiverReportTime()) { + // last receiver report was before streaming started, wait for next one + mos, _ := cs.scorer.GetMOSAndQuality() + return mos, streams + } + + agg := toAggregateDeltaInfo(streams, true) + if agg == nil { + // no receiver report in the window + mos, _ := cs.scorer.GetMOSAndQuality() + return mos, streams + } + if streamingStartedAt.After(agg.StartTime) { + agg.StartTime = streamingStartedAt + } + return cs.updateScoreWithAggregate(agg, time.Time{}, at), streams +} + +func (cs *ConnectionStats) updateScoreAt(at time.Time) (float32, map[uint32]*buffer.StreamStatsWithLayers, bool) { + if cs.params.SenderProvider != nil { + // receiver report based quality scoring, use stats from receiver report for scoring + score, streams := cs.updateScoreFromReceiverReport(at) + return score, streams, true + } + + if cs.params.ReceiverProvider == nil { + return MinMOS, nil, false + } + + streams := cs.params.ReceiverProvider.GetDeltaStats() + if len(streams) == 0 { + mos, _ := cs.scorer.GetMOSAndQuality() + return mos, nil, false + } + + agg := toAggregateDeltaInfo(streams, false) + if agg == nil { + // no receiver report in the window + mos, _ := cs.scorer.GetMOSAndQuality() + return mos, streams, false + } + return cs.updateScoreWithAggregate(agg, cs.params.ReceiverProvider.GetLastSenderReportTime(), at), streams, false +} + +func (cs *ConnectionStats) updateStreamingStart(at time.Time) time.Time { + cs.lock.Lock() + defer cs.lock.Unlock() + + packetsSent := cs.params.SenderProvider.GetPrimaryStreamPacketsSent() + if packetsSent > cs.packetsSent { + if cs.streamingStartedAt.IsZero() { + // the start could be anywhere after last update, but using `at` as this is not required to be accurate + if at.IsZero() { + cs.streamingStartedAt = time.Now() + } else { + cs.streamingStartedAt = at + } + } + } else { + cs.streamingStartedAt = time.Time{} + } + cs.packetsSent = packetsSent + + return cs.streamingStartedAt +} + +func (cs *ConnectionStats) getStat() { + score, streams, isSender := cs.updateScoreAt(time.Time{}) + + if cs.onStatsUpdate != nil && len(streams) != 0 { + analyticsStreams := make([]*livekit.AnalyticsStream, 0, len(streams)) + for ssrc, stream := range streams { + as := toAnalyticsStream(ssrc, stream.RTPStats, stream.RTPStatsRemoteView, isSender) + if as == nil { + continue + } + + // + // add video layer if either + // 1. Simulcast - even if there is only one layer per stream as it provides layer id + // 2. A stream has multiple layers + // + if (len(streams) > 1 || len(stream.Layers) > 1) && cs.isVideo.Load() { + for layer, layerStats := range stream.Layers { + avl := toAnalyticsVideoLayer(layer, layerStats) + if avl != nil { + as.VideoLayers = append(as.VideoLayers, avl) + } + } + } + + analyticsStreams = append(analyticsStreams, as) + } + + if len(analyticsStreams) != 0 { + cs.onStatsUpdate(cs, &livekit.AnalyticsStat{ + Score: score, + Streams: analyticsStreams, + Mime: cs.codecMimeType.Load().(mime.MimeType).String(), + }) + } + } +} + +func (cs *ConnectionStats) updateStatsWorker() { + interval := cs.params.UpdateInterval + if interval == 0 { + interval = UpdateInterval + } + + tk := time.NewTicker(interval) + defer tk.Stop() + + for { + select { + case <-cs.done.Watch(): + cs.getStat() + return + + case <-tk.C: + if cs.done.IsBroken() { + return + } + + cs.getStat() + } + } +} + +// ----------------------------------------------------------------------- + +// how much weight to give to packet loss rate when calculating score. +// It is codec dependent. +// For audio: +// +// o Opus without FEC or RED suffers the most through packet loss, hence has the highest weight +// o RED with two packet redundancy can absorb one out of every two packets lost, so packet loss is not as detrimental and therefore lower weight +// +// For video: +// +// o No in-built codec repair available, hence same for all codecs +func getPacketLossWeight(mimeType mime.MimeType, isFecEnabled bool) float64 { + var plw float64 + switch { + case mimeType == mime.MimeTypeOpus: + // 2.5%: fall to GOOD, 7.5%: fall to POOR + plw = 8.0 + if isFecEnabled { + // 3.75%: fall to GOOD, 11.25%: fall to POOR + plw /= 1.5 + } + + case mimeType == mime.MimeTypeRED: + // 5%: fall to GOOD, 15.0%: fall to POOR + plw = 4.0 + if isFecEnabled { + // 7.5%: fall to GOOD, 22.5%: fall to POOR + plw /= 1.5 + } + + case mime.IsMimeTypeVideo(mimeType): + // 2%: fall to GOOD, 6%: fall to POOR + plw = 10.0 + } + + return plw +} + +func toAggregateDeltaInfo(streams map[uint32]*buffer.StreamStatsWithLayers, useRemoteView bool) *rtpstats.RTPDeltaInfo { + deltaInfoList := make([]*rtpstats.RTPDeltaInfo, 0, len(streams)) + for _, s := range streams { + if useRemoteView { + if s.RTPStatsRemoteView != nil { + // discount jitter from publisher side + internal processing while reporting downstream jitter + if s.RTPStats != nil { + s.RTPStatsRemoteView.JitterMax -= s.RTPStats.JitterMax + if s.RTPStatsRemoteView.JitterMax < 0.0 { + s.RTPStatsRemoteView.JitterMax = 0.0 + } + } + deltaInfoList = append(deltaInfoList, s.RTPStatsRemoteView) + } + } else { + if s.RTPStats != nil { + deltaInfoList = append(deltaInfoList, s.RTPStats) + } + } + } + return rtpstats.AggregateRTPDeltaInfo(deltaInfoList) +} + +func toAnalyticsStream( + ssrc uint32, + deltaStats *rtpstats.RTPDeltaInfo, + deltaStatsRemoteView *rtpstats.RTPDeltaInfo, + isSender bool, +) *livekit.AnalyticsStream { + if deltaStats == nil { + return nil + } + + // discount the feed side loss when reporting forwarded track stats, + // discount jitter from publisher side + internal processing while reporting downstream jitter + packetsLost := deltaStats.PacketsLost + rtt := deltaStats.RttMax + maxJitter := deltaStats.JitterMax + if deltaStatsRemoteView != nil { + packetsLost = deltaStatsRemoteView.PacketsLost + if deltaStatsRemoteView.PacketsMissing > packetsLost { + packetsLost = 0 + } else { + packetsLost -= deltaStatsRemoteView.PacketsMissing + } + + rtt = deltaStatsRemoteView.RttMax + maxJitter = deltaStatsRemoteView.JitterMax + } else if isSender { + rtt = 0 + maxJitter = 0 + } + + return &livekit.AnalyticsStream{ + StartTime: timestamppb.New(deltaStats.StartTime), + EndTime: timestamppb.New(deltaStats.EndTime), + Ssrc: ssrc, + PrimaryPackets: deltaStats.Packets, + PrimaryBytes: deltaStats.Bytes, + RetransmitPackets: deltaStats.PacketsDuplicate, + RetransmitBytes: deltaStats.BytesDuplicate, + PaddingPackets: deltaStats.PacketsPadding, + PaddingBytes: deltaStats.BytesPadding, + PacketsLost: packetsLost, + PacketsOutOfOrder: deltaStats.PacketsOutOfOrder, + Frames: deltaStats.Frames, + Rtt: rtt, + Jitter: uint32(maxJitter), + Nacks: deltaStats.Nacks, + Plis: deltaStats.Plis, + Firs: deltaStats.Firs, + } +} + +func toAnalyticsVideoLayer(layer int32, layerStats *rtpstats.RTPDeltaInfo) *livekit.AnalyticsVideoLayer { + if layerStats == nil { + return nil + } + + avl := &livekit.AnalyticsVideoLayer{ + Layer: layer, + Packets: layerStats.Packets + layerStats.PacketsDuplicate + layerStats.PacketsPadding, + Bytes: layerStats.Bytes + layerStats.BytesDuplicate + layerStats.BytesPadding, + Frames: layerStats.Frames, + } + if avl.Packets == 0 || avl.Bytes == 0 || avl.Frames == 0 { + return nil + } + + return avl +} diff --git a/livekit/pkg/sfu/connectionquality/connectionstats_test.go b/livekit/pkg/sfu/connectionquality/connectionstats_test.go new file mode 100644 index 0000000..73ba4ae --- /dev/null +++ b/livekit/pkg/sfu/connectionquality/connectionstats_test.go @@ -0,0 +1,903 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionquality + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +// ----------------------------------------------- + +type testReceiverProvider struct { + streams map[uint32]*buffer.StreamStatsWithLayers + lastSenderReportTime time.Time +} + +func newTestReceiverProvider() *testReceiverProvider { + return &testReceiverProvider{} +} + +func (trp *testReceiverProvider) setStreams(streams map[uint32]*buffer.StreamStatsWithLayers) { + trp.streams = streams +} + +func (trp *testReceiverProvider) GetDeltaStats() map[uint32]*buffer.StreamStatsWithLayers { + return trp.streams +} + +func (trp *testReceiverProvider) setLastSenderReportTime(at time.Time) { + trp.lastSenderReportTime = at +} + +func (trp *testReceiverProvider) GetLastSenderReportTime() time.Time { + return trp.lastSenderReportTime +} + +// ----------------------------------------------- + +func TestConnectionQuality(t *testing.T) { + trp := newTestReceiverProvider() + t.Run("quality scorer operation", func(t *testing.T) { + cs := NewConnectionStats(ConnectionStatsParams{ + IncludeRTT: true, + IncludeJitter: true, + EnableBitrateScore: true, + ReceiverProvider: trp, + Logger: logger.GetLogger(), + }) + + duration := 5 * time.Second + now := time.Now() + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) + cs.UpdateMuteAt(false, now.Add(-1*time.Second)) + + // no data and not enough unmute time should return default state which is EXCELLENT quality + cs.updateScoreAt(now) + mos, quality := cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // best conditions (no loss, jitter/rtt = 0) - quality should stay EXCELLENT + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // introduce loss and the score should drop - 12% loss for Opus -> POOR + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 120, + PacketsLost: 30, + }, + }, + 2: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 130, + PacketsLost: 0, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(2.1), mos) + require.Equal(t, livekit.ConnectionQuality_POOR, quality) + + // should climb to GOOD quality in one iteration if the conditions improve. + // although significant loss (12%) in the previous window, lowest score is + // bound so that climbing back does not take too long even under excellent conditions. + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + // should stay at GOOD if conditions continue to be good + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + // should climb up to EXCELLENT if conditions continue to be good + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // introduce loss and the score should drop - 5% loss for Opus -> GOOD + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + PacketsLost: 13, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + // should stay at GOOD quality for another iteration even if the conditions improve + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + // should climb up to EXCELLENT if conditions continue to be good + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // mute when quality is POOR should return quality to EXCELLENT + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + PacketsLost: 30, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(2.1), mos) + require.Equal(t, livekit.ConnectionQuality_POOR, quality) + + now = now.Add(duration) + cs.UpdateMuteAt(true, now.Add(1*time.Second)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // unmute at specific time to ensure next window does not satisfy the unmute time threshold. + // that means even if the next update has 0 packets, it should hold state and stay at EXCELLENT quality + cs.UpdateMuteAt(false, now.Add(3*time.Second)) + + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 0, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // next update with no packets, + // but last RTCP is not set, should knock quality down to POOR + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 0, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(2.1), mos) + require.Equal(t, livekit.ConnectionQuality_POOR, quality) + + // another dry spell, but last RTCP is not stale, should keep quality at POOR + now = now.Add(duration) + trp.setLastSenderReportTime(now.Add(time.Second)) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 0, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(2.1), mos) + require.Equal(t, livekit.ConnectionQuality_POOR, quality) + + // yet another dry spell, but last RTCP is stale, should knock down quality at LOST + now = now.Add(duration) + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 0, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(1.3), mos) + require.Equal(t, livekit.ConnectionQuality_LOST, quality) + + // mute when LOST should not bump up score/quality + now = now.Add(duration) + cs.UpdateMuteAt(true, now.Add(1*time.Second)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(1.3), mos) + require.Equal(t, livekit.ConnectionQuality_LOST, quality) + + // unmute and send packets to bring quality back up + now = now.Add(duration) + cs.UpdateMuteAt(false, now.Add(2*time.Second)) + for range 3 { + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + PacketsLost: 0, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + now = now.Add(duration) + } + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // with lesser number of packet (simulating DTX). + // even higher loss (like 10%) should not knock down quality due to quadratic weighting of packet loss ratio + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 50, + PacketsLost: 5, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // mute/unmute to bring quality back up + now = now.Add(duration) + cs.UpdateMuteAt(true, now.Add(1*time.Second)) + cs.UpdateMuteAt(false, now.Add(2*time.Second)) + + // RTT and jitter can knock quality down. + // at 2% loss, quality should stay at EXCELLENT purely based on loss, but with added RTT/jitter, should drop to GOOD + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + PacketsLost: 5, + RttMax: 400, + JitterMax: 30000, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + // mute/unmute to bring quality back up + now = now.Add(duration) + cs.UpdateMuteAt(true, now.Add(1*time.Second)) + cs.UpdateMuteAt(false, now.Add(2*time.Second)) + + // bitrate based calculation can drop quality even if there is no loss + cs.AddBitrateTransitionAt(1_000_000, now) + cs.AddBitrateTransitionAt(2_000_000, now.Add(2*time.Second)) + + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + Bytes: 8_000_000 / 8 / 5, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + // test layer mute via UpdateLayerMute API + cs.AddBitrateTransitionAt(1_000_000, now) + cs.AddBitrateTransitionAt(2_000_000, now.Add(2*time.Second)) + + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + Bytes: 8_000_000 / 8 / 5, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + + now = now.Add(duration) + cs.UpdateLayerMuteAt(true, now) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // unmute layer + cs.UpdateLayerMuteAt(false, now.Add(2*time.Second)) + + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + Bytes: 8_000_000 / 8 / 5, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + + // pause + now = now.Add(duration) + cs.UpdatePauseAt(true, now) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(2.1), mos) + require.Equal(t, livekit.ConnectionQuality_POOR, quality) + + // resume + cs.UpdatePauseAt(false, now.Add(2*time.Second)) + + // although conditions are perfect, climbing back from POOR (because of pause above) + // will only climb to GOOD. + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + Bytes: 8_000_000 / 8 / 5, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality = cs.GetScoreAndQuality() + require.Greater(t, float32(4.1), mos) + require.Equal(t, livekit.ConnectionQuality_GOOD, quality) + }) + + t.Run("quality scorer dependent rtt", func(t *testing.T) { + cs := NewConnectionStats(ConnectionStatsParams{ + IncludeRTT: false, + IncludeJitter: true, + ReceiverProvider: trp, + Logger: logger.GetLogger(), + }) + + duration := 5 * time.Second + now := time.Now() + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) + cs.UpdateMuteAt(false, now.Add(-1*time.Second)) + + // RTT does not knock quality down because it is dependent and hence not taken into account + // at 2% loss, quality should stay at EXCELLENT purely based on loss. With high RTT (700 ms) + // quality should drop to GOOD if RTT were taken into consideration + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + PacketsLost: 5, + RttMax: 700, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality := cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + }) + + t.Run("quality scorer dependent jitter", func(t *testing.T) { + cs := NewConnectionStats(ConnectionStatsParams{ + IncludeRTT: true, + IncludeJitter: false, + ReceiverProvider: trp, + Logger: logger.GetLogger(), + }) + + duration := 5 * time.Second + now := time.Now() + cs.StartAt(mime.MimeTypeOpus, false, now.Add(-duration)) + cs.UpdateMuteAt(false, now.Add(-1*time.Second)) + + // Jitter does not knock quality down because it is dependent and hence not taken into account + // at 2% loss, quality should stay at EXCELLENT purely based on loss. With high jitter (200 ms) + // quality should drop to GOOD if jitter were taken into consideration + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 1: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 250, + PacketsLost: 5, + JitterMax: 200, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality := cs.GetScoreAndQuality() + require.Greater(t, float32(4.6), mos) + require.Equal(t, livekit.ConnectionQuality_EXCELLENT, quality) + }) + + t.Run("codecs - packet", func(t *testing.T) { + type expectedQuality struct { + packetLossPercentage float64 + expectedMOS float32 + expectedQuality livekit.ConnectionQuality + } + testCases := []struct { + name string + mimeType mime.MimeType + isFECEnabled bool + packetsExpected uint32 + expectedQualities []expectedQuality + }{ + // NOTE: Because of EWMA (Exponentially Weighted Moving Average), these cut off points are not exact + // "audio/opus" - no fec - 0 <= loss < 2.5%: EXCELLENT, 2.5% <= loss < 7.5%: GOOD, >= 7.5%: POOR + { + name: "audio/opus - no fec", + mimeType: mime.MimeTypeOpus, + isFECEnabled: false, + packetsExpected: 200, + expectedQualities: []expectedQuality{ + { + packetLossPercentage: 1.0, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + packetLossPercentage: 4.0, + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + packetLossPercentage: 9.2, + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + }, + }, + // "audio/opus" - fec - 0 <= loss < 3.75%: EXCELLENT, 3.75% <= loss < 11.25%: GOOD, >= 11.25%: POOR + { + name: "audio/opus - fec", + mimeType: mime.MimeTypeOpus, + isFECEnabled: true, + packetsExpected: 200, + expectedQualities: []expectedQuality{ + { + packetLossPercentage: 3.0, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + packetLossPercentage: 4.4, + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + packetLossPercentage: 15.0, + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + }, + }, + // "audio/red" - no fec - 0 <= loss < 5%: EXCELLENT, 5% <= loss < 15%: GOOD, >= 15%: POOR + { + name: "audio/red - no fec", + mimeType: mime.MimeTypeRED, + isFECEnabled: false, + packetsExpected: 200, + expectedQualities: []expectedQuality{ + { + packetLossPercentage: 4.0, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + packetLossPercentage: 6.0, + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + packetLossPercentage: 19.5, + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + }, + }, + // "audio/red" - fec - 0 <= loss < 7.5%: EXCELLENT, 7.5% <= loss < 22.5%: GOOD, >= 22.5%: POOR + { + name: "audio/red - fec", + mimeType: mime.MimeTypeRED, + isFECEnabled: true, + packetsExpected: 200, + expectedQualities: []expectedQuality{ + { + packetLossPercentage: 6.0, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + packetLossPercentage: 10.0, + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + packetLossPercentage: 30.0, + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + }, + }, + // "video/*" - 0 <= loss < 2%: EXCELLENT, 2% <= loss < 6%: GOOD, >= 6%: POOR + { + name: "video/*", + mimeType: mime.MimeTypeVP8, + isFECEnabled: false, + packetsExpected: 200, + expectedQualities: []expectedQuality{ + { + packetLossPercentage: 1.0, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + packetLossPercentage: 3.5, + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + packetLossPercentage: 8.0, + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := NewConnectionStats(ConnectionStatsParams{ + IncludeRTT: true, + IncludeJitter: true, + ReceiverProvider: trp, + Logger: logger.GetLogger(), + }) + + duration := 5 * time.Second + now := time.Now() + cs.StartAt(tc.mimeType, tc.isFECEnabled, now.Add(-duration)) + + for _, eq := range tc.expectedQualities { + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 123: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: tc.packetsExpected, + PacketsLost: uint32(math.Ceil(eq.packetLossPercentage * float64(tc.packetsExpected) / 100.0)), + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality := cs.GetScoreAndQuality() + require.Greater(t, eq.expectedMOS, mos) + require.Equal(t, eq.expectedQuality, quality) + + now = now.Add(duration) + } + }) + } + }) + + t.Run("bitrate", func(t *testing.T) { + type transition struct { + bitrate int64 + offset time.Duration + } + testCases := []struct { + name string + transitions []transition + bytes uint64 + expectedMOS float32 + expectedQuality livekit.ConnectionQuality + }{ + // NOTE: Because of EWMA (Exponentially Weighted Moving Average), these cut off points are not exact + // 1.0 <= expectedBits / actualBits < ~2.7 = EXCELLENT + // ~2.7 <= expectedBits / actualBits < ~20.1 = GOOD + // expectedBits / actualBits >= ~20.1 = POOR + { + name: "excellent", + transitions: []transition{ + { + bitrate: 1_000_000, + }, + { + bitrate: 2_000_000, + offset: 3 * time.Second, + }, + }, + bytes: 6_000_000 / 8, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + name: "good", + transitions: []transition{ + { + bitrate: 1_000_000, + }, + { + bitrate: 2_000_000, + offset: 3 * time.Second, + }, + }, + bytes: uint64(math.Ceil(7_000_000.0 / 8.0 / 4.2)), + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + name: "poor", + transitions: []transition{ + { + bitrate: 2_000_000, + }, + { + bitrate: 1_000_000, + offset: 3 * time.Second, + }, + }, + bytes: uint64(math.Ceil(8_000_000.0 / 8.0 / 75.0)), + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := NewConnectionStats(ConnectionStatsParams{ + IncludeRTT: true, + IncludeJitter: true, + EnableBitrateScore: true, + ReceiverProvider: trp, + Logger: logger.GetLogger(), + }) + + duration := 5 * time.Second + now := time.Now() + cs.StartAt(mime.MimeTypeVP8, false, now) + + for _, tr := range tc.transitions { + cs.AddBitrateTransitionAt(tr.bitrate, now.Add(tr.offset)) + } + + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 123: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 100, + Bytes: tc.bytes, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality := cs.GetScoreAndQuality() + require.Greater(t, tc.expectedMOS, mos) + require.Equal(t, tc.expectedQuality, quality) + }) + } + }) + + t.Run("layer", func(t *testing.T) { + type transition struct { + distance float64 + offset time.Duration + } + testCases := []struct { + name string + transitions []transition + expectedMOS float32 + expectedQuality livekit.ConnectionQuality + }{ + // NOTE: Because of EWMA (Exponentially Weighted Moving Average), these cut off points are not exact + // each spatial layer missed drops o quality level + { + name: "excellent", + transitions: []transition{ + { + distance: 0.5, + }, + { + distance: 0.0, + offset: 3 * time.Second, + }, + }, + expectedMOS: 4.6, + expectedQuality: livekit.ConnectionQuality_EXCELLENT, + }, + { + name: "good", + transitions: []transition{ + { + distance: 1.0, + }, + { + distance: 1.5, + offset: 2 * time.Second, + }, + }, + expectedMOS: 4.1, + expectedQuality: livekit.ConnectionQuality_GOOD, + }, + { + name: "poor", + transitions: []transition{ + { + distance: 2.0, + }, + { + distance: 2.6, + offset: 1 * time.Second, + }, + }, + expectedMOS: 2.1, + expectedQuality: livekit.ConnectionQuality_POOR, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := NewConnectionStats(ConnectionStatsParams{ + IncludeRTT: true, + IncludeJitter: true, + ReceiverProvider: trp, + Logger: logger.GetLogger(), + }) + + duration := 5 * time.Second + now := time.Now() + cs.StartAt(mime.MimeTypeVP8, false, now) + + for _, tr := range tc.transitions { + cs.AddLayerTransitionAt(tr.distance, now.Add(tr.offset)) + } + + trp.setStreams(map[uint32]*buffer.StreamStatsWithLayers{ + 123: { + RTPStats: &rtpstats.RTPDeltaInfo{ + StartTime: now, + EndTime: now.Add(duration), + Packets: 200, + }, + }, + }) + cs.updateScoreAt(now.Add(duration)) + mos, quality := cs.GetScoreAndQuality() + require.Greater(t, tc.expectedMOS, mos) + require.Equal(t, tc.expectedQuality, quality) + }) + } + }) +} diff --git a/livekit/pkg/sfu/connectionquality/scorer.go b/livekit/pkg/sfu/connectionquality/scorer.go new file mode 100644 index 0000000..41977af --- /dev/null +++ b/livekit/pkg/sfu/connectionquality/scorer.go @@ -0,0 +1,644 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connectionquality + +import ( + "fmt" + "math" + "sync" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "go.uber.org/zap/zapcore" +) + +const ( + MaxMOS = float32(4.5) + MinMOS = float32(1.0) + + cMaxScore = float64(100.0) + cMinScore = float64(30.0) + + cIncreaseFactor = float64(0.4) // slower increase, i. e. when score is recovering move up slower -> conservative + cDecreaseFactor = float64(0.8) // faster decrease, i. e. when score is dropping move down faster -> aggressive to be responsive to quality drops + + cDistanceWeight = float64(35.0) // each spatial layer missed drops a quality level + + cUnmuteTimeThreshold = float64(0.5) + + cPPSQuantization = float64(2) + cPPSMinReadings = 10 + cModeCalculationInterval = 2 * time.Minute +) + +var ( + qualityTransitionScore = map[livekit.ConnectionQuality]float64{ + livekit.ConnectionQuality_GOOD: 80, + livekit.ConnectionQuality_POOR: 40, + livekit.ConnectionQuality_LOST: 20, + } +) + +// ------------------------------------------ + +type windowStat struct { + startedAt time.Time + duration time.Duration + packets uint32 + packetsPadding uint32 + packetsLost uint32 + packetsMissing uint32 + packetsOutOfOrder uint32 + bytes uint64 + rttMax uint32 + jitterMax float64 + lastRTCPAt time.Time +} + +func (w *windowStat) calculatePacketScore(aplw float64, includeRTT bool, includeJitter bool) float64 { + // this is based on simplified E-model based on packet loss, rtt, jitter as + // outlined at https://www.pingman.com/kb/article/how-is-mos-calculated-in-pingplotter-pro-50.html. + effectiveDelay := 0.0 + // discount the dependent factors if dependency indicated. + // for example, + // 1. in the up stream, RTT cannot be measured without RTCP-XR, it is using down stream RTT. + // 2. in the down stream, up stream jitter affects it. although jitter can be adjusted to account for up stream + // jitter, this lever can be used to discount jitter in scoring. + if includeRTT { + effectiveDelay += float64(w.rttMax) / 2.0 + } + if includeJitter { + effectiveDelay += (w.jitterMax * 2.0) / 1000.0 + } + delayEffect := effectiveDelay / 40.0 + if effectiveDelay > 160.0 { + delayEffect = (effectiveDelay - 120.0) / 10.0 + } + + // discount out-of-order packets from loss to deal with a scenario like + // 1. up stream has loss + // 2. down stream forwards with loss/hole in sequence number + // 3. down stream client reports a certain number of loss via RTCP RR + // 4. while processing that RTCP RR, up stream could have retransmitted missing packets + // 5. those retransmitted packets are forwarded, + // - server's view: it has forwarded those packets + // - client's view: it had not seen those packets when sending RTCP RR + // so those retransmitted packets appear like down stream loss to server. + // + // retransmitted packets would have arrived out-of-order. So, discounting them + // will account for it. + // + // Note that packets can arrive out-of-order in the upstream during regular + // streaming as well, i. e. without loss + NACK + retransmit. Those will be + // discounted too. And that will skew the real loss. For example, let + // us say that 40 out of 100 packets were reported lost by down stream. + // These could be real losses. In the same window, 40 packets could have been + // delivered out-of-order by the up stream, thus cancelling out the real loss. + // But, those situations should be rare and is a compromise for not letting + // up stream loss penalise down stream. + actualLost := w.packetsLost - w.packetsMissing - w.packetsOutOfOrder + if int32(actualLost) < 0 { + actualLost = 0 + } + + var lossEffect float64 + if w.packets+w.packetsPadding > 0 { + lossEffect = float64(actualLost) * 100.0 / float64(w.packets+w.packetsPadding) + } + lossEffect *= aplw + + score := cMaxScore - delayEffect - lossEffect + if score < 0.0 { + score = 0.0 + } + + return score +} + +func (w *windowStat) calculateBitrateScore(expectedBits int64, isEnabled bool) float64 { + if expectedBits == 0 || !isEnabled { + // unsupported mode OR all layers stopped + return cMaxScore + } + + var score float64 + if w.bytes != 0 { + // using the ratio of expectedBits / actualBits + // the quality inflection points are approximately + // GOOD at ~2.7x, POOR at ~20.1x + score = cMaxScore - 20*math.Log(float64(expectedBits)/float64(w.bytes*8)) + if score > cMaxScore { + score = cMaxScore + } + if score < 0.0 { + score = 0.0 + } + } + + return score +} + +func (w *windowStat) String() string { + return fmt.Sprintf("start: %+v, dur: %+v, p: %d, pp: %d, pl: %d, pm: %d, pooo: %d, b: %d, rtt: %d, jitter: %0.2f, lastRTCP: %+v", + w.startedAt, + w.duration, + w.packets, + w.packetsPadding, + w.packetsLost, + w.packetsMissing, + w.packetsOutOfOrder, + w.bytes, + w.rttMax, + w.jitterMax, + w.lastRTCPAt, + ) +} + +func (w *windowStat) MarshalLogObject(e zapcore.ObjectEncoder) error { + if w == nil { + return nil + } + + e.AddTime("startedAt", w.startedAt) + e.AddString("duration", w.duration.String()) + e.AddUint32("packets", w.packets) + e.AddUint32("packetsPadding", w.packetsPadding) + e.AddUint32("packetsLost", w.packetsLost) + e.AddUint32("packetsMissing", w.packetsMissing) + e.AddUint32("packetsOutOfOrder", w.packetsOutOfOrder) + e.AddUint64("bytes", w.bytes) + e.AddUint32("rttMax", w.rttMax) + e.AddFloat64("jitterMax", w.jitterMax) + e.AddTime("lastRTCPAt", w.lastRTCPAt) + return nil +} + +// ------------------------------------------ + +type qualityScorerParams struct { + IncludeRTT bool + IncludeJitter bool + EnableBitrateScore bool + Logger logger.Logger +} + +type qualityScorer struct { + params qualityScorerParams + + lock sync.RWMutex + lastUpdateAt time.Time + + packetLossWeight float64 + + score float64 + stat windowStat + + mutedAt time.Time + unmutedAt time.Time + + layerMutedAt time.Time + layerUnmutedAt time.Time + + pausedAt time.Time + resumedAt time.Time + + ppsHistogram [250]int + numPPSReadings int + ppsMode int + modeCalculatedAt time.Time + + aggregateBitrate *utils.TimedAggregator[int64] + layerDistance *utils.TimedAggregator[float64] +} + +func newQualityScorer(params qualityScorerParams) *qualityScorer { + return &qualityScorer{ + params: params, + score: cMaxScore, + aggregateBitrate: utils.NewTimedAggregator[int64](utils.TimedAggregatorParams{ + CapNegativeValues: true, + }), + layerDistance: utils.NewTimedAggregator[float64](utils.TimedAggregatorParams{ + CapNegativeValues: true, + }), + modeCalculatedAt: time.Now().Add(-cModeCalculationInterval), + } +} + +func (q *qualityScorer) startAtLocked(packetLossWeight float64, at time.Time) { + q.packetLossWeight = packetLossWeight + q.lastUpdateAt = at +} + +func (q *qualityScorer) StartAt(packetLossWeight float64, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.startAtLocked(packetLossWeight, at) +} + +func (q *qualityScorer) Start(packetLossWeight float64) { + q.lock.Lock() + defer q.lock.Unlock() + + q.startAtLocked(packetLossWeight, time.Now()) +} + +func (q *qualityScorer) UpdatePacketLossWeight(packetLossWeight float64) { + q.lock.Lock() + defer q.lock.Unlock() + + q.packetLossWeight = packetLossWeight +} + +func (q *qualityScorer) updateMuteAtLocked(isMuted bool, at time.Time) { + if isMuted { + q.mutedAt = at + // muting when LOST should not push quality to EXCELLENT + if q.score != qualityTransitionScore[livekit.ConnectionQuality_LOST] { + q.score = cMaxScore + } + } else { + q.unmutedAt = at + } +} + +func (q *qualityScorer) UpdateMuteAt(isMuted bool, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updateMuteAtLocked(isMuted, at) +} + +func (q *qualityScorer) UpdateMute(isMuted bool) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updateMuteAtLocked(isMuted, time.Now()) +} + +func (q *qualityScorer) addBitrateTransitionAtLocked(bitrate int64, at time.Time) { + q.aggregateBitrate.AddSampleAt(bitrate, at) +} + +func (q *qualityScorer) AddBitrateTransitionAt(bitrate int64, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.addBitrateTransitionAtLocked(bitrate, at) +} + +func (q *qualityScorer) AddBitrateTransition(bitrate int64) { + q.lock.Lock() + defer q.lock.Unlock() + + q.addBitrateTransitionAtLocked(bitrate, time.Now()) +} + +func (q *qualityScorer) updateLayerMuteAtLocked(isMuted bool, at time.Time) { + if isMuted { + if !q.isLayerMuted() { + q.aggregateBitrate.Reset() + q.layerDistance.Reset() + q.layerMutedAt = at + q.score = cMaxScore + } + } else { + if q.isLayerMuted() { + q.layerUnmutedAt = at + } + } +} + +func (q *qualityScorer) UpdateLayerMuteAt(isMuted bool, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updateLayerMuteAtLocked(isMuted, at) +} + +func (q *qualityScorer) UpdateLayerMute(isMuted bool) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updateLayerMuteAtLocked(isMuted, time.Now()) +} + +func (q *qualityScorer) updatePauseAtLocked(isPaused bool, at time.Time) { + if isPaused { + if !q.isPaused() { + q.aggregateBitrate.Reset() + q.layerDistance.Reset() + q.pausedAt = at + q.score = cMinScore + } + } else { + if q.isPaused() { + q.resumedAt = at + } + } +} + +func (q *qualityScorer) UpdatePauseAt(isPaused bool, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updatePauseAtLocked(isPaused, at) +} + +func (q *qualityScorer) UpdatePause(isPaused bool) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updatePauseAtLocked(isPaused, time.Now()) +} + +func (q *qualityScorer) addLayerTransitionAtLocked(distance float64, at time.Time) { + q.layerDistance.AddSampleAt(distance, at) +} + +func (q *qualityScorer) AddLayerTransitionAt(distance float64, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.addLayerTransitionAtLocked(distance, at) +} + +func (q *qualityScorer) AddLayerTransition(distance float64) { + q.lock.Lock() + defer q.lock.Unlock() + + q.addLayerTransitionAtLocked(distance, time.Now()) +} + +func (q *qualityScorer) updateAtLocked(stat *windowStat, at time.Time) { + // always update transitions + expectedBits, _, err := q.aggregateBitrate.GetAggregateAndRestartAt(at) + if err != nil { + q.params.Logger.Warnw("error getting expected bitrate", err) + } + expectedDistance, err := q.layerDistance.GetAverageAndRestartAt(at) + if err != nil { + q.params.Logger.Warnw("error getting expected distance", err) + } + + // nothing to do when muted or not unmuted for long enough + // NOTE: it is possible that unmute -> mute -> unmute transition happens in the + // same analysis window. On a transition to mute, quality is immediately moved + // EXCELLENT for responsiveness. On an unmute, the entire window data is + // considered (as long as enough time has passed since unmute). + // + // Similarly, when paused (possibly due to congestion), score is immediately + // set to cMinScore for responsiveness. The layer transition is reset. + // On a resume, quality climbs back up using normal operation. + if q.isMuted() || !q.isUnmutedEnough(at) || q.isLayerMuted() || q.isPaused() { + q.lastUpdateAt = at + return + } + + aplw := q.getAdjustedPacketLossWeight(stat) + reason := "none" + var score, packetScore, bitrateScore, layerScore float64 + if stat.packets+stat.packetsPadding == 0 { + if !stat.lastRTCPAt.IsZero() && at.Sub(stat.lastRTCPAt) > stat.duration { + reason = "rtcp" + score = qualityTransitionScore[livekit.ConnectionQuality_LOST] + } else { + reason = "dry" + score = qualityTransitionScore[livekit.ConnectionQuality_POOR] + } + } else { + packetScore = stat.calculatePacketScore(aplw, q.params.IncludeRTT, q.params.IncludeJitter) + bitrateScore = stat.calculateBitrateScore(expectedBits, q.params.EnableBitrateScore) + layerScore = math.Max(math.Min(cMaxScore, cMaxScore-(expectedDistance*cDistanceWeight)), 0.0) + + minScore := math.Min(packetScore, bitrateScore) + minScore = math.Min(minScore, layerScore) + + switch { + case packetScore == minScore: + reason = "packet" + score = packetScore + + case bitrateScore == minScore: + reason = "bitrate" + score = bitrateScore + + case layerScore == minScore: + reason = "layer" + score = layerScore + } + + factor := cIncreaseFactor + if score < q.score { + factor = cDecreaseFactor + } + score = factor*score + (1.0-factor)*q.score + if score < cMinScore { + // lower bound to prevent score from becoming very small values due to extreme conditions. + // Without a lower bound, it can get so low that it takes a long time to climb back to + // better quality even under excellent conditions. + score = cMinScore + } + } + + prevCQ := scoreToConnectionQuality(q.score) + currCQ := scoreToConnectionQuality(score) + ulgr := q.params.Logger.WithUnlikelyValues( + "reason", reason, + "prevScore", q.score, + "prevQuality", prevCQ, + "prevStat", &q.stat, + "score", score, + "packetScore", packetScore, + "layerScore", layerScore, + "bitrateScore", bitrateScore, + "quality", currCQ, + "stat", stat, + "packetLossWeight", q.packetLossWeight, + "adjustedPacketLossWeight", aplw, + "modePPS", q.ppsMode*int(cPPSQuantization), + "expectedBits", expectedBits, + "expectedDistance", expectedDistance, + ) + switch { + case utils.IsConnectionQualityLower(prevCQ, currCQ): + ulgr.Debugw("quality drop") + case utils.IsConnectionQualityHigher(prevCQ, currCQ): + ulgr.Debugw("quality rise") + default: + packets := stat.packets + stat.packetsPadding + if packets != 0 && (stat.packetsLost*100/packets) > 10 { + ulgr.Debugw("quality hold - high loss") + } + } + + q.score = score + q.stat = *stat + q.lastUpdateAt = at +} + +func (q *qualityScorer) UpdateAt(stat *windowStat, at time.Time) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updateAtLocked(stat, at) +} + +func (q *qualityScorer) Update(stat *windowStat) { + q.lock.Lock() + defer q.lock.Unlock() + + q.updateAtLocked(stat, time.Now()) +} + +func (q *qualityScorer) isMuted() bool { + return !q.mutedAt.IsZero() && (q.unmutedAt.IsZero() || q.mutedAt.After(q.unmutedAt)) +} + +func (q *qualityScorer) isUnmutedEnough(at time.Time) bool { + var sinceUnmute time.Duration + if q.unmutedAt.IsZero() { + sinceUnmute = at.Sub(q.lastUpdateAt) + } else { + sinceUnmute = at.Sub(q.unmutedAt) + } + + var sinceLayerUnmute time.Duration + if q.layerUnmutedAt.IsZero() { + sinceLayerUnmute = at.Sub(q.lastUpdateAt) + } else { + sinceLayerUnmute = at.Sub(q.layerUnmutedAt) + } + + validDuration := sinceUnmute + if sinceLayerUnmute < validDuration { + validDuration = sinceLayerUnmute + } + + sinceLastUpdate := at.Sub(q.lastUpdateAt) + + return validDuration.Seconds()/sinceLastUpdate.Seconds() > cUnmuteTimeThreshold +} + +func (q *qualityScorer) isLayerMuted() bool { + return !q.layerMutedAt.IsZero() && (q.layerUnmutedAt.IsZero() || q.layerMutedAt.After(q.layerUnmutedAt)) +} + +func (q *qualityScorer) isPaused() bool { + return !q.pausedAt.IsZero() && (q.resumedAt.IsZero() || q.pausedAt.After(q.resumedAt)) +} + +func (q *qualityScorer) getAdjustedPacketLossWeight(stat *windowStat) float64 { + if stat == nil || stat.duration <= 0 { + return q.packetLossWeight + } + + // packet loss is weighted by comparing against mode of packet rate seen. + // this is to handle situations like DTX in audio and variable bit rate tracks like screen share. + // and the effect of loss is not pronounced in those scenarios (audio silence, static screen share). + // for example, DTX typically uses only 5% of packets of full packet rate. at that rate, + // packet loss weight is reduced to ~22% of configured weight (i. e. sqrt(0.05) * configured weight) + pps := float64(stat.packets) / stat.duration.Seconds() + ppsQuantized := int(pps/cPPSQuantization + 0.5) + if ppsQuantized < len(q.ppsHistogram)-1 { + q.ppsHistogram[ppsQuantized]++ + } else { + q.ppsHistogram[len(q.ppsHistogram)-1]++ + } + q.numPPSReadings++ + + // calculate mode sparingly, do it under the following conditions + // 1. minimum number of readings available (AND) + // 2. enough time has elapsed since last calculation + if q.numPPSReadings > cPPSMinReadings && time.Since(q.modeCalculatedAt) > cModeCalculationInterval { + q.ppsMode = 0 + for i := range len(q.ppsHistogram) { + if q.ppsHistogram[i] > q.ppsMode { + q.ppsMode = i + } + } + q.modeCalculatedAt = time.Now() + q.params.Logger.Debugw("updating pps mode", "expected", stat.packets, "duration", stat.duration.Seconds(), "pps", pps, "ppsMode", q.ppsMode) + } + + if q.ppsMode == 0 || q.ppsMode == len(q.ppsHistogram)-1 { + return q.packetLossWeight + } + + packetRatio := pps / (float64(q.ppsMode) * cPPSQuantization) + if packetRatio > 1.0 { + packetRatio = 1.0 + } + return math.Sqrt(packetRatio) * q.packetLossWeight +} + +func (q *qualityScorer) GetScoreAndQuality() (float32, livekit.ConnectionQuality) { + q.lock.RLock() + defer q.lock.RUnlock() + + return float32(q.score), scoreToConnectionQuality(q.score) +} + +func (q *qualityScorer) GetMOSAndQuality() (float32, livekit.ConnectionQuality) { + q.lock.RLock() + defer q.lock.RUnlock() + + return scoreToMOS(q.score), scoreToConnectionQuality(q.score) +} + +// ------------------------------------------ + +func scoreToConnectionQuality(score float64) livekit.ConnectionQuality { + // R-factor -> livekit.ConnectionQuality scale mapping roughly based on + // https://www.itu.int/ITU-T/2005-2008/com12/emodelv1/tut.htm + // + // As there are only three levels in livekit.ConnectionQuality scale, + // using a larger range for middling quality. Empirical evidence suggests + // that a score of 60 does not correspond to `POOR` quality. Repair + // mechanisms and use of algorithms like de-jittering makes the experience + // better even under harsh conditions. + if score > qualityTransitionScore[livekit.ConnectionQuality_GOOD] { + return livekit.ConnectionQuality_EXCELLENT + } + + if score > qualityTransitionScore[livekit.ConnectionQuality_POOR] { + return livekit.ConnectionQuality_GOOD + } + + if score > qualityTransitionScore[livekit.ConnectionQuality_LOST] { + return livekit.ConnectionQuality_POOR + } + + return livekit.ConnectionQuality_LOST +} + +// ------------------------------------------ + +func scoreToMOS(score float64) float32 { + if score <= 0.0 { + return 1.0 + } + + if score >= 100.0 { + return 4.5 + } + + return float32(1.0 + 0.035*score + (0.000007 * score * (score - 60.0) * (100.0 - score))) +} + +// ------------------------------------------ diff --git a/livekit/pkg/sfu/datachannel/bitrate.go b/livekit/pkg/sfu/datachannel/bitrate.go new file mode 100644 index 0000000..b95198a --- /dev/null +++ b/livekit/pkg/sfu/datachannel/bitrate.go @@ -0,0 +1,109 @@ +package datachannel + +import ( + "sync" + "time" + + "github.com/gammazero/deque" + + "github.com/livekit/protocol/utils/mono" +) + +const ( + BitrateDuration = 2 * time.Second + BitrateWindow = 100 * time.Millisecond +) + +// BitrateCalculator calculates bitrate over sliding window +type BitrateCalculator struct { + lock sync.Mutex + windowDuration time.Duration + duration time.Duration + + windows deque.Deque[bitrateWindow] + active bitrateWindow + + bytes int + lastBufferedAmount int + start time.Time +} + +func NewBitrateCalculator(duration time.Duration, window time.Duration) *BitrateCalculator { + windowCnt := int((duration + (window - 1)) / window) + if windowCnt == 0 { + windowCnt = 1 + } + now := mono.Now() + c := &BitrateCalculator{ + duration: duration, + windowDuration: window, + start: now, + active: bitrateWindow{start: now}, + } + c.windows.SetBaseCap(windowCnt + 1) + + return c +} + +func (c *BitrateCalculator) AddBytes(bytes int, bufferedAmout int, ts time.Time) { + c.lock.Lock() + defer c.lock.Unlock() + + bytes -= bufferedAmout - c.lastBufferedAmount + if bytes < 0 { + // it is possible that internal buffering (non-data like DCEP packet from webrtc) caused bytes to be negative + bytes = 0 + } + c.lastBufferedAmount = bufferedAmout + if ts.Sub(c.active.start) >= c.windowDuration { + c.windows.PushBack(c.active) + c.active.start = ts + c.active.bytes = 0 + + for c.windows.Len() > 0 { + // pop expired windows + if w := c.windows.Front(); ts.Sub(w.start) > (c.duration + c.windowDuration) { + c.bytes -= w.bytes + c.windows.PopFront() + } else { + c.start = w.start + break + } + } + if c.windows.Len() == 0 { + c.start = ts + c.bytes = 0 + } + } + c.bytes += bytes + c.active.bytes += bytes + +} + +func (c *BitrateCalculator) Bitrate(ts time.Time) (int, bool) { + return c.bitrate(ts, false) +} + +func (c *BitrateCalculator) ForceBitrate(ts time.Time) (int, bool) { + return c.bitrate(ts, true) +} + +func (c *BitrateCalculator) bitrate(ts time.Time, force bool) (int, bool) { + c.lock.Lock() + defer c.lock.Unlock() + duration := ts.Sub(c.start) + if duration < c.windowDuration { + if force { + duration = c.windowDuration + } else { + return 0, false + } + } + + return c.bytes * 8 * 1000 / int(duration.Milliseconds()), true +} + +type bitrateWindow struct { + start time.Time + bytes int +} diff --git a/livekit/pkg/sfu/datachannel/bitrate_test.go b/livekit/pkg/sfu/datachannel/bitrate_test.go new file mode 100644 index 0000000..c1414c7 --- /dev/null +++ b/livekit/pkg/sfu/datachannel/bitrate_test.go @@ -0,0 +1,36 @@ +package datachannel + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBitrateCalculator(t *testing.T) { + c := NewBitrateCalculator(BitrateDuration, BitrateWindow) + require.NotNil(t, c) + + t0 := time.Now() + c.AddBytes(100, 0, t0) + // bytes buffered + c.AddBytes(100, 100, t0.Add(50*time.Millisecond)) + bitrate, ok := c.Bitrate(t0.Add(50 * time.Millisecond)) + require.Equal(t, 0, bitrate) + require.False(t, ok) + // 50 bytes sent (50 bytes buffer flushed) + c.AddBytes(100, 50, t0.Add(time.Second)) + + // 250 bytes sent in 1 second + bitrate, ok = c.Bitrate(t0.Add(time.Second)) + require.Equal(t, 2000, bitrate) + require.True(t, ok) + + // silence for long time + t1 := t0.Add(2 * BitrateDuration) + // 150 bytes sent (50 bytes buffer flushed) + c.AddBytes(100, 0, t1) + bitrate, ok = c.Bitrate(t1.Add(time.Second)) + require.Equal(t, 1200, bitrate) + require.True(t, ok) +} diff --git a/livekit/pkg/sfu/datachannel/datachannel_writer.go b/livekit/pkg/sfu/datachannel/datachannel_writer.go new file mode 100644 index 0000000..8fbbb3c --- /dev/null +++ b/livekit/pkg/sfu/datachannel/datachannel_writer.go @@ -0,0 +1,142 @@ +package datachannel + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/pion/datachannel" + + "github.com/livekit/protocol/utils/mono" +) + +const ( + singleWriteTimeout = 50 * time.Millisecond +) + +var ( + ErrDataDroppedBySlowReader = errors.New("data dropped by slow reader") + ErrDataDroppedByHighBufferedAmount = errors.New("data dropped due to high buffered amount") +) + +type BufferedAmountGetter interface { + BufferedAmount() uint64 +} + +type DataChannelWriter[T BufferedAmountGetter] struct { + bufferGetter T + rawDC datachannel.ReadWriteCloserDeadliner + slowThreshold int + rate *BitrateCalculator + reliable bool + targetLatency time.Duration + minBufferedAmount uint64 +} + +// NewDataChannelWriterReliable creates a new DataChannelWriter for reliable data channel by +// detaching it, when writing to the datachanel times out, it will block and retry if the +// receiver's bitrate is above the slowThreshold or drop the data if it's below the threshold. +// If the slowThreshold is 0, it will never retry on write timeout. +func NewDataChannelWriterReliable[T BufferedAmountGetter](bufferGetter T, rawDC datachannel.ReadWriteCloserDeadliner, slowThreshold int) *DataChannelWriter[T] { + var rate *BitrateCalculator + if slowThreshold > 0 { + rate = NewBitrateCalculator(BitrateDuration, BitrateWindow) + } + return &DataChannelWriter[T]{ + bufferGetter: bufferGetter, + rawDC: rawDC, + slowThreshold: slowThreshold, + rate: rate, + reliable: true, + } +} + +// NewDataChannelWriterUnreliable creates a new DataChannelWriter for unreliable data channel. +// It will drop data when the buffered amount is too high to maintain the target latency. +// The latency is estimated based on the bitrate in past 1 second. If targetLatency is 0, no +// buffering control is applied. +func NewDataChannelWriterUnreliable[T BufferedAmountGetter](bufferGetter T, rawDC datachannel.ReadWriteCloserDeadliner, targetLatency time.Duration, minBufferedAmount uint64) *DataChannelWriter[T] { + var rate *BitrateCalculator + if targetLatency > 0 { + rate = NewBitrateCalculator(BitrateDuration, BitrateWindow) + } + return &DataChannelWriter[T]{ + bufferGetter: bufferGetter, + rawDC: rawDC, + rate: rate, + targetLatency: targetLatency, + minBufferedAmount: minBufferedAmount, + reliable: false, + } +} + +func (w *DataChannelWriter[T]) BufferedAmountGetter() T { + return w.bufferGetter +} + +func (w *DataChannelWriter[T]) Write(p []byte) (n int, err error) { + if w.reliable { + return w.writeReliable(p) + } else { + return w.writeUnreliable(p) + } +} + +func (w *DataChannelWriter[T]) writeReliable(p []byte) (n int, err error) { + for { + err = w.rawDC.SetWriteDeadline(time.Now().Add(singleWriteTimeout)) + if err != nil { + return 0, err + } + n, err = w.rawDC.Write(p) + if w.slowThreshold == 0 { + return + } + + now := mono.Now() + w.rate.AddBytes(n, int(w.bufferGetter.BufferedAmount()), now) + // retry if the write timed out on a non-slow receiver + if errors.Is(err, context.DeadlineExceeded) { + if bitrate, ok := w.rate.Bitrate(now); !ok || bitrate >= w.slowThreshold { + continue + } else { + err = fmt.Errorf("%w: bitrate %d, threshold %d", ErrDataDroppedBySlowReader, bitrate, w.slowThreshold) + } + } + + return + } +} + +func (w *DataChannelWriter[T]) writeUnreliable(p []byte) (n int, err error) { + if w.targetLatency == 0 { + err = w.rawDC.SetWriteDeadline(time.Now().Add(singleWriteTimeout)) + if err != nil { + return 0, err + } + return w.rawDC.Write(p) + } + + if bitrate, ok := w.rate.Bitrate(time.Now()); ok { + // control buffer latency to ~100ms + if w.bufferGetter.BufferedAmount() > uint64(time.Duration(bitrate)*w.targetLatency/8/time.Second) && w.bufferGetter.BufferedAmount() > w.minBufferedAmount { + return 0, ErrDataDroppedByHighBufferedAmount + } + } + + err = w.rawDC.SetWriteDeadline(time.Now().Add(singleWriteTimeout)) + if err != nil { + return 0, err + } + n, err = w.rawDC.Write(p) + if err != nil { + w.rate.AddBytes(n, int(w.bufferGetter.BufferedAmount()), mono.Now()) + } + + return +} + +func (w *DataChannelWriter[T]) Close() error { + return w.rawDC.Close() +} diff --git a/livekit/pkg/sfu/datachannel/datachannel_writer_test.go b/livekit/pkg/sfu/datachannel/datachannel_writer_test.go new file mode 100644 index 0000000..9b99b89 --- /dev/null +++ b/livekit/pkg/sfu/datachannel/datachannel_writer_test.go @@ -0,0 +1,154 @@ +package datachannel + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/pion/datachannel" + "github.com/pion/transport/v4/deadline" + "github.com/stretchr/testify/require" +) + +func TestDataChannelWriter(t *testing.T) { + mockDC := newMockDataChannelWriter() + // slow threshold is 1000B/s + w := NewDataChannelWriterReliable(mockDC, mockDC, 8000) + require.Equal(t, mockDC, w.BufferedAmountGetter()) + buf := make([]byte, 2000) + // write 2000 bytes so it should not drop in 2 seconds + t0 := time.Now() + n, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, 2000, n) + + t1 := time.Now() + mockDC.SetNextWriteCompleteAt(t0.Add(time.Second * 3 / 2)) + n, err = w.Write(buf[:10]) + require.NoError(t, err) + require.Equal(t, 10, n) + require.GreaterOrEqual(t, time.Since(t1), time.Second) + + // bitrate below slow threshold(2000bytes/3sec), should drop by timeout + mockDC.SetNextWriteCompleteAt(t0.Add(3 * time.Second)) + n, err = w.Write(buf[:1000]) + require.ErrorIs(t, err, ErrDataDroppedBySlowReader, err) + require.Equal(t, 0, n) +} + +func TestDataChannelWriter_NoSlowThreshold(t *testing.T) { + mockDC := newMockDataChannelWriter() + w := NewDataChannelWriterReliable(mockDC, mockDC, 0) + buf := make([]byte, 2000) + n, err := w.Write(buf) + require.NoError(t, err) + require.Equal(t, 2000, n) + mockDC.SetNextWriteCompleteAt(time.Now().Add(singleWriteTimeout / 2)) + n, err = w.Write(buf[:10]) + require.NoError(t, err) + require.Equal(t, 10, n) + + // slow threshold is 0, should not block & retry + mockDC.SetNextWriteCompleteAt(time.Now().Add(singleWriteTimeout * 2)) + n, err = w.Write(buf[:1000]) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + require.Equal(t, 0, n) +} + +func TestDataChannelWriter_Unreliable(t *testing.T) { + mockDC := newMockLossyDataChannelWriter(8192) + w := NewDataChannelWriterUnreliable(mockDC, mockDC, 100*time.Millisecond, 2000) + for range 10 { + buf := make([]byte, 128) + _, err := w.Write(buf) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + } + buf := make([]byte, 4096) + _, err := w.Write(buf) + require.NoError(t, err) + // should drop due to high buffered amount + _, err = w.Write(buf) + require.ErrorIs(t, err, ErrDataDroppedByHighBufferedAmount) +} + +// mockDataChannelWriter +type mockDataChannelWriter struct { + datachannel.ReadWriteCloserDeadliner + nextWriteCompleteAt time.Time + deadline *deadline.Deadline +} + +func newMockDataChannelWriter() *mockDataChannelWriter { + return &mockDataChannelWriter{ + deadline: deadline.New(), + } +} + +func (m *mockDataChannelWriter) BufferedAmount() uint64 { + return 0 +} + +func (m *mockDataChannelWriter) Write(b []byte) (int, error) { + wait := time.Until(m.nextWriteCompleteAt) + if wait <= 0 { + return len(b), nil + } + select { + case <-m.deadline.Done(): + return 0, m.deadline.Err() + case <-time.After(wait): + return len(b), nil + } +} + +func (m *mockDataChannelWriter) SetWriteDeadline(t time.Time) error { + m.deadline.Set(t) + return nil +} + +func (m *mockDataChannelWriter) SetNextWriteCompleteAt(t time.Time) { + m.nextWriteCompleteAt = t +} + +// mockLossyDataChannelWriter +type mockLossyDataChannelWriter struct { + datachannel.ReadWriteCloserDeadliner + bufferedAmount atomic.Int64 + targetBitrate int + lastWriteAt time.Time +} + +func newMockLossyDataChannelWriter(targetBitrate int) *mockLossyDataChannelWriter { + return &mockLossyDataChannelWriter{ + targetBitrate: targetBitrate, + lastWriteAt: time.Now(), + } +} + +func (m *mockLossyDataChannelWriter) BufferedAmount() uint64 { + return uint64(m.bufferedAmount.Load()) +} + +func (m *mockLossyDataChannelWriter) Write(b []byte) (int, error) { + m.bufferedAmount.Add(int64(len(b))) + if time.Now().Before(m.lastWriteAt) { + return len(b), nil + } + + // drain buffer based on target bitrate + canWriteBytes := time.Since(m.lastWriteAt) * time.Duration(m.targetBitrate) / time.Second / 8 + if m.bufferedAmount.Load() <= int64(canWriteBytes) { + m.lastWriteAt = m.lastWriteAt.Add(time.Duration(int64(time.Second) * int64(m.BufferedAmount()) / (int64(m.targetBitrate) / 8))) + m.bufferedAmount.Store(0) + } else { + m.lastWriteAt = time.Now() + m.bufferedAmount.Add(-int64(canWriteBytes)) + } + return len(b), nil +} + +func (m *mockLossyDataChannelWriter) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/livekit/pkg/sfu/downtrack.go b/livekit/pkg/sfu/downtrack.go new file mode 100644 index 0000000..662caf8 --- /dev/null +++ b/livekit/pkg/sfu/downtrack.go @@ -0,0 +1,2544 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "math/rand" + "strings" + "sync" + "time" + + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/sdp/v3" + "github.com/pion/transport/v4/packetio" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + "go.uber.org/zap/zapcore" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + act "github.com/livekit/livekit-server/pkg/sfu/rtpextension/abscapturetime" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/livekit-server/pkg/sfu/utils" +) + +// TrackSender defines an interface send media to remote peer +type TrackSender interface { + UpTrackLayersChange() + UpTrackBitrateAvailabilityChange() + UpTrackMaxPublishedLayerChange(maxPublishedLayer int32) + UpTrackMaxTemporalLayerSeenChange(maxTemporalLayerSeen int32) + UpTrackBitrateReport(availableLayers []int32, bitrates Bitrates) + WriteRTP(p *buffer.ExtPacket, layer int32) int32 + Close() + IsClosed() bool + // ID is the globally unique identifier for this Track. + ID() string + SubscriberID() livekit.ParticipantID + HandleRTCPSenderReportData( + payloadType webrtc.PayloadType, + layer int32, + publisherSRData *livekit.RTCPSenderReportState, + ) error + Resync() + SetReceiver(TrackReceiver) + ReceiverRestart() +} + +// ------------------------------------------------------------------- + +const ( + RTPPaddingMaxPayloadSize = 255 + RTPPaddingEstimatedHeaderSize = 20 + RTPBlankFramesMuteSeconds = float32(1.0) + RTPBlankFramesCloseSeconds = float32(0.2) + + FlagStopRTXOnPLI = true + + keyFrameIntervalMin = 200 + keyFrameIntervalMax = 1000 + flushTimeout = 1 * time.Second + + waitBeforeSendPaddingOnMute = 100 * time.Millisecond + maxPaddingOnMuteDuration = 5 * time.Second + paddingOnMuteInterval = 100 * time.Millisecond +) + +// ------------------------------------------------------------------- + +var ( + errUnknownKind = errors.New("unknown kind of codec") + errOutOfOrderSequenceNumberCacheMiss = errors.New("out-of-order sequence number not found in cache") + errPaddingOnlyPacket = errors.New("padding only packet that need not be forwarded") + errDuplicatePacket = errors.New("duplicate packet") + errPaddingNotOnFrameBoundary = errors.New("padding cannot send on non-frame boundary") + errDownTrackAlreadyBound = errors.New("already bound") + errPayloadOverflow = errors.New("payload overflow") +) + +var ( + VP8KeyFrame8x8 = []byte{ + 0x10, 0x02, 0x00, 0x9d, 0x01, 0x2a, 0x08, 0x00, + 0x08, 0x00, 0x00, 0x47, 0x08, 0x85, 0x85, 0x88, + 0x85, 0x84, 0x88, 0x02, 0x02, 0x00, 0x0c, 0x0d, + 0x60, 0x00, 0xfe, 0xff, 0xab, 0x50, 0x80, + } + + H264KeyFrame2x2SPS = []byte{ + 0x67, 0x42, 0xc0, 0x1f, 0x0f, 0xd9, 0x1f, 0x88, + 0x88, 0x84, 0x00, 0x00, 0x03, 0x00, 0x04, 0x00, + 0x00, 0x03, 0x00, 0xc8, 0x3c, 0x60, 0xc9, 0x20, + } + H264KeyFrame2x2PPS = []byte{ + 0x68, 0x87, 0xcb, 0x83, 0xcb, 0x20, + } + H264KeyFrame2x2IDR = []byte{ + 0x65, 0x88, 0x84, 0x0a, 0xf2, 0x62, 0x80, 0x00, + 0xa7, 0xbe, + } + H264KeyFrame2x2 = [][]byte{H264KeyFrame2x2SPS, H264KeyFrame2x2PPS, H264KeyFrame2x2IDR} + + OpusSilenceFrame = []byte{ + 0xf8, 0xff, 0xfe, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + dummyAbsSendTimeExt, _ = rtp.NewAbsSendTimeExtension(mono.Now()).Marshal() + dummyTransportCCExt, _ = rtp.TransportCCExtension{TransportSequence: 12345}.Marshal() +) + +// ------------------------------------------------------------------- + +type DownTrackState struct { + RTPStats *rtpstats.RTPStatsSender + DeltaStatsSenderSnapshotId uint32 + RTPStatsRTX *rtpstats.RTPStatsSender + DeltaStatsRTXSenderSnapshotId uint32 + ForwarderState *livekit.RTPForwarderState + PlayoutDelayControllerState PlayoutDelayControllerState +} + +func (d DownTrackState) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddObject("RTPStats", d.RTPStats) + e.AddUint32("DeltaStatsSenderSnapshotId", d.DeltaStatsSenderSnapshotId) + e.AddObject("RTPStatsRTX", d.RTPStatsRTX) + e.AddUint32("DeltaStatsRTXSenderSnapshotId", d.DeltaStatsRTXSenderSnapshotId) + e.AddObject("ForwarderState", logger.Proto(d.ForwarderState)) + e.AddObject("PlayoutDelayControllerState", d.PlayoutDelayControllerState) + return nil +} + +// ------------------------------------------------------------------- + +type DownTrackStreamAllocatorListener interface { + // RTCP received + OnREMB(dt *DownTrack, remb *rtcp.ReceiverEstimatedMaximumBitrate) + OnTransportCCFeedback(dt *DownTrack, cc *rtcp.TransportLayerCC) + + // video layer availability changed + OnAvailableLayersChanged(dt *DownTrack) + + // video layer bitrate availability changed + OnBitrateAvailabilityChanged(dt *DownTrack) + + // max published spatial layer changed + OnMaxPublishedSpatialChanged(dt *DownTrack) + + // max published temporal layer changed + OnMaxPublishedTemporalChanged(dt *DownTrack) + + // subscription changed - mute/unmute + OnSubscriptionChanged(dt *DownTrack) + + // subscribed max video layer changed + OnSubscribedLayerChanged(dt *DownTrack, layers buffer.VideoLayer) + + // stream resumed + OnResume(dt *DownTrack) + + // check if track should participate in BWE + IsBWEEnabled(dt *DownTrack) bool + + // get the BWE type in use + BWEType() bwe.BWEType + + // check if subscription mute can be applied + IsSubscribeMutable(dt *DownTrack) bool +} + +// ------------------------------------------------------------------- + +type DownTrackListener interface { + OnBindAndConnected() + OnStatsUpdate(stat *livekit.AnalyticsStat) + OnMaxSubscribedLayerChanged(layer int32) + OnRttUpdate(rtt uint32) + OnCodecNegotiated(webrtc.RTPCodecCapability) + OnDownTrackClose(isExpectedToResume bool) +} + +// ------------------------------------------------------------------- + +type bindState int + +const ( + bindStateUnbound bindState = iota + // downtrack negotiated, but waiting for receiver to be ready to start forwarding + bindStateWaitForReceiverReady + // downtrack is bound and ready to forward + bindStateBound +) + +func (bs bindState) String() string { + switch bs { + case bindStateUnbound: + return "unbound" + case bindStateWaitForReceiverReady: + return "waitForReceiverReady" + case bindStateBound: + return "bound" + } + return "unknown" +} + +// ------------------------------------------------------------------- + +var _ TrackSender = (*DownTrack)(nil) + +type ReceiverReportListener func(dt *DownTrack, report *rtcp.ReceiverReport) + +type DownTrackParams struct { + Codecs []webrtc.RTPCodecParameters + IsEncrypted bool + Source livekit.TrackSource + Receiver TrackReceiver + BufferFactory *buffer.Factory + SubID livekit.ParticipantID + StreamID string + MaxTrack int + PlayoutDelayLimit *livekit.PlayoutDelay + Pacer pacer.Pacer + Logger logger.Logger + Trailer []byte + RTCPWriter func([]rtcp.Packet) error + DisableSenderReportPassThrough bool + SupportsCodecChange bool + Listener DownTrackListener +} + +// DownTrack implements webrtc.TrackLocal, is the track used to write packets +// to SFU Subscriber, the track handle the packets for simple, simulcast +// and SVC Publisher. +// A DownTrack has the following lifecycle +// - new +// - bound / unbound +// - closed +// once closed, a DownTrack cannot be re-used. +type DownTrack struct { + params DownTrackParams + id livekit.TrackID + kind webrtc.RTPCodecType + ssrc uint32 + ssrcRTX uint32 + payloadType atomic.Uint32 + payloadTypeRTX atomic.Uint32 + sequencer *sequencer + rtxSequenceNumber atomic.Uint64 + + receiverLock sync.RWMutex + receiver TrackReceiver + + forwarder *Forwarder + + upstreamCodecs []webrtc.RTPCodecParameters + codec atomic.Value // webrtc.RTPCodecCapability + clockRate uint32 + negotiatedCodecParameters []webrtc.RTPCodecParameters + + // payload types for red codec only + isRED bool + upstreamPrimaryPT uint8 + primaryPT uint8 + + absSendTimeExtID int + transportWideExtID int + dependencyDescriptorExtID int + playoutDelayExtID int + absCaptureTimeExtID int + transceiver atomic.Pointer[webrtc.RTPTransceiver] + writeStream webrtc.TrackLocalWriter + rtcpReader *buffer.RTCPReader + rtcpReaderRTX *buffer.RTCPReader + + listenerLock sync.RWMutex + receiverReportListeners []ReceiverReportListener + + bindLock sync.Mutex + bindState atomic.Value + onBinding func(error) + bindOnReceiverReady func() + + isClosed atomic.Bool + connected atomic.Bool + bindAndConnectedOnce atomic.Bool + writable atomic.Bool + writeStopped atomic.Bool + isReceiverReady bool + + rtpStats *rtpstats.RTPStatsSender + deltaStatsSenderSnapshotId uint32 + + rtpStatsRTX *rtpstats.RTPStatsSender + deltaStatsRTXSenderSnapshotId uint32 + + totalRepeatedNACKs atomic.Uint32 + + blankFramesGeneration atomic.Uint32 + + connectionStats *connectionquality.ConnectionStats + + isNACKThrottled atomic.Bool + + activePaddingOnMuteUpTrack atomic.Bool + + streamAllocatorLock sync.RWMutex + streamAllocatorListener DownTrackStreamAllocatorListener + probeClusterId atomic.Uint32 + + playoutDelay *PlayoutDelayController + + pacer pacer.Pacer + + maxLayerNotifierChMu sync.RWMutex + maxLayerNotifierCh chan string + maxLayerNotifierChClosed bool + + keyFrameRequesterChMu sync.RWMutex + keyFrameRequesterCh chan struct{} + keyFrameRequesterChClosed bool + + createdAt int64 +} + +// NewDownTrack returns a DownTrack. +func NewDownTrack(params DownTrackParams) (*DownTrack, error) { + mimeType := mime.NormalizeMimeType(params.Codecs[0].MimeType) + var kind webrtc.RTPCodecType + switch { + case mime.IsMimeTypeAudio(mimeType): + kind = webrtc.RTPCodecTypeAudio + case mime.IsMimeTypeVideo(mimeType): + kind = webrtc.RTPCodecTypeVideo + default: + kind = webrtc.RTPCodecType(0) + } + + codec := params.Codecs[0].RTPCodecCapability + d := &DownTrack{ + params: params, + id: params.Receiver.TrackID(), + upstreamCodecs: params.Codecs, + kind: kind, + clockRate: codec.ClockRate, + pacer: params.Pacer, + maxLayerNotifierCh: make(chan string, 1), + keyFrameRequesterCh: make(chan struct{}, 1), + createdAt: time.Now().UnixNano(), + receiver: params.Receiver, + } + d.codec.Store(codec) + d.bindState.Store(bindStateUnbound) + d.params.Logger = params.Logger.WithValues( + "subscriberID", d.SubscriberID(), + ) + + var mdCacheSize, mdCacheSizeRTX int + if d.kind == webrtc.RTPCodecTypeVideo { + mdCacheSize, mdCacheSizeRTX = 8192, 8192 + } else { + mdCacheSize, mdCacheSizeRTX = 8192, 1024 + } + d.rtpStats = rtpstats.NewRTPStatsSender(rtpstats.RTPStatsParams{ + ClockRate: codec.ClockRate, + Logger: d.params.Logger.WithValues( + "stream", "primary", + ), + }, mdCacheSize) + d.deltaStatsSenderSnapshotId = d.rtpStats.NewSenderSnapshotId() + + d.rtpStatsRTX = rtpstats.NewRTPStatsSender(rtpstats.RTPStatsParams{ + ClockRate: codec.ClockRate, + IsRTX: true, + Logger: d.params.Logger.WithValues( + "stream", "rtx", + ), + }, mdCacheSizeRTX) + d.deltaStatsRTXSenderSnapshotId = d.rtpStatsRTX.NewSenderSnapshotId() + + d.forwarder = NewForwarder( + d.kind, + d.params.Logger, + false, // skipReferenceTS + false, // disableOpportunisticAllocation + d.rtpStats, + ) + + d.connectionStats = connectionquality.NewConnectionStats(connectionquality.ConnectionStatsParams{ + SenderProvider: d, + Logger: d.params.Logger.WithValues("direction", "down"), + }) + d.connectionStats.OnStatsUpdate(func(_cs *connectionquality.ConnectionStats, stat *livekit.AnalyticsStat) { + d.params.Listener.OnStatsUpdate(stat) + }) + + if d.kind == webrtc.RTPCodecTypeVideo { + if delay := params.PlayoutDelayLimit; delay.GetEnabled() { + var err error + d.playoutDelay, err = NewPlayoutDelayController(delay.GetMin(), delay.GetMax(), params.Logger, d.rtpStats) + if err != nil { + return nil, err + } + } + go d.maxLayerNotifierWorker() + go d.keyFrameRequester() + } + + d.params.Receiver.AddOnReady(d.handleReceiverReady) + d.rtxSequenceNumber.Store(uint64(rand.Intn(1<<14)) + uint64(1<<15)) // a random number in third quartile of sequence number space + d.params.Logger.Debugw("downtrack created", "upstreamCodecs", d.upstreamCodecs) + + return d, nil +} + +// Bind is called by the PeerConnection after negotiation is complete +// This asserts that the code requested is supported by the remote peer. +// If so it sets up all the state (SSRC and PayloadType) to have a call +func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, error) { + d.bindLock.Lock() + if d.bindState.Load() != bindStateUnbound { + d.bindLock.Unlock() + return webrtc.RTPCodecParameters{}, errDownTrackAlreadyBound + } + + // the TrackLocalContext's codec parameters will be set to the bound codec after Bind returns, + // so keep a copy of the codec parameters here to use it later + d.negotiatedCodecParameters = append([]webrtc.RTPCodecParameters{}, t.CodecParameters()...) + var codec, matchedUpstreamCodec webrtc.RTPCodecParameters + for _, c := range d.upstreamCodecs { + matchCodec, err := utils.CodecParametersFuzzySearch(c, d.negotiatedCodecParameters) + if err == nil { + codec = matchCodec + matchedUpstreamCodec = c + break + } else { + // for encrypyted tracks, should match on primary codec, + // i. e. codec at index 0 if the combination of upstream codecs is opus and RED + if d.params.IsEncrypted { + isRedAndOpus := true + for _, u := range d.upstreamCodecs { + if !mime.IsMimeTypeStringOpus(u.MimeType) || !mime.IsMimeTypeStringRED(u.MimeType) { + isRedAndOpus = false + break + } + } + if isRedAndOpus { + break + } + } + } + } + + if codec.MimeType == "" { + err := webrtc.ErrUnsupportedCodec + onBinding := d.onBinding + d.bindLock.Unlock() + d.params.Logger.Infow( + "bind error for unsupported codec", + "codecs", d.upstreamCodecs, + "remoteParameters", d.negotiatedCodecParameters, + ) + if onBinding != nil { + onBinding(err) + } + // don't return error here, as pion will not start transports if Bind fails at first answer + return webrtc.RTPCodecParameters{}, nil + } + + // if a downtrack is closed before bind, it already unsubscribed from client, don't do subsequent operation and return here. + if d.IsClosed() { + d.params.Logger.Debugw("DownTrack closed before bind") + d.bindLock.Unlock() + return codec, nil + } + + // Bind is called under RTPSender.mu lock, + // call the RTPSender.GetParameters (which setRTPHeaderExtensions invokes) + // in goroutine to avoid deadlock + go d.setRTPHeaderExtensions() + + doBind := func() { + d.bindLock.Lock() + if d.IsClosed() { + d.bindLock.Unlock() + d.params.Logger.Debugw("DownTrack closed before bind") + return + } + + isFECEnabled := false + if mime.IsMimeTypeStringRED(matchedUpstreamCodec.MimeType) { + d.isRED = true + for _, c := range d.upstreamCodecs { + isFECEnabled = strings.Contains(strings.ToLower(c.SDPFmtpLine), "useinbandfec=1") + + // assume upstream primary codec is opus since we only support it for audio now + if mime.IsMimeTypeStringOpus(c.MimeType) { + d.upstreamPrimaryPT = uint8(c.PayloadType) + break + } + } + if d.upstreamPrimaryPT == 0 { + d.params.Logger.Errorw( + "failed to find upstream primary opus payload type for RED", nil, + "matchedCodec", codec, + "upstreamCodec", d.upstreamCodecs, + ) + } + + var primaryPT, secondaryPT int + if n, err := fmt.Sscanf(codec.SDPFmtpLine, "%d/%d", &primaryPT, &secondaryPT); err != nil || n != 2 { + d.params.Logger.Errorw( + "failed to parse primary and secondary payload type for RED", err, + "matchedCodec", codec, + ) + } + d.primaryPT = uint8(primaryPT) + } else if mime.IsMimeTypeStringAudio(matchedUpstreamCodec.MimeType) { + isFECEnabled = strings.Contains(strings.ToLower(matchedUpstreamCodec.SDPFmtpLine), "fec") + } + + logFields := []any{ + "codecs", d.upstreamCodecs, + "matchCodec", codec, + "ssrc", t.SSRC(), + "ssrcRTX", t.SSRCRetransmission(), + "isFECEnabled", isFECEnabled, + } + if d.isRED { + logFields = append( + logFields, + "isRED", d.isRED, + "upstreamPrimaryPT", d.upstreamPrimaryPT, + "primaryPT", d.primaryPT, + ) + } + + d.ssrc = uint32(t.SSRC()) + d.ssrcRTX = uint32(t.SSRCRetransmission()) + d.payloadType.Store(uint32(codec.PayloadType)) + d.payloadTypeRTX.Store(uint32(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters))) + logFields = append( + logFields, + "payloadType", d.payloadType.Load(), + "payloadTypeRTX", d.payloadTypeRTX.Load(), + "codecParameters", d.negotiatedCodecParameters, + ) + d.params.Logger.Debugw("DownTrack.Bind", logFields...) + + d.writeStream = t.WriteStream() + if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, d.ssrc).(*buffer.RTCPReader); rr != nil { + rr.OnPacket(func(pkt []byte) { + d.handleRTCP(pkt) + }) + d.rtcpReader = rr + } + if d.ssrcRTX != 0 { + if rr := d.params.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, d.ssrcRTX).(*buffer.RTCPReader); rr != nil { + rr.OnPacket(func(pkt []byte) { + d.handleRTCPRTX(pkt) + }) + d.rtcpReaderRTX = rr + } + } + + d.sequencer = newSequencer(d.params.MaxTrack, d.kind == webrtc.RTPCodecTypeVideo, d.params.Logger) + + d.codec.Store(codec.RTPCodecCapability) + if d.onBinding != nil { + d.onBinding(nil) + } + d.setBindStateLocked(bindStateBound) + d.bindLock.Unlock() + + receiver := d.Receiver() + d.forwarder.DetermineCodec(codec.RTPCodecCapability, receiver.HeaderExtensions(), receiver.VideoLayerMode()) + d.connectionStats.Start(d.Mime(), isFECEnabled) + d.params.Logger.Debugw("downtrack bound") + } + + isReceiverReady := d.isReceiverReady + if !isReceiverReady { + d.params.Logger.Debugw("downtrack bound: receiver not ready", "codec", codec) + d.bindOnReceiverReady = doBind + d.setBindStateLocked(bindStateWaitForReceiverReady) + } + d.bindLock.Unlock() + + d.params.Listener.OnCodecNegotiated(codec.RTPCodecCapability) + + if isReceiverReady { + doBind() + } + return codec, nil +} + +func (d *DownTrack) setBindStateLocked(state bindState) { + if d.bindState.Swap(state) == state { + return + } + + if state == bindStateBound || state == bindStateUnbound { + d.bindOnReceiverReady = nil + d.onBindAndConnectedChange() + } +} + +func (d *DownTrack) handleReceiverReady() { + d.bindLock.Lock() + if d.isReceiverReady { + d.bindLock.Unlock() + return + } + d.params.Logger.Debugw("downtrack receiver ready") + d.isReceiverReady = true + doBind := d.bindOnReceiverReady + d.bindOnReceiverReady = nil + d.bindLock.Unlock() + + if doBind != nil { + doBind() + } +} + +func (d *DownTrack) handleUpstreamCodecChange(mimeType string) { + d.bindLock.Lock() + existingMimeType := d.codec.Load().(webrtc.RTPCodecCapability).MimeType + if mime.IsMimeTypeStringEqual(existingMimeType, mimeType) { + d.bindLock.Unlock() + return + } + + if !d.params.SupportsCodecChange { + d.bindLock.Unlock() + d.params.Logger.Infow("client doesn't support codec change, renegotiate new codec") + go d.Close() + return + } + + oldPT, oldRtxPT, oldCodec := d.payloadType.Load(), d.payloadTypeRTX.Load(), d.codec.Load().(webrtc.RTPCodecCapability) + + var codec webrtc.RTPCodecParameters + for _, c := range d.upstreamCodecs { + if !mime.IsMimeTypeStringEqual(c.MimeType, mimeType) { + continue + } + + matchCodec, err := utils.CodecParametersFuzzySearch(c, d.negotiatedCodecParameters) + if err == nil { + codec = matchCodec + break + } + } + + if codec.MimeType == "" { + // codec not found, should not happen since the upstream codec should only fall back to higher compatibility (vp8) + d.params.Logger.Errorw( + "can't find matched codec for new upstream payload type", nil, + "upstreamCodecs", d.upstreamCodecs, + "remoteParameters", d.negotiatedCodecParameters, + "mime", mimeType, + ) + d.bindLock.Unlock() + return + } + + d.payloadType.Store(uint32(codec.PayloadType)) + d.payloadTypeRTX.Store(uint32(utils.FindRTXPayloadType(codec.PayloadType, d.negotiatedCodecParameters))) + d.codec.Store(codec.RTPCodecCapability) + isFECEnabled := strings.Contains(strings.ToLower(codec.SDPFmtpLine), "fec") + d.bindLock.Unlock() + + d.params.Logger.Infow( + "upstream codec changed", + "oldPT", oldPT, "newPT", d.payloadType.Load(), + "oldRTXPT", oldRtxPT, "newRTXPT", d.payloadTypeRTX.Load(), + "oldCodec", oldCodec, "newCodec", codec.RTPCodecCapability, + ) + + receiver := d.Receiver() + d.forwarder.Restart() + d.forwarder.DetermineCodec(codec.RTPCodecCapability, receiver.HeaderExtensions(), receiver.VideoLayerMode()) + + d.connectionStats.UpdateCodec(d.Mime(), isFECEnabled) +} + +// Unbind implements the teardown logic when the track is no longer needed. This happens +// because a track has been stopped. +func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { + d.bindLock.Lock() + d.setBindStateLocked(bindStateUnbound) + d.bindLock.Unlock() + return nil +} + +func (d *DownTrack) SetStreamAllocatorListener(listener DownTrackStreamAllocatorListener) { + d.streamAllocatorLock.Lock() + d.streamAllocatorListener = listener + d.streamAllocatorLock.Unlock() + + d.setRTPHeaderExtensions() + + if listener != nil { + // kick off a gratuitous allocation + listener.OnSubscriptionChanged(d) + } +} + +func (d *DownTrack) getStreamAllocatorListener() DownTrackStreamAllocatorListener { + d.streamAllocatorLock.RLock() + defer d.streamAllocatorLock.RUnlock() + + return d.streamAllocatorListener +} + +func (d *DownTrack) SetProbeClusterId(probeClusterId ccutils.ProbeClusterId) { + d.probeClusterId.Store(uint32(probeClusterId)) +} + +func (d *DownTrack) SwapProbeClusterId(match ccutils.ProbeClusterId, swap ccutils.ProbeClusterId) { + d.probeClusterId.CompareAndSwap(uint32(match), uint32(swap)) +} + +// ID is the unique identifier for this Track. This should be unique for the +// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' +// and StreamID would be 'desktop' or 'webcam' +func (d *DownTrack) ID() string { return string(d.id) } + +// Codec returns current track codec capability +func (d *DownTrack) Codec() webrtc.RTPCodecCapability { + return d.codec.Load().(webrtc.RTPCodecCapability) +} + +func (d *DownTrack) Mime() mime.MimeType { + return mime.NormalizeMimeType(d.codec.Load().(webrtc.RTPCodecCapability).MimeType) +} + +// StreamID is the group this track belongs too. This must be unique +func (d *DownTrack) StreamID() string { return d.params.StreamID } + +func (d *DownTrack) SubscriberID() livekit.ParticipantID { + // add `createdAt` to ensure repeated subscriptions from same subscriber to same publisher does not collide + return livekit.ParticipantID(fmt.Sprintf("%s:%d", d.params.SubID, d.createdAt)) +} + +func (d *DownTrack) Receiver() TrackReceiver { + d.receiverLock.RLock() + defer d.receiverLock.RUnlock() + return d.receiver +} + +func (d *DownTrack) SetReceiver(r TrackReceiver) { + d.params.Logger.Debugw("downtrack set receiver", "codec", r.Codec()) + d.bindLock.Lock() + if d.IsClosed() { + d.bindLock.Unlock() + return + } + + d.receiverLock.Lock() + old := d.receiver + d.receiver = r + d.receiverLock.Unlock() + + old.DeleteDownTrack(d.SubscriberID()) + d.bindLock.Unlock() + + r.AddOnReady(d.handleReceiverReady) + d.handleUpstreamCodecChange(r.Codec().MimeType) + + d.bindLock.Lock() + if err := r.AddDownTrack(d); err != nil { + d.params.Logger.Warnw("failed to add downtrack to receiver", err) + } + d.bindLock.Unlock() + + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnSubscribedLayerChanged(d, d.forwarder.MaxLayer()) + } +} + +// Sets RTP header extensions for this track +func (d *DownTrack) setRTPHeaderExtensions() { + sal := d.getStreamAllocatorListener() + if sal == nil { + return + } + isBWEEnabled := sal.IsBWEEnabled(d) + bweType := sal.BWEType() + + tr := d.transceiver.Load() + if tr == nil { + return + } + var extensions []webrtc.RTPHeaderExtensionParameter + if sender := tr.Sender(); sender != nil { + extensions = sender.GetParameters().HeaderExtensions + d.params.Logger.Debugw("negotiated downtrack extensions", "extensions", extensions) + } + + d.bindLock.Lock() + for _, ext := range extensions { + switch ext.URI { + case sdp.ABSSendTimeURI: + if isBWEEnabled && bweType == bwe.BWETypeRemote { + d.absSendTimeExtID = ext.ID + } else { + d.absSendTimeExtID = 0 + } + case dd.ExtensionURI: + d.dependencyDescriptorExtID = ext.ID + case pd.PlayoutDelayURI: + d.playoutDelayExtID = ext.ID + case sdp.TransportCCURI: + if isBWEEnabled && bweType == bwe.BWETypeSendSide { + d.transportWideExtID = ext.ID + } else { + d.transportWideExtID = 0 + } + case act.AbsCaptureTimeURI: + d.absCaptureTimeExtID = ext.ID + } + } + d.params.Logger.Debugw( + "negotiated extension ids", + "absSendTimeExtID", d.absSendTimeExtID, + "dependencyDescriptorExtID", d.dependencyDescriptorExtID, + "playoutDelayExtID", d.playoutDelayExtID, + "transportWideExtID", d.transportWideExtID, + "absCaptureTimeExtID", d.absCaptureTimeExtID, + ) + d.bindLock.Unlock() +} + +// Kind controls if this TrackLocal is audio or video +func (d *DownTrack) Kind() webrtc.RTPCodecType { + return d.kind +} + +// RID is required by `webrtc.TrackLocal` interface +func (d *DownTrack) RID() string { + return "" +} + +func (d *DownTrack) SSRC() uint32 { + return d.ssrc +} + +func (d *DownTrack) SSRCRTX() uint32 { + return d.ssrcRTX +} + +func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { + d.transceiver.Store(transceiver) + d.setRTPHeaderExtensions() +} + +func (d *DownTrack) GetTransceiver() *webrtc.RTPTransceiver { + return d.transceiver.Load() +} + +func (d *DownTrack) postKeyFrameRequestEvent() { + if d.kind != webrtc.RTPCodecTypeVideo { + return + } + + d.keyFrameRequesterChMu.RLock() + if !d.keyFrameRequesterChClosed { + select { + case d.keyFrameRequesterCh <- struct{}{}: + default: + } + } + d.keyFrameRequesterChMu.RUnlock() +} + +func (d *DownTrack) keyFrameRequester() { + getInterval := func() time.Duration { + interval := min(max(2*d.rtpStats.GetRtt(), keyFrameIntervalMin), keyFrameIntervalMax) + return time.Duration(interval) * time.Millisecond + } + + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + + defer timer.Stop() + + for !d.IsClosed() { + timer.Reset(getInterval()) + + select { + case _, more := <-d.keyFrameRequesterCh: + if !more { + return + } + if !timer.Stop() { + <-timer.C + } + case <-timer.C: + } + + locked, layer := d.forwarder.CheckSync() + if !locked && layer != buffer.InvalidLayerSpatial && d.writable.Load() { + d.params.Logger.Debugw("sending PLI for layer lock", "layer", layer) + d.Receiver().SendPLI(layer, false) + d.rtpStats.UpdateLayerLockPliAndTime(1) + } + } +} + +func (d *DownTrack) postMaxLayerNotifierEvent(event string) { + if d.kind != webrtc.RTPCodecTypeVideo { + return + } + + d.maxLayerNotifierChMu.RLock() + if !d.maxLayerNotifierChClosed { + select { + case d.maxLayerNotifierCh <- event: + default: + d.params.Logger.Debugw("max layer notifier channel busy", "event", event) + } + } + d.maxLayerNotifierChMu.RUnlock() +} + +func (d *DownTrack) maxLayerNotifierWorker() { + for event := range d.maxLayerNotifierCh { + maxLayerSpatial := d.forwarder.GetMaxSubscribedSpatial() + d.params.Logger.Debugw("max subscribed layer processed", "layer", maxLayerSpatial, "event", event) + + d.params.Logger.Debugw( + "notifying max subscribed layer", + "layer", maxLayerSpatial, + "event", event, + ) + d.params.Listener.OnMaxSubscribedLayerChanged(maxLayerSpatial) + } + + d.params.Logger.Debugw( + "notifying max subscribed layer", + "layer", buffer.InvalidLayerSpatial, + "event", "close", + ) + d.params.Listener.OnMaxSubscribedLayerChanged(buffer.InvalidLayerSpatial) +} + +// WriteRTP writes an RTP Packet to the DownTrack +func (d *DownTrack) WriteRTP(extPkt *buffer.ExtPacket, layer int32) int32 { + if !d.writable.Load() { + return 0 + } + + tp, err := d.forwarder.GetTranslationParams(extPkt, layer) + if tp.shouldDrop { + if err != nil { + d.params.Logger.Errorw("could not get translation params", err) + } + return 0 + } + + poolEntity := PacketFactory.Get().(*[]byte) + payload := *poolEntity + copy(payload, tp.codecBytes) + n := copy(payload[len(tp.codecBytes):], extPkt.Packet.Payload[tp.incomingHeaderSize:]) + if n != len(extPkt.Packet.Payload[tp.incomingHeaderSize:]) { + d.params.Logger.Errorw( + "payload overflow", errPayloadOverflow, + "want", len(extPkt.Packet.Payload[tp.incomingHeaderSize:]), + "have", n, + ) + PacketFactory.Put(poolEntity) + return 0 + } + payload = payload[:len(tp.codecBytes)+n] + + // translate RTP header + hdr := RTPHeaderFactory.Get().(*rtp.Header) + *hdr = rtp.Header{ + Version: extPkt.Packet.Version, + Padding: extPkt.Packet.Padding, + Marker: tp.marker, + PayloadType: d.getTranslatedPayloadType(extPkt.Packet.PayloadType), + SequenceNumber: uint16(tp.rtp.extSequenceNumber), + Timestamp: uint32(tp.rtp.extTimestamp), + SSRC: d.ssrc, + } + + // add extensions + if d.dependencyDescriptorExtID != 0 && tp.ddBytes != nil { + hdr.SetExtension(uint8(d.dependencyDescriptorExtID), tp.ddBytes) + } + if d.playoutDelayExtID != 0 && d.playoutDelay != nil { + if val := d.playoutDelay.GetDelayExtension(hdr.SequenceNumber); val != nil { + hdr.SetExtension(uint8(d.playoutDelayExtID), val) + + // NOTE: play out delay extension is not cached in sequencer, + // i. e. they will not be added to retransmitted packet. + // But, it is okay as the extension is added till a RTCP Receiver Report for + // the corresponding sequence number is received. + // The extreme case is all packets containing the play out delay are lost and + // all of them retransmitted and an RTCP Receiver Report received for those + // retransmitted sequence numbers. But, that is highly improbable, if not impossible. + } + } + var actBytes []byte + if extPkt.AbsCaptureTimeExt != nil && d.absCaptureTimeExtID != 0 { + // normalize capture time to SFU clock. + // NOTE: even if there is estimated offset populated, just re-map the + // absolute capture time stamp as it should be the same RTCP sender report + // clock domain of publisher. SFU is normalising sender reports of publisher + // to SFU clock before sending to subscribers. So, capture time should be + // normalized to the same clock. Clear out any offset. + _, _, _, refSenderReport := d.forwarder.GetSenderReportParams() + if refSenderReport != nil { + actExtCopy := *extPkt.AbsCaptureTimeExt + if err = actExtCopy.Rewrite( + rtpstats.RTCPSenderReportPropagationDelay( + refSenderReport, + !d.params.DisableSenderReportPassThrough, + ), + ); err == nil { + actBytes, err = actExtCopy.Marshal() + if err == nil { + hdr.SetExtension(uint8(d.absCaptureTimeExtID), actBytes) + } + } + } + } + d.addDummyExtensions(hdr) + + if d.sequencer != nil { + d.sequencer.push( + extPkt.Arrival, + extPkt.ExtSequenceNumber, + tp.rtp.extSequenceNumber, + tp.rtp.extTimestamp, + hdr.Marker, + int8(layer), + payload[:len(tp.codecBytes)], + tp.incomingHeaderSize, + tp.ddBytes, + actBytes, + ) + } + + headerSize := hdr.MarshalSize() + d.rtpStats.Update( + extPkt.Arrival, + tp.rtp.extSequenceNumber, + tp.rtp.extTimestamp, + hdr.Marker, + headerSize, + len(payload), + 0, + extPkt.IsOutOfOrder, + ) + pacerPacket := pacer.PacketFactory.Get().(*pacer.Packet) + *pacerPacket = pacer.Packet{ + Header: hdr, + HeaderPool: RTPHeaderFactory, + HeaderSize: headerSize, + Payload: payload, + ProbeClusterId: ccutils.ProbeClusterId(d.probeClusterId.Load()), + AbsSendTimeExtID: uint8(d.absSendTimeExtID), + TransportWideExtID: uint8(d.transportWideExtID), + WriteStream: d.writeStream, + Pool: PacketFactory, + PoolEntity: poolEntity, + } + d.pacer.Enqueue(pacerPacket) + + if extPkt.IsKeyFrame { + d.isNACKThrottled.Store(false) + d.rtpStats.UpdateKeyFrame(1) + d.params.Logger.Debugw( + "forwarded key frame", + "layer", layer, + "rtpsn", tp.rtp.extSequenceNumber, + "rtpts", tp.rtp.extTimestamp, + ) + } + + if tp.isSwitching { + d.postMaxLayerNotifierEvent("switching") + } + + if tp.isResuming { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnResume(d) + } + } + return 1 +} + +// WritePaddingRTP tries to write as many padding only RTP packets as necessary +// to satisfy given size to the DownTrack +func (d *DownTrack) WritePaddingRTP(bytesToSend int, paddingOnMute bool, forceMarker bool) int { + if !d.writable.Load() { + return 0 + } + + if !paddingOnMute { + if !d.rtpStats.IsActive() { + return 0 + } + + // Ideally should look at header extensions negotiated for + // track and decide if padding can be sent. But, browsers behave + // in unexpected ways when using audio for bandwidth estimation and + // padding is mainly used to probe for excess available bandwidth. + // So, to be safe, limit to video tracks + if d.kind == webrtc.RTPCodecTypeAudio { + return 0 + } + + // LK-TODO-START + // Potentially write padding even if muted. Given that padding + // can be sent only on frame boundaries, writing on disabled tracks + // will give more options. + // LK-TODO-END + if d.forwarder.IsMuted() { + return 0 + } + + // Hold sending padding packets till first RTCP-RR is received for this RTP stream. + // That is definitive proof that the remote side knows about this RTP stream. + if d.rtpStats.LastReceiverReportTime() == 0 { + return 0 + } + } + + // RTP padding maximum is 255 bytes. Break it up. + // Use 20 byte as estimate of RTP header size (12 byte header + 8 byte extension) + num := (bytesToSend + RTPPaddingMaxPayloadSize + RTPPaddingEstimatedHeaderSize - 1) / (RTPPaddingMaxPayloadSize + RTPPaddingEstimatedHeaderSize) + if num == 0 { + return 0 + } + + frameRate := uint32(0) + if paddingOnMute { + // advance timestamps when sending dummy padding packets to start a stream + // to ensure receiver sees proper timestamp and starts the stream + frameRate = uint32(time.Second / paddingOnMuteInterval) + } + + snts, err := d.forwarder.GetSnTsForPadding(num, frameRate, forceMarker) + if err != nil { + return 0 + } + + // + // Register with sequencer as padding only so that NACKs for these can be filtered out. + // Retransmission is probably a sign of network congestion/badness. + // So, retransmitting padding only packets is only going to make matters worse. + // + if d.sequencer != nil { + d.sequencer.pushPadding(snts[0].extSequenceNumber, snts[len(snts)-1].extSequenceNumber) + } + + bytesSent := 0 + payloads := make([]byte, RTPPaddingMaxPayloadSize*len(snts)) + for i := range snts { + hdr := RTPHeaderFactory.Get().(*rtp.Header) + *hdr = rtp.Header{ + Version: 2, + Padding: true, + Marker: false, + PayloadType: uint8(d.payloadType.Load()), + SequenceNumber: uint16(snts[i].extSequenceNumber), + Timestamp: uint32(snts[i].extTimestamp), + SSRC: d.ssrc, + } + d.addDummyExtensions(hdr) + + payload := payloads[i*RTPPaddingMaxPayloadSize : (i+1)*RTPPaddingMaxPayloadSize : (i+1)*RTPPaddingMaxPayloadSize] + // last byte of padding has padding size including that byte + payload[RTPPaddingMaxPayloadSize-1] = byte(RTPPaddingMaxPayloadSize) + + hdrSize := hdr.MarshalSize() + payloadSize := len(payload) + d.rtpStats.Update( + mono.UnixNano(), + snts[i].extSequenceNumber, + snts[i].extTimestamp, + hdr.Marker, + hdrSize, + 0, + payloadSize, + false, + ) + + pacerPacket := pacer.PacketFactory.Get().(*pacer.Packet) + *pacerPacket = pacer.Packet{ + Header: hdr, + HeaderPool: RTPHeaderFactory, + HeaderSize: hdrSize, + Payload: payload, + ProbeClusterId: ccutils.ProbeClusterId(d.probeClusterId.Load()), + IsProbe: true, + AbsSendTimeExtID: uint8(d.absSendTimeExtID), + TransportWideExtID: uint8(d.transportWideExtID), + WriteStream: d.writeStream, + } + d.pacer.Enqueue(pacerPacket) + + bytesSent += hdrSize + payloadSize + } + + return bytesSent +} + +// Mute enables or disables media forwarding - subscriber triggered +func (d *DownTrack) Mute(muted bool) { + isSubscribeMutable := true + if sal := d.getStreamAllocatorListener(); sal != nil { + isSubscribeMutable = sal.IsSubscribeMutable(d) + } + changed := d.forwarder.Mute(muted, isSubscribeMutable) + d.handleMute(muted, changed) +} + +// PubMute enables or disables media forwarding - publisher side +func (d *DownTrack) PubMute(pubMuted bool) { + changed := d.forwarder.PubMute(pubMuted) + d.handleMute(pubMuted, changed) +} + +func (d *DownTrack) handleMute(muted bool, changed bool) { + if !changed { + return + } + + d.connectionStats.UpdateMute(d.forwarder.IsAnyMuted()) + + // + // Subscriber mute changes trigger a max layer notification. + // That could result in encoding layers getting turned on/off on publisher side + // (depending on aggregate layer requirements of all subscribers of the track). + // + // Publisher mute changes should not trigger notification. + // If publisher turns off all layers because of subscribers indicating + // no layers required due to publisher mute (bit of circular dependency), + // there will be a delay in layers turning back on when unmute happens. + // Unmute path will require + // 1. unmute signalling out-of-band from publisher received by downtrack(s) + // 2. downtrack(s) notifying max layer + // 3. out-of-band notification about max layer sent back to the publisher + // 4. publisher starts layer(s) + // Ideally, on publisher mute, whatever layers were active remain active and + // can be restarted by publisher immediately on unmute. + // + // Note that while publisher mute is active, subscriber changes can also happen + // and that could turn on/off layers on publisher side. + // + d.postMaxLayerNotifierEvent("mute") + + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnSubscriptionChanged(d) + } + + // when muting, send a few silence frames to ensure residual noise does not + // put the comfort noise generator on decoder side in a bad state where it + // generates noise that is not so comfortable. + // + // One possibility is not to inject blank frames when publisher is muted + // and let forwarding continue. When publisher is muted, unless the media + // stream is stopped, publisher will send silence frames which should have + // comfort noise information. But, in case the publisher stops at an + // inopportune frame (due to media stream stop or injecting audio from a file), + // the decoder could be in a noisy state. So, inject blank frames on publisher + // mute too. + d.blankFramesGeneration.Inc() + if d.kind == webrtc.RTPCodecTypeAudio && muted { + d.writeBlankFrameRTP(RTPBlankFramesMuteSeconds, d.blankFramesGeneration.Load()) + } +} + +func (d *DownTrack) IsClosed() bool { + return d.isClosed.Load() +} + +func (d *DownTrack) Close() { + d.CloseWithFlush(true, true) +} + +// CloseWithFlush - flush used to indicate whether send blank frame to flush +// decoder of client. +// 1. When transceiver is reused by other participant's video track, +// set flush=true to avoid previous video shows before new stream is displayed. +// 2. in case of session migration, participant migrate from other node, video track should +// be resumed with same participant, set flush=false since we don't need to flush decoder. +func (d *DownTrack) CloseWithFlush(flush bool, isEnding bool) { + d.bindLock.Lock() + if d.isClosed.Swap(true) { + // already closed + d.bindLock.Unlock() + return + } + + d.params.Logger.Debugw("close downtrack", "flushBlankFrame", flush) + if d.bindState.Load() == bindStateBound { + d.forwarder.Mute(true, true) + + // write blank frames after disabling so that other frames do not interfere. + // Idea here is to send blank key frames to flush the decoder buffer at the remote end. + // Otherwise, with transceiver re-use last frame from previous stream is held in the + // display buffer and there could be a brief moment where the previous stream is displayed. + if flush { + doneFlushing := d.writeBlankFrameRTP(RTPBlankFramesCloseSeconds, d.blankFramesGeneration.Inc()) + + // wait a limited time to flush + timer := time.NewTimer(flushTimeout) + defer timer.Stop() + + select { + case <-doneFlushing: + case <-timer.C: + d.blankFramesGeneration.Inc() // in case flush is still running + } + } + + d.params.Logger.Debugw("closing sender", "kind", d.kind) + } + + d.setBindStateLocked(bindStateUnbound) + d.Receiver().DeleteDownTrack(d.SubscriberID()) + + if d.rtcpReader != nil && isEnding { + d.params.Logger.Debugw("downtrack close rtcp reader") + d.rtcpReader.Close() + d.rtcpReader.OnPacket(nil) + } + if d.rtcpReaderRTX != nil && isEnding { + d.params.Logger.Debugw("downtrack close rtcp rtx reader") + d.rtcpReaderRTX.Close() + d.rtcpReaderRTX.OnPacket(nil) + } + d.bindLock.Unlock() + + d.connectionStats.Close() + + d.rtpStats.Stop() + d.rtpStatsRTX.Stop() + d.params.Logger.Debugw( + "rtp stats", + "direction", "downstream", + "mime", d.Mime().String(), + "ssrc", d.ssrc, + "stats", d.rtpStats, + "statsRTX", d.rtpStatsRTX, + ) + + d.maxLayerNotifierChMu.Lock() + d.maxLayerNotifierChClosed = true + close(d.maxLayerNotifierCh) + d.maxLayerNotifierChMu.Unlock() + + d.keyFrameRequesterChMu.Lock() + d.keyFrameRequesterChClosed = true + close(d.keyFrameRequesterCh) + d.keyFrameRequesterChMu.Unlock() + + d.params.Listener.OnDownTrackClose(!isEnding) +} + +func (d *DownTrack) SetMaxSpatialLayer(spatialLayer int32) { + changed, maxLayer := d.forwarder.SetMaxSpatialLayer(spatialLayer) + if !changed { + return + } + + d.postMaxLayerNotifierEvent("max-subscribed") + d.postKeyFrameRequestEvent() + + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnSubscribedLayerChanged(d, maxLayer) + } +} + +func (d *DownTrack) SetMaxTemporalLayer(temporalLayer int32) { + changed, maxLayer := d.forwarder.SetMaxTemporalLayer(temporalLayer) + if !changed { + return + } + + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnSubscribedLayerChanged(d, maxLayer) + } +} + +func (d *DownTrack) MaxLayer() buffer.VideoLayer { + return d.forwarder.MaxLayer() +} + +func (d *DownTrack) GetState() DownTrackState { + dts := DownTrackState{ + RTPStats: d.rtpStats, + DeltaStatsSenderSnapshotId: d.deltaStatsSenderSnapshotId, + RTPStatsRTX: d.rtpStatsRTX, + DeltaStatsRTXSenderSnapshotId: d.deltaStatsRTXSenderSnapshotId, + ForwarderState: d.forwarder.GetState(), + } + + if d.playoutDelay != nil { + dts.PlayoutDelayControllerState = d.playoutDelay.GetState() + } + return dts +} + +func (d *DownTrack) SeedState(state DownTrackState) { + if d.writable.Load() { + return + } + + if state.RTPStats != nil || state.ForwarderState != nil { + d.params.Logger.Debugw("seeding downtrack state", "state", state) + } + if state.RTPStats != nil { + d.rtpStats.Seed(state.RTPStats) + d.deltaStatsSenderSnapshotId = state.DeltaStatsSenderSnapshotId + if d.playoutDelay != nil { + d.playoutDelay.SeedState(state.PlayoutDelayControllerState) + } + } + if state.RTPStatsRTX != nil { + d.rtpStatsRTX.Seed(state.RTPStatsRTX) + d.deltaStatsRTXSenderSnapshotId = state.DeltaStatsRTXSenderSnapshotId + + d.rtxSequenceNumber.Store(d.rtpStatsRTX.ExtHighestSequenceNumber()) + } + d.forwarder.SeedState(state.ForwarderState) +} + +func (d *DownTrack) StopWriteAndGetState() DownTrackState { + d.params.Logger.Debugw("stopping write") + d.bindLock.Lock() + d.writable.Store(false) + d.writeStopped.Store(true) + d.bindLock.Unlock() + + return d.GetState() +} + +func (d *DownTrack) UpTrackLayersChange() { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnAvailableLayersChanged(d) + } +} + +func (d *DownTrack) UpTrackBitrateAvailabilityChange() { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnBitrateAvailabilityChanged(d) + } +} + +func (d *DownTrack) UpTrackMaxPublishedLayerChange(maxPublishedLayer int32) { + if d.forwarder.SetMaxPublishedLayer(maxPublishedLayer) { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnMaxPublishedSpatialChanged(d) + } + } +} + +func (d *DownTrack) UpTrackMaxTemporalLayerSeenChange(maxTemporalLayerSeen int32) { + if d.forwarder.SetMaxTemporalLayerSeen(maxTemporalLayerSeen) { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnMaxPublishedTemporalChanged(d) + } + } +} + +func (d *DownTrack) maybeAddTransition(bitrate int64, distance float64, pauseReason VideoPauseReason) { + if d.kind == webrtc.RTPCodecTypeAudio { + return + } + + if pauseReason == VideoPauseReasonBandwidth { + d.connectionStats.UpdatePause(true) + } else { + d.connectionStats.UpdatePause(false) + d.connectionStats.AddLayerTransition(distance) + d.connectionStats.AddBitrateTransition(bitrate) + } +} + +func (d *DownTrack) UpTrackBitrateReport(availableLayers []int32, bitrates Bitrates) { + d.maybeAddTransition( + d.forwarder.GetOptimalBandwidthNeeded(bitrates), + d.forwarder.DistanceToDesired(availableLayers, bitrates), + d.forwarder.PauseReason(), + ) +} + +func (d *DownTrack) OnBinding(fn func(error)) { + d.bindLock.Lock() + defer d.bindLock.Unlock() + + d.onBinding = fn +} + +func (d *DownTrack) AddReceiverReportListener(listener ReceiverReportListener) { + d.listenerLock.Lock() + defer d.listenerLock.Unlock() + + d.receiverReportListeners = append(d.receiverReportListeners, listener) +} + +func (d *DownTrack) IsDeficient() bool { + return d.forwarder.IsDeficient() +} + +func (d *DownTrack) BandwidthRequested() int64 { + _, brs := d.Receiver().GetLayeredBitrate() + return d.forwarder.BandwidthRequested(brs) +} + +func (d *DownTrack) DistanceToDesired() float64 { + al, brs := d.Receiver().GetLayeredBitrate() + return d.forwarder.DistanceToDesired(al, brs) +} + +func (d *DownTrack) AllocateOptimal(allowOvershoot bool, hold bool) VideoAllocation { + al, brs := d.Receiver().GetLayeredBitrate() + allocation := d.forwarder.AllocateOptimal(al, brs, allowOvershoot, hold) + d.postKeyFrameRequestEvent() + d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) + return allocation +} + +func (d *DownTrack) ProvisionalAllocatePrepare() { + al, brs := d.Receiver().GetLayeredBitrate() + d.forwarder.ProvisionalAllocatePrepare(al, brs) +} + +func (d *DownTrack) ProvisionalAllocateReset() { + d.forwarder.ProvisionalAllocateReset() +} + +func (d *DownTrack) ProvisionalAllocate(availableChannelCapacity int64, layers buffer.VideoLayer, allowPause bool, allowOvershoot bool) (bool, int64) { + return d.forwarder.ProvisionalAllocate(availableChannelCapacity, layers, allowPause, allowOvershoot) +} + +func (d *DownTrack) ProvisionalAllocateGetCooperativeTransition(allowOvershoot bool) VideoTransition { + transition, availableLayers, brs := d.forwarder.ProvisionalAllocateGetCooperativeTransition(allowOvershoot) + d.params.Logger.Debugw( + "stream: cooperative transition", + "transition", &transition, + "availableLayers", availableLayers, + "bitrates", brs, + ) + return transition +} + +func (d *DownTrack) ProvisionalAllocateGetBestWeightedTransition() VideoTransition { + transition, availableLayers, brs := d.forwarder.ProvisionalAllocateGetBestWeightedTransition() + d.params.Logger.Debugw( + "stream: best weighted transition", + "transition", &transition, + "availableLayers", availableLayers, + "bitrates", brs, + ) + return transition +} + +func (d *DownTrack) ProvisionalAllocateCommit() VideoAllocation { + allocation := d.forwarder.ProvisionalAllocateCommit() + d.postKeyFrameRequestEvent() + d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) + return allocation +} + +func (d *DownTrack) AllocateNextHigher(availableChannelCapacity int64, allowOvershoot bool) (VideoAllocation, bool) { + al, brs := d.Receiver().GetLayeredBitrate() + allocation, available := d.forwarder.AllocateNextHigher(availableChannelCapacity, al, brs, allowOvershoot) + d.postKeyFrameRequestEvent() + d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) + return allocation, available +} + +func (d *DownTrack) GetNextHigherTransition(allowOvershoot bool) (VideoTransition, bool) { + availableLayers, brs := d.Receiver().GetLayeredBitrate() + transition, available := d.forwarder.GetNextHigherTransition(brs, allowOvershoot) + d.params.Logger.Debugw( + "stream: get next higher layer", + "transition", transition, + "available", available, + "availableLayers", availableLayers, + "bitrates", brs, + ) + return transition, available +} + +func (d *DownTrack) Pause() VideoAllocation { + al, brs := d.Receiver().GetLayeredBitrate() + allocation := d.forwarder.Pause(al, brs) + d.maybeAddTransition(allocation.BandwidthNeeded, allocation.DistanceToDesired, allocation.PauseReason) + return allocation +} + +func (d *DownTrack) Resync() { + d.forwarder.Resync() +} + +func (d *DownTrack) ReceiverRestart() { + d.bindLock.Lock() + codec := d.codec.Load().(webrtc.RTPCodecCapability) + d.bindLock.Unlock() + + d.params.Logger.Infow("upstream receiver restart") + + receiver := d.Receiver() + d.forwarder.Restart() + d.forwarder.DetermineCodec(codec, receiver.HeaderExtensions(), receiver.VideoLayerMode()) +} + +func (d *DownTrack) CreateSourceDescriptionChunks() []rtcp.SourceDescriptionChunk { + transceiver := d.transceiver.Load() + if d.bindState.Load() != bindStateBound || transceiver == nil { + return nil + } + return []rtcp.SourceDescriptionChunk{ + { + Source: d.ssrc, + Items: []rtcp.SourceDescriptionItem{ + { + Type: rtcp.SDESCNAME, + Text: d.params.StreamID, + }, + { + Type: rtcp.SDESType(15), + Text: transceiver.Mid(), + }, + }, + }, + } +} + +func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport { + if d.bindState.Load() != bindStateBound { + return nil + } + + _, _, tsOffset, refSenderReport := d.forwarder.GetSenderReportParams() + return d.rtpStats.GetRtcpSenderReport(d.ssrc, refSenderReport, tsOffset, !d.params.DisableSenderReportPassThrough) + + // not sending RTCP Sender Report for RTX +} + +func (d *DownTrack) writeBlankFrameRTP(duration float32, generation uint32) chan struct{} { + done := make(chan struct{}) + go func() { + // don't send if not writable OR nothing has been sent + if !d.writable.Load() || !d.rtpStats.IsActive() { + close(done) + return + } + + mimeType := d.Mime() + var getBlankFrame func(bool) ([]byte, error) + switch mimeType { + case mime.MimeTypeOpus: + getBlankFrame = d.getOpusBlankFrame + case mime.MimeTypeRED: + getBlankFrame = d.getOpusRedBlankFrame + case mime.MimeTypeVP8: + getBlankFrame = d.getVP8BlankFrame + case mime.MimeTypeH264: + getBlankFrame = d.getH264BlankFrame + default: + close(done) + return + } + + frameRate := uint32(30) + if mimeType == mime.MimeTypeOpus || mimeType == mime.MimeTypeRED { + frameRate = 50 + } + + // send a number of blank frames just in case there is loss. + // Intentionally ignoring check for mute or bandwidth constrained mute + // as this is used to clear client side buffer. + numFrames := int(float32(frameRate) * duration) + frameDuration := time.Duration(1000/frameRate) * time.Millisecond + + ticker := time.NewTicker(frameDuration) + defer ticker.Stop() + + for { + if generation != d.blankFramesGeneration.Load() || numFrames <= 0 || !d.writable.Load() || !d.rtpStats.IsActive() { + close(done) + return + } + + snts, frameEndNeeded, err := d.forwarder.GetSnTsForBlankFrames(frameRate, 1) + if err != nil { + d.params.Logger.Warnw("could not get SN/TS for blank frame", err) + close(done) + return + } + + for i := range snts { + hdr := &rtp.Header{ + Version: 2, + Padding: false, + Marker: true, + PayloadType: uint8(d.payloadType.Load()), + SequenceNumber: uint16(snts[i].extSequenceNumber), + Timestamp: uint32(snts[i].extTimestamp), + SSRC: d.ssrc, + } + d.addDummyExtensions(hdr) + + payload, err := getBlankFrame(frameEndNeeded) + if err != nil { + d.params.Logger.Warnw("could not get blank frame", err) + close(done) + return + } + + headerSize := hdr.MarshalSize() + d.rtpStats.Update( + mono.UnixNano(), + snts[i].extSequenceNumber, + snts[i].extTimestamp, + hdr.Marker, + headerSize, + len(payload), + 0, + false, + ) + pacerPacket := pacer.PacketFactory.Get().(*pacer.Packet) + *pacerPacket = pacer.Packet{ + Header: hdr, + HeaderSize: headerSize, + Payload: payload, + ProbeClusterId: ccutils.ProbeClusterId(d.probeClusterId.Load()), + AbsSendTimeExtID: uint8(d.absSendTimeExtID), + TransportWideExtID: uint8(d.transportWideExtID), + WriteStream: d.writeStream, + } + d.pacer.Enqueue(pacerPacket) + + // only the first frame will need frameEndNeeded to close out the + // previous picture, rest are small key frames (for the video case) + frameEndNeeded = false + } + + numFrames-- + <-ticker.C + } + }() + + return done +} + +func (d *DownTrack) maybeAddTrailer(buf []byte) int { + if len(buf) < len(d.params.Trailer) { + d.params.Logger.Warnw("trailer too big", nil, "bufLen", len(buf), "trailerLen", len(d.params.Trailer)) + return 0 + } + + copy(buf, d.params.Trailer) + return len(d.params.Trailer) +} + +func (d *DownTrack) getOpusBlankFrame(_frameEndNeeded bool) ([]byte, error) { + // silence frame + // Used shortly after muting to ensure residual noise does not keep + // generating noise at the decoder after the stream is stopped + // i. e. comfort noise generation actually not producing something comfortable. + payload := make([]byte, 1000) + copy(payload[0:], OpusSilenceFrame) + trailerLen := d.maybeAddTrailer(payload[len(OpusSilenceFrame):]) + return payload[:len(OpusSilenceFrame)+trailerLen], nil +} + +func (d *DownTrack) getOpusRedBlankFrame(_frameEndNeeded bool) ([]byte, error) { + // primary only silence frame for opus/red, there is no need to contain redundant silent frames + payload := make([]byte, 1000) + + // primary header + // 0 1 2 3 4 5 6 7 + // +-+-+-+-+-+-+-+-+ + // |0| Block PT | + // +-+-+-+-+-+-+-+-+ + payload[0] = opusPT + copy(payload[1:], OpusSilenceFrame) + trailerLen := d.maybeAddTrailer(payload[1+len(OpusSilenceFrame):]) + return payload[:1+len(OpusSilenceFrame)+trailerLen], nil +} + +func (d *DownTrack) getVP8BlankFrame(frameEndNeeded bool) ([]byte, error) { + // 8x8 key frame + // Used even when closing out a previous frame. Looks like receivers + // do not care about content (it will probably end up being an undecodable + // frame, but that should be okay as there are key frames following) + header, err := d.forwarder.GetPadding(frameEndNeeded) + if err != nil { + return nil, err + } + + payload := make([]byte, 1000) + copy(payload, header) + copy(payload[len(header):], VP8KeyFrame8x8) + trailerLen := d.maybeAddTrailer(payload[len(header)+len(VP8KeyFrame8x8):]) + return payload[:len(header)+len(VP8KeyFrame8x8)+trailerLen], nil +} + +func (d *DownTrack) getH264BlankFrame(_frameEndNeeded bool) ([]byte, error) { + // TODO - Jie Zeng + // now use STAP-A to compose sps, pps, idr together, most decoder support packetization-mode 1. + // if client only support packetization-mode 0, use single nalu unit packet + buf := make([]byte, 1000) + offset := 0 + buf[0] = 0x18 // STAP-A + offset++ + for _, payload := range H264KeyFrame2x2 { + binary.BigEndian.PutUint16(buf[offset:], uint16(len(payload))) + offset += 2 + copy(buf[offset:offset+len(payload)], payload) + offset += len(payload) + } + offset += d.maybeAddTrailer(buf[offset:]) + return buf[:offset], nil +} + +func (d *DownTrack) handleRTCP(bytes []byte) { + pkts, err := rtcp.Unmarshal(bytes) + if err != nil { + d.params.Logger.Errorw("could not unmarshal rtcp receiver packet", err) + return + } + + pliOnce := true + sendPliOnce := func() { + _, layer := d.forwarder.CheckSync() + if pliOnce { + if layer != buffer.InvalidLayerSpatial { + d.params.Logger.Debugw("sending PLI RTCP", "layer", layer) + d.Receiver().SendPLI(layer, false) + d.isNACKThrottled.Store(true) + d.rtpStats.UpdatePliTime() + pliOnce = false + } + } + } + + rttToReport := uint32(0) + + var numNACKs uint32 + var numPLIs uint32 + var numFIRs uint32 + for _, pkt := range pkts { + switch p := pkt.(type) { + case *rtcp.PictureLossIndication: + if p.MediaSSRC == d.ssrc { + numPLIs++ + sendPliOnce() + } + + case *rtcp.FullIntraRequest: + if p.MediaSSRC == d.ssrc { + numFIRs++ + sendPliOnce() + } + + case *rtcp.ReceiverEstimatedMaximumBitrate: + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnREMB(d, p) + } + + case *rtcp.ReceiverReport: + // create new receiver report w/ only valid reception reports + rr := &rtcp.ReceiverReport{ + SSRC: p.SSRC, + ProfileExtensions: p.ProfileExtensions, + } + for _, r := range p.Reports { + if r.SSRC != d.ssrc { + continue + } + + rtt, isRttChanged := d.rtpStats.UpdateFromReceiverReport(r) + if isRttChanged { + rttToReport = rtt + } + + if d.playoutDelay != nil { + d.playoutDelay.OnSeqAcked(uint16(r.LastSequenceNumber)) + // screen share track has inaccuracy jitter due to its low frame rate and bursty traffic + if d.params.Source != livekit.TrackSource_SCREEN_SHARE { + jitterMs := uint64(r.Jitter*1e3) / uint64(d.clockRate) + d.playoutDelay.SetJitter(uint32(jitterMs)) + } + } + } + // RTX-TODO: This is used for media loss proxying only as of 2024-12-15. + // Ideally, this should keep deltas between previous RTCP Receiver Report + // and current report, calculate the loss in the window and reconcile it with + // data in a similar window from RTX stream (to ensure losses are discounted + // for NACKs), but keeping this simple for several reasons + // - media loss proxying is a configurable setting and could be disabled + // - media loss proxying is used for audio only and audio may not have NACKing + // - to keep it simple + if len(rr.Reports) > 0 { + d.listenerLock.RLock() + rrListeners := d.receiverReportListeners + d.listenerLock.RUnlock() + for _, l := range rrListeners { + l(d, rr) + } + } + + case *rtcp.TransportLayerNack: + if p.MediaSSRC == d.ssrc { + var nacks []uint16 + for _, pair := range p.Nacks { + packetList := pair.PacketList() + numNACKs += uint32(len(packetList)) + nacks = append(nacks, packetList...) + } + go d.retransmitPackets(nacks) + } + + case *rtcp.TransportLayerCC: + if p.MediaSSRC == d.ssrc { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnTransportCCFeedback(d, p) + } + } + + case *rtcp.ExtendedReport: + // SFU only responds with the DLRRReport for the track has the sender SSRC, the behavior is different with + // browser's implementation, which includes all sent tracks. It is ok since all the tracks + // use the same connection, and server-sdk-go can get the rtt from the first DLRRReport + // (libwebrtc/browsers don't send XR to calculate rtt, it only responds) + var lastRR uint32 + for _, report := range p.Reports { + if rr, ok := report.(*rtcp.ReceiverReferenceTimeReportBlock); ok { + lastRR = uint32(rr.NTPTimestamp >> 16) + break + } + } + + if lastRR > 0 { + d.params.RTCPWriter([]rtcp.Packet{&rtcp.ExtendedReport{ + SenderSSRC: d.ssrc, + Reports: []rtcp.ReportBlock{ + &rtcp.DLRRReportBlock{ + Reports: []rtcp.DLRRReport{{ + SSRC: p.SenderSSRC, + LastRR: lastRR, + DLRR: 0, // no delay + }}, + }, + }, + }}) + } + } + } + + d.rtpStats.UpdateNack(numNACKs) + d.rtpStats.UpdatePli(numPLIs) + d.rtpStats.UpdateFir(numFIRs) + + if rttToReport != 0 { + if d.sequencer != nil { + d.sequencer.setRTT(rttToReport) + } + + d.params.Listener.OnRttUpdate(rttToReport) + } +} + +func (d *DownTrack) handleRTCPRTX(bytes []byte) { + pkts, err := rtcp.Unmarshal(bytes) + if err != nil { + d.params.Logger.Errorw("could not unmarshal rtcp rtx receiver packet", err) + return + } + + for _, pkt := range pkts { + switch p := pkt.(type) { + case *rtcp.ReceiverReport: + for _, r := range p.Reports { + if r.SSRC != d.ssrcRTX { + continue + } + + d.rtpStatsRTX.UpdateFromReceiverReport(r) + } + + case *rtcp.ReceiverEstimatedMaximumBitrate: + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnREMB(d, p) + } + + case *rtcp.TransportLayerCC: + if p.MediaSSRC == d.ssrcRTX { + if sal := d.getStreamAllocatorListener(); sal != nil { + sal.OnTransportCCFeedback(d, p) + } + } + } + } +} + +func (d *DownTrack) SetConnected() { + d.bindLock.Lock() + if !d.connected.Swap(true) { + d.onBindAndConnectedChange() + } + d.params.Logger.Debugw("downtrack connected") + d.bindLock.Unlock() +} + +// SetActivePaddingOnMuteUpTrack will enable padding on the track when its uptrack is muted. +// Pion will not fire OnTrack event until it receives packet for the track, +// so we send padding packets to help pion client (go-sdk) to fire the event. +func (d *DownTrack) SetActivePaddingOnMuteUpTrack() { + d.activePaddingOnMuteUpTrack.Store(true) +} + +func (d *DownTrack) retransmitPacket(epm *extPacketMeta, sourcePkt []byte, isProbe bool) (int, error) { + var pkt rtp.Packet + if err := pkt.Unmarshal(sourcePkt); err != nil { + d.params.Logger.Errorw("could not unmarshal rtp packet to send via RTX", err) + return 0, err + } + hdr := RTPHeaderFactory.Get().(*rtp.Header) + *hdr = rtp.Header{ + Version: pkt.Header.Version, + Padding: pkt.Header.Padding, + Marker: epm.marker, + PayloadType: d.getTranslatedPayloadType(pkt.Header.PayloadType), + SequenceNumber: epm.targetSeqNo, + Timestamp: epm.timestamp, + SSRC: d.ssrc, + } + rtxOffset := 0 + var rtxExtSequenceNumber uint64 + if rtxPT := d.payloadTypeRTX.Load(); rtxPT != 0 && d.ssrcRTX != 0 { + rtxExtSequenceNumber = d.rtxSequenceNumber.Inc() + rtxOffset = 2 + + hdr.PayloadType = uint8(rtxPT) + hdr.SequenceNumber = uint16(rtxExtSequenceNumber) + hdr.SSRC = d.ssrcRTX + } + + if d.dependencyDescriptorExtID != 0 { + var ddBytes []byte + if len(epm.ddBytesSlice) != 0 { + ddBytes = epm.ddBytesSlice + } else { + ddBytes = epm.ddBytes[:epm.ddBytesSize] + } + if len(ddBytes) != 0 { + hdr.SetExtension(uint8(d.dependencyDescriptorExtID), ddBytes) + } + } + if d.absCaptureTimeExtID != 0 && len(epm.actBytes) != 0 { + hdr.SetExtension(uint8(d.absCaptureTimeExtID), epm.actBytes) + } + d.addDummyExtensions(hdr) + + poolEntity := PacketFactory.Get().(*[]byte) + payload := *poolEntity + if rtxOffset != 0 { + // write OSN (Original Sequence Number) + binary.BigEndian.PutUint16(payload[0:2], epm.targetSeqNo) + } + if len(epm.codecBytesSlice) != 0 { + n := copy(payload[rtxOffset:], epm.codecBytesSlice) + m := copy(payload[rtxOffset+n:], pkt.Payload[epm.numCodecBytesIn:]) + payload = payload[:rtxOffset+n+m] + } else { + copy(payload[rtxOffset:], epm.codecBytes[:epm.numCodecBytesOut]) + copy(payload[rtxOffset+int(epm.numCodecBytesOut):], pkt.Payload[epm.numCodecBytesIn:]) + payload = payload[:rtxOffset+int(epm.numCodecBytesOut)+len(pkt.Payload)-int(epm.numCodecBytesIn)] + } + + headerSize := hdr.MarshalSize() + var ( + payloadSize, paddingSize int + isOutOfOrder bool + ) + if isProbe { + // although not padding only packets, marking it as padding for accounting as padding is used to signify probing, + // also not marking them as out-of-order although sequence numbers in packets are out-of-order because of re-sending packets + payloadSize, paddingSize, isOutOfOrder = 0, len(payload), false + } else { + payloadSize, paddingSize, isOutOfOrder = len(payload), 0, true + } + if hdr.SSRC == d.ssrcRTX { + d.rtpStatsRTX.Update( + mono.UnixNano(), + rtxExtSequenceNumber, + 0, + hdr.Marker, + headerSize, + payloadSize, + paddingSize, + isOutOfOrder, + ) + } else { + d.rtpStats.Update( + mono.UnixNano(), + epm.extSequenceNumber, + epm.extTimestamp, + hdr.Marker, + headerSize, + payloadSize, + paddingSize, + isOutOfOrder, + ) + } + pacerPacket := pacer.PacketFactory.Get().(*pacer.Packet) + *pacerPacket = pacer.Packet{ + Header: hdr, + HeaderPool: RTPHeaderFactory, + HeaderSize: headerSize, + Payload: payload, + ProbeClusterId: ccutils.ProbeClusterId(d.probeClusterId.Load()), + IsProbe: isProbe, + IsRTX: !isProbe, + AbsSendTimeExtID: uint8(d.absSendTimeExtID), + TransportWideExtID: uint8(d.transportWideExtID), + WriteStream: d.writeStream, + Pool: PacketFactory, + PoolEntity: poolEntity, + } + d.pacer.Enqueue(pacerPacket) + return headerSize + len(payload), nil +} + +func (d *DownTrack) retransmitPackets(nacks []uint16) { + if d.sequencer == nil { + return + } + + if FlagStopRTXOnPLI && d.isNACKThrottled.Load() { + return + } + + filtered, disallowedLayers := d.forwarder.FilterRTX(nacks) + if len(filtered) == 0 { + return + } + + src := PacketFactory.Get().(*[]byte) + defer PacketFactory.Put(src) + + receiver := d.Receiver() + + nackAcks := uint32(0) + nackMisses := uint32(0) + numRepeatedNACKs := uint32(0) + for _, epm := range d.sequencer.getExtPacketMetas(filtered) { + if disallowedLayers[epm.layer] { + continue + } + + nackAcks++ + + pktBuff := *src + n, err := receiver.ReadRTP(pktBuff, uint8(epm.layer), epm.sourceSeqNo) + if err != nil { + if err == io.EOF { + break + } + nackMisses++ + continue + } + + if epm.nacked > 1 { + numRepeatedNACKs++ + } + + d.retransmitPacket(&epm, pktBuff[:n], false) + } + + d.totalRepeatedNACKs.Add(numRepeatedNACKs) + + d.rtpStats.UpdateNackProcessed(nackAcks, nackMisses, numRepeatedNACKs) +} + +func (d *DownTrack) WriteProbePackets(bytesToSend int, usePadding bool) int { + rtxPT := uint8(d.payloadTypeRTX.Load()) + if rtxPT == 0 || d.ssrcRTX == 0 { + return d.WritePaddingRTP(bytesToSend, false, false) + } + + if !d.writable.Load() || + !d.rtpStats.IsActive() || + (d.absSendTimeExtID == 0 && d.transportWideExtID == 0) || + d.rtpStats.LastReceiverReportTime() == 0 || + d.sequencer == nil { + return 0 + } + + bytesSent := 0 + + if usePadding { + num := (bytesToSend + RTPPaddingMaxPayloadSize + RTPPaddingEstimatedHeaderSize - 1) / (RTPPaddingMaxPayloadSize + RTPPaddingEstimatedHeaderSize) + if num == 0 { + return 0 + } + + payloads := make([]byte, RTPPaddingMaxPayloadSize*num) + for i := range num { + rtxExtSequenceNumber := d.rtxSequenceNumber.Inc() + hdr := RTPHeaderFactory.Get().(*rtp.Header) + *hdr = rtp.Header{ + Version: 2, + Padding: true, + Marker: false, + PayloadType: rtxPT, + SequenceNumber: uint16(rtxExtSequenceNumber), + Timestamp: 0, + SSRC: d.ssrcRTX, + } + d.addDummyExtensions(hdr) + + payload := payloads[i*RTPPaddingMaxPayloadSize : (i+1)*RTPPaddingMaxPayloadSize : (i+1)*RTPPaddingMaxPayloadSize] + // last byte of padding has padding size including that byte + payload[RTPPaddingMaxPayloadSize-1] = byte(RTPPaddingMaxPayloadSize) + + hdrSize := hdr.MarshalSize() + payloadSize := len(payload) + d.rtpStatsRTX.Update( + mono.UnixNano(), + rtxExtSequenceNumber, + 0, + hdr.Marker, + hdrSize, + 0, + payloadSize, + false, + ) + pacerPacket := pacer.PacketFactory.Get().(*pacer.Packet) + *pacerPacket = pacer.Packet{ + Header: hdr, + HeaderPool: RTPHeaderFactory, + HeaderSize: hdrSize, + Payload: payload, + ProbeClusterId: ccutils.ProbeClusterId(d.probeClusterId.Load()), + IsProbe: true, + AbsSendTimeExtID: uint8(d.absSendTimeExtID), + TransportWideExtID: uint8(d.transportWideExtID), + WriteStream: d.writeStream, + } + d.pacer.Enqueue(pacerPacket) + + bytesSent += hdrSize + payloadSize + } + } else { + src := PacketFactory.Get().(*[]byte) + defer PacketFactory.Put(src) + + receiver := d.Receiver() + + endExtHighestSequenceNumber := d.rtpStats.ExtHighestSequenceNumber() + startExtHighestSequenceNumber := endExtHighestSequenceNumber - 5 + for esn := startExtHighestSequenceNumber; esn <= endExtHighestSequenceNumber; esn++ { + epm := d.sequencer.lookupExtPacketMeta(esn) + if epm == nil { + continue + } + + pktBuff := *src + n, err := receiver.ReadRTP(pktBuff, uint8(epm.layer), epm.sourceSeqNo) + if err != nil { + if err == io.EOF { + break + } + continue + } + + sent, _ := d.retransmitPacket(epm, pktBuff[:n], true) + bytesSent += sent + if bytesSent >= bytesToSend { + break + } + } + } + + return bytesSent +} + +func (d *DownTrack) addDummyExtensions(hdr *rtp.Header) { + // add dummy extensions (actual ones will be filed by pacer) to get header size + if d.absSendTimeExtID != 0 { + hdr.SetExtension(uint8(d.absSendTimeExtID), dummyAbsSendTimeExt) + } + if d.transportWideExtID != 0 { + hdr.SetExtension(uint8(d.transportWideExtID), dummyTransportCCExt) + } +} + +func (d *DownTrack) getTranslatedPayloadType(srcPT uint8) uint8 { + // send primary codec to subscriber if the publisher sent primary codec when red is negotiated, + // this will happen when the payload is too large to encode into red payload (exceeds mtu). + if d.isRED && srcPT == d.upstreamPrimaryPT && d.primaryPT != 0 { + return d.primaryPT + } + return uint8(d.payloadType.Load()) +} + +func (d *DownTrack) DebugInfo() map[string]any { + stats := map[string]any{ + "LastPli": d.rtpStats.LastPli(), + } + stats["RTPMunger"] = d.forwarder.RTPMungerDebugInfo() + + senderReport := d.CreateSenderReport() + if senderReport != nil { + stats["NTPTime"] = senderReport.NTPTime + stats["RTPTime"] = senderReport.RTPTime + stats["PacketCount"] = senderReport.PacketCount + } + + return map[string]any{ + "SubscriberID": d.params.SubID, + "TrackID": d.id, + "StreamID": d.params.StreamID, + "SSRC": d.ssrc, + "MimeType": d.Mime().String(), + "BindState": d.bindState.Load().(bindState), + "Muted": d.forwarder.IsMuted(), + "PubMuted": d.forwarder.IsPubMuted(), + "CurrentSpatialLayer": d.forwarder.CurrentLayer().Spatial, + "Stats": stats, + } +} + +func (d *DownTrack) GetConnectionScoreAndQuality() (float32, livekit.ConnectionQuality) { + return d.connectionStats.GetScoreAndQuality() +} + +func (d *DownTrack) GetTrackStats() *livekit.RTPStats { + return rtpstats.ReconcileRTPStatsWithRTX(d.rtpStats.ToProto(), d.rtpStatsRTX.ToProto()) +} + +func (d *DownTrack) deltaStats(ds *rtpstats.RTPDeltaInfo, dsrv *rtpstats.RTPDeltaInfo) map[uint32]*buffer.StreamStatsWithLayers { + if ds == nil && dsrv == nil { + return nil + } + + streamStats := make(map[uint32]*buffer.StreamStatsWithLayers, 1) + streamStats[d.ssrc] = &buffer.StreamStatsWithLayers{ + RTPStats: ds, + RTPStatsRemoteView: dsrv, + Layers: map[int32]*rtpstats.RTPDeltaInfo{ + 0: ds, + }, + } + + return streamStats +} + +func (d *DownTrack) GetDeltaStatsSender() map[uint32]*buffer.StreamStatsWithLayers { + ds, dsrv := d.rtpStats.DeltaInfoSender(d.deltaStatsSenderSnapshotId) + dsRTX, dsrvRTX := d.rtpStatsRTX.DeltaInfoSender(d.deltaStatsRTXSenderSnapshotId) + return d.deltaStats( + rtpstats.ReconcileRTPDeltaInfoWithRTX(ds, dsRTX), + rtpstats.ReconcileRTPDeltaInfoWithRTX(dsrv, dsrvRTX), + ) +} + +func (d *DownTrack) GetPrimaryStreamLastReceiverReportTime() time.Time { + return time.Unix(0, d.rtpStats.LastReceiverReportTime()) +} + +func (d *DownTrack) GetPrimaryStreamPacketsSent() uint64 { + return d.rtpStats.GetPacketsSeenMinusPadding() +} + +func (d *DownTrack) GetNackStats() (totalPackets uint32, totalRepeatedNACKs uint32) { + totalPackets = uint32(d.rtpStats.GetPacketsSeenMinusPadding()) + totalRepeatedNACKs = d.totalRepeatedNACKs.Load() + return +} + +func (d *DownTrack) onBindAndConnectedChange() { + if d.writeStopped.Load() { + return + } + d.writable.Store(d.connected.Load() && d.bindState.Load() == bindStateBound) + if d.connected.Load() && d.bindState.Load() == bindStateBound && !d.bindAndConnectedOnce.Swap(true) { + go d.params.Listener.OnBindAndConnected() + + if d.activePaddingOnMuteUpTrack.Load() { + go d.sendPaddingOnMute() + } + + // kick off PLI request if allocation is pending + d.postKeyFrameRequestEvent() + } +} + +func (d *DownTrack) sendPaddingOnMute() { + // let uptrack have chance to send packet before we send padding + time.Sleep(waitBeforeSendPaddingOnMute) + + if d.kind == webrtc.RTPCodecTypeVideo { + d.sendPaddingOnMuteForVideo() + } else if d.Mime() == mime.MimeTypeOpus { + d.sendSilentFrameOnMuteForOpus() + } +} + +func (d *DownTrack) sendPaddingOnMuteForVideo() { + numPackets := maxPaddingOnMuteDuration / paddingOnMuteInterval + for i := range int(numPackets) { + if d.rtpStats.IsActive() || d.IsClosed() { + return + } + if i == 0 { + d.params.Logger.Debugw("sending padding on mute") + } + d.WritePaddingRTP(20, true, true) + time.Sleep(paddingOnMuteInterval) + } +} + +func (d *DownTrack) sendSilentFrameOnMuteForOpus() { + frameRate := uint32(50) + frameDuration := time.Duration(1000/frameRate) * time.Millisecond + numFrames := frameRate * uint32(maxPaddingOnMuteDuration/time.Second) + first := true + for { + if d.rtpStats.IsActive() || d.IsClosed() || numFrames <= 0 { + return + } + if first { + first = false + d.params.Logger.Debugw("sending padding on mute") + } + snts, _, err := d.forwarder.GetSnTsForBlankFrames(frameRate, 1) + if err != nil { + d.params.Logger.Warnw("could not get SN/TS for blank frame", err) + return + } + for i := range len(snts) { + hdr := &rtp.Header{ + Version: 2, + Padding: false, + Marker: true, + PayloadType: uint8(d.payloadType.Load()), + SequenceNumber: uint16(snts[i].extSequenceNumber), + Timestamp: uint32(snts[i].extTimestamp), + SSRC: d.ssrc, + } + d.addDummyExtensions(hdr) + + payload, err := d.getOpusBlankFrame(false) + if err != nil { + d.params.Logger.Warnw("could not get blank frame", err) + return + } + + headerSize := hdr.MarshalSize() + d.rtpStats.Update( + mono.UnixNano(), + snts[i].extSequenceNumber, + snts[i].extTimestamp, + hdr.Marker, + headerSize, + 0, + len(payload), // although this is using empty frames, mark as padding as these are used to trigger Pion OnTrack only + false, + ) + pacerPacket := pacer.PacketFactory.Get().(*pacer.Packet) + *pacerPacket = pacer.Packet{ + Header: hdr, + HeaderSize: headerSize, + Payload: payload, + ProbeClusterId: ccutils.ProbeClusterId(d.probeClusterId.Load()), + AbsSendTimeExtID: uint8(d.absSendTimeExtID), + TransportWideExtID: uint8(d.transportWideExtID), + WriteStream: d.writeStream, + } + d.pacer.Enqueue(pacerPacket) + } + + numFrames-- + time.Sleep(frameDuration) + } +} + +func (d *DownTrack) HandleRTCPSenderReportData( + _payloadType webrtc.PayloadType, + layer int32, + publisherSRData *livekit.RTCPSenderReportState, +) error { + d.forwarder.SetRefSenderReport(layer, publisherSRData) + + currentLayer, isSingleStream, tsOffset, refSenderReport := d.forwarder.GetSenderReportParams() + if layer == currentLayer || (layer == 0 && isSingleStream) { + d.handleRTCPSenderReportData(refSenderReport, tsOffset) + } + return nil +} + +func (d *DownTrack) handleRTCPSenderReportData(publisherSRData *livekit.RTCPSenderReportState, tsOffset uint64) { + d.rtpStats.MaybeAdjustFirstPacketTime(publisherSRData, tsOffset) +} + +// ------------------------------------------------------------------------------- diff --git a/livekit/pkg/sfu/errors.go b/livekit/pkg/sfu/errors.go new file mode 100644 index 0000000..1074380 --- /dev/null +++ b/livekit/pkg/sfu/errors.go @@ -0,0 +1,15 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu diff --git a/livekit/pkg/sfu/forwarder.go b/livekit/pkg/sfu/forwarder.go new file mode 100644 index 0000000..d52bcaa --- /dev/null +++ b/livekit/pkg/sfu/forwarder.go @@ -0,0 +1,2399 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "errors" + "fmt" + "math" + "math/rand" + "sync" + "time" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + "go.uber.org/zap/zapcore" + + "github.com/livekit/mediatransportutil" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/mono" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/codecmunger" + "github.com/livekit/livekit-server/pkg/sfu/mime" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + sfuutils "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/livekit-server/pkg/sfu/videolayerselector" + "github.com/livekit/livekit-server/pkg/sfu/videolayerselector/temporallayerselector" +) + +const ( + FlagPauseOnDowngrade = true + FlagFilterRTX = false + FlagFilterRTXLayers = true + TransitionCostSpatial = 10 + + ResumeBehindThresholdSeconds = float64(0.2) // 200ms + ResumeBehindHighThresholdSeconds = float64(2.0) // 2 seconds + LayerSwitchBehindThresholdSeconds = float64(0.05) // 50ms + SwitchAheadThresholdSeconds = float64(0.025) // 25ms +) + +var ( + errSkipStartOnOutOfOrderPacket = errors.New("skip start on out-of-order packet") + errSwitchPointTooFarBehind = errors.New("switch point too far behind") +) + +// ------------------------------------------------------------------- + +type VideoPauseReason int + +const ( + VideoPauseReasonNone VideoPauseReason = iota + VideoPauseReasonMuted + VideoPauseReasonPubMuted + VideoPauseReasonFeedDry + VideoPauseReasonBandwidth +) + +func (v VideoPauseReason) String() string { + switch v { + case VideoPauseReasonNone: + return "NONE" + case VideoPauseReasonMuted: + return "MUTED" + case VideoPauseReasonPubMuted: + return "PUB_MUTED" + case VideoPauseReasonFeedDry: + return "FEED_DRY" + case VideoPauseReasonBandwidth: + return "BANDWIDTH" + default: + return fmt.Sprintf("%d", int(v)) + } +} + +// ------------------------------------------------------------------- + +type VideoAllocation struct { + PauseReason VideoPauseReason + IsDeficient bool + BandwidthRequested int64 + BandwidthDelta int64 + BandwidthNeeded int64 + Bitrates Bitrates + TargetLayer buffer.VideoLayer + RequestLayerSpatial int32 + MaxLayer buffer.VideoLayer + DistanceToDesired float64 +} + +func (v *VideoAllocation) String() string { + return fmt.Sprintf("VideoAllocation{pause: %s, def: %+v, bwr: %d, del: %d, bwn: %d, rates: %+v, target: %s, req: %d, max: %s, dist: %0.2f}", + v.PauseReason, + v.IsDeficient, + v.BandwidthRequested, + v.BandwidthDelta, + v.BandwidthNeeded, + v.Bitrates, + v.TargetLayer, + v.RequestLayerSpatial, + v.MaxLayer, + v.DistanceToDesired, + ) +} + +func (v *VideoAllocation) MarshalLogObject(e zapcore.ObjectEncoder) error { + if v == nil { + return nil + } + + e.AddString("PauseReason", v.PauseReason.String()) + e.AddBool("IsDeficient", v.IsDeficient) + e.AddInt64("BandwidthRquested", v.BandwidthRequested) + e.AddInt64("BandwidthDelta", v.BandwidthDelta) + e.AddInt64("BandwidthNeeded", v.BandwidthNeeded) + e.AddReflected("Bitrates", v.Bitrates) + e.AddReflected("TargetLayer", v.TargetLayer) + e.AddInt32("RequestLayerSpatial", v.RequestLayerSpatial) + e.AddReflected("MaxLayer", v.MaxLayer) + e.AddFloat64("DistanceToDesired", v.DistanceToDesired) + return nil +} + +var ( + VideoAllocationDefault = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, // start with no feed till feed is seen + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.InvalidLayer, + } +) + +// ------------------------------------------------------------------- + +type VideoAllocationProvisional struct { + muted bool + pubMuted bool + maxSeenLayer buffer.VideoLayer + availableLayers []int32 + bitrates Bitrates + maxLayer buffer.VideoLayer + currentLayer buffer.VideoLayer + allocatedLayer buffer.VideoLayer +} + +// ------------------------------------------------------------------- + +type VideoTransition struct { + From buffer.VideoLayer + To buffer.VideoLayer + BandwidthDelta int64 +} + +func (v *VideoTransition) String() string { + return fmt.Sprintf("VideoTransition{from: %s, to: %s, del: %d}", v.From, v.To, v.BandwidthDelta) +} + +func (v *VideoTransition) MarshalLogObject(e zapcore.ObjectEncoder) error { + if v == nil { + return nil + } + + e.AddReflected("From", v.From) + e.AddReflected("To", v.To) + e.AddInt64("BandwidthDelta", v.BandwidthDelta) + return nil +} + +// ------------------------------------------------------------------- + +type TranslationParams struct { + shouldDrop bool + isResuming bool + isSwitching bool + rtp TranslationParamsRTP + ddBytes []byte + incomingHeaderSize int + codecBytes []byte + marker bool +} + +// ------------------------------------------------------------------- + +type refInfo struct { + senderReport *livekit.RTCPSenderReportState + tsOffset uint64 + isTSOffsetValid bool +} + +func (r refInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddObject("senderReport", rtpstats.WrappedRTCPSenderReportStateLogger{ + RTCPSenderReportState: r.senderReport, + }) + e.AddUint64("tsOffset", r.tsOffset) + e.AddBool("isTSOffsetValid", r.isTSOffsetValid) + return nil +} + +// ------------------------------------------------------------------- + +type Forwarder struct { + lock sync.RWMutex + mime mime.MimeType + clockRate uint32 + kind webrtc.RTPCodecType + logger logger.Logger + skipReferenceTS bool + disableOpportunisticAllocation bool + rtpStats *rtpstats.RTPStatsSender + + muted bool + pubMuted bool + resumeBehindThreshold float64 + + started bool + preStartTime time.Time + extFirstTS uint64 + lastSSRC uint32 + lastReferencePayloadType int8 + lastSwitchExtIncomingTS uint64 + referenceLayerSpatial int32 + dummyStartTSOffset uint64 + refInfos [buffer.DefaultMaxLayerSpatial + 1]refInfo + refVideoLayerMode livekit.VideoLayer_Mode + isDDAvailable bool + + provisional *VideoAllocationProvisional + + lastAllocation VideoAllocation + + rtpMunger *RTPMunger + + vls videolayerselector.VideoLayerSelector + + codecMunger codecmunger.CodecMunger +} + +func NewForwarder( + kind webrtc.RTPCodecType, + logger logger.Logger, + skipReferenceTS bool, + disableOpportunisticAllocation bool, + rtpStats *rtpstats.RTPStatsSender, +) *Forwarder { + f := &Forwarder{ + mime: mime.MimeTypeUnknown, + kind: kind, + logger: logger, + skipReferenceTS: skipReferenceTS, + disableOpportunisticAllocation: disableOpportunisticAllocation, + rtpStats: rtpStats, + referenceLayerSpatial: buffer.InvalidLayerSpatial, + lastAllocation: VideoAllocationDefault, + lastReferencePayloadType: -1, + rtpMunger: NewRTPMunger(logger), + vls: videolayerselector.NewNull(logger), + codecMunger: codecmunger.NewNull(logger), + } + + if f.kind == webrtc.RTPCodecTypeVideo { + f.vls.SetMaxTemporal(buffer.DefaultMaxLayerTemporal) + } + return f +} + +func (f *Forwarder) SetMaxPublishedLayer(maxPublishedLayer int32) bool { + f.lock.Lock() + defer f.lock.Unlock() + + existingMaxSeen := f.vls.GetMaxSeen() + if maxPublishedLayer <= existingMaxSeen.Spatial { + return false + } + + f.vls.SetMaxSeenSpatial(maxPublishedLayer) + f.logger.Debugw("setting max published layer", "layer", maxPublishedLayer) + return true +} + +func (f *Forwarder) SetMaxTemporalLayerSeen(maxTemporalLayerSeen int32) bool { + f.lock.Lock() + defer f.lock.Unlock() + + existingMaxSeen := f.vls.GetMaxSeen() + if maxTemporalLayerSeen <= existingMaxSeen.Temporal { + return false + } + + f.vls.SetMaxSeenTemporal(maxTemporalLayerSeen) + f.logger.Debugw("setting max temporal layer seen", "maxTemporalLayerSeen", maxTemporalLayerSeen) + return true +} + +func (f *Forwarder) DetermineCodec(codec webrtc.RTPCodecCapability, extensions []webrtc.RTPHeaderExtensionParameter, videoLayerMode livekit.VideoLayer_Mode) { + f.lock.Lock() + defer f.lock.Unlock() + + if videoLayerMode == livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM_INCOMPLETE_RTCP_SR { + f.skipReferenceTS = true + } + + toMimeType := mime.NormalizeMimeType(codec.MimeType) + codecChanged := f.mime != mime.MimeTypeUnknown && f.mime != toMimeType + if codecChanged { + f.logger.Debugw("forwarder codec changed", "from", f.mime, "to", toMimeType) + } + f.mime = toMimeType + f.clockRate = codec.ClockRate + f.refVideoLayerMode = videoLayerMode + + ddAvailable := func(exts []webrtc.RTPHeaderExtensionParameter) bool { + for _, ext := range exts { + if ext.URI == dd.ExtensionURI { + return true + } + } + return false + } + + switch f.mime { + case mime.MimeTypeVP8: + f.codecMunger = codecmunger.NewVP8FromOther(f.codecMunger, f.logger) + if f.vls != nil { + if vls := videolayerselector.NewSimulcastFromOther(f.vls); vls != nil { + f.vls = vls + } else { + f.logger.Errorw("failed to create simulcast on codec change", nil) + } + } else { + f.vls = videolayerselector.NewSimulcast(f.logger) + } + f.vls.SetTemporalLayerSelector(temporallayerselector.NewVP8(f.logger)) + + case mime.MimeTypeH264, mime.MimeTypeH265: + f.codecMunger = codecmunger.NewNull(f.logger) + if f.vls != nil { + if vls := videolayerselector.NewSimulcastFromOther(f.vls); vls != nil { + f.vls = vls + } else { + f.logger.Errorw("failed to create simulcast on codec change", nil) + } + } else { + f.vls = videolayerselector.NewSimulcast(f.logger) + } + + case mime.MimeTypeVP9: + f.codecMunger = codecmunger.NewNull(f.logger) + if sfuutils.IsSimulcastMode(videoLayerMode) { + if f.vls != nil { + f.vls = videolayerselector.NewSimulcastFromOther(f.vls) + } else { + f.vls = videolayerselector.NewDependencyDescriptor(f.logger) + } + } else { + f.isDDAvailable = ddAvailable(extensions) + if f.isDDAvailable { + if f.vls != nil { + f.vls = videolayerselector.NewDependencyDescriptorFromOther(f.vls) + } else { + f.vls = videolayerselector.NewDependencyDescriptor(f.logger) + } + } else { + if f.vls != nil { + f.vls = videolayerselector.NewVP9FromOther(f.vls) + } else { + f.vls = videolayerselector.NewVP9(f.logger) + } + } + } + + case mime.MimeTypeAV1: + f.codecMunger = codecmunger.NewNull(f.logger) + if sfuutils.IsSimulcastMode(videoLayerMode) { + if f.vls != nil { + f.vls = videolayerselector.NewSimulcastFromOther(f.vls) + } else { + f.vls = videolayerselector.NewSimulcast(f.logger) + } + } else { + f.isDDAvailable = ddAvailable(extensions) + if f.isDDAvailable { + if f.vls != nil { + f.vls = videolayerselector.NewDependencyDescriptorFromOther(f.vls) + } else { + f.vls = videolayerselector.NewDependencyDescriptor(f.logger) + } + } else { + if f.vls != nil { + f.vls = videolayerselector.NewSimulcastFromOther(f.vls) + } else { + f.vls = videolayerselector.NewSimulcast(f.logger) + } + } + } + } +} + +func (f *Forwarder) GetState() *livekit.RTPForwarderState { + f.lock.RLock() + defer f.lock.RUnlock() + + if !f.started { + return nil + } + + state := &livekit.RTPForwarderState{ + Started: f.started, + ReferenceLayerSpatial: f.referenceLayerSpatial, + ExtFirstTimestamp: f.extFirstTS, + DummyStartTimestampOffset: f.dummyStartTSOffset, + RtpMunger: f.rtpMunger.GetState(), + } + if !f.preStartTime.IsZero() { + state.PreStartTime = f.preStartTime.UnixNano() + } + + codecMungerState := f.codecMunger.GetState() + if vp8MungerState, ok := codecMungerState.(*livekit.VP8MungerState); ok { + state.CodecMunger = &livekit.RTPForwarderState_Vp8Munger{ + Vp8Munger: vp8MungerState, + } + } + + state.SenderReportState = make([]*livekit.RTCPSenderReportState, len(f.refInfos)) + for layer, refInfo := range f.refInfos { + state.SenderReportState[layer] = utils.CloneProto(refInfo.senderReport) + } + return state +} + +func (f *Forwarder) SeedState(state *livekit.RTPForwarderState) { + if state == nil || !state.Started { + return + } + + f.lock.Lock() + defer f.lock.Unlock() + + for layer, rtcpSenderReportState := range state.SenderReportState { + f.refInfos[layer] = refInfo{} + if senderReport := utils.CloneProto(rtcpSenderReportState); senderReport != nil && senderReport.NtpTimestamp != 0 { + f.refInfos[layer].senderReport = senderReport + } + } + + f.rtpMunger.SeedState(state.RtpMunger) + f.codecMunger.SeedState(state.CodecMunger) + + f.started = true + f.referenceLayerSpatial = state.ReferenceLayerSpatial + if state.PreStartTime != 0 { + f.preStartTime = time.Unix(0, state.PreStartTime) + } + f.extFirstTS = state.ExtFirstTimestamp + f.dummyStartTSOffset = state.DummyStartTimestampOffset +} + +func (f *Forwarder) Mute(muted bool, isSubscribeMutable bool) bool { + f.lock.Lock() + defer f.lock.Unlock() + + if f.muted == muted { + return false + } + + // Do not mute when paused due to bandwidth limitation. + // There are two issues + // 1. Muting means probing cannot happen on this track. + // 2. Muting also triggers notification to publisher about layers this forwarder needs. + // If this forwarder does not need any layer, publisher could turn off all layers. + // So, muting could lead to not being able to restart the track. + // To avoid that, ignore mute when paused due to bandwidth limitations. + // + // NOTE: The above scenario refers to mute getting triggered due + // to video stream visibility changes. When a stream is paused, it is possible + // that the receiver hides the video tile triggering subscription mute. + // The work around here to ignore mute does ignore an intentional mute. + // It could result in some bandwidth consumed for stream without visibility in + // the case of intentional mute. + if muted && !isSubscribeMutable { + f.logger.Infow( + "ignoring forwarder mute, paused due to congestion", + "targetLayers", f.vls.GetTarget(), + "currentLayers", f.vls.GetCurrent(), + "lastAllocation", f.lastAllocation, + ) + return false + } + + f.logger.Debugw("setting forwarder mute", "muted", muted) + f.muted = muted + + // resync when muted so that sequence numbers do not jump on unmute + if muted { + f.resyncLocked() + } + + return true +} + +func (f *Forwarder) IsMuted() bool { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.muted +} + +func (f *Forwarder) PubMute(pubMuted bool) bool { + f.lock.Lock() + defer f.lock.Unlock() + + if f.pubMuted == pubMuted { + return false + } + + f.logger.Debugw("setting forwarder pub mute", "muted", pubMuted) + f.pubMuted = pubMuted + + // resync when pub muted so that sequence numbers do not jump on unmute + if pubMuted { + f.resyncLocked() + } + return true +} + +func (f *Forwarder) IsPubMuted() bool { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.pubMuted +} + +func (f *Forwarder) IsAnyMuted() bool { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.muted || f.pubMuted +} + +func (f *Forwarder) SetMaxSpatialLayer(spatialLayer int32) (bool, buffer.VideoLayer) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.kind == webrtc.RTPCodecTypeAudio { + return false, buffer.InvalidLayer + } + + existingMax := f.vls.GetMax() + if spatialLayer == existingMax.Spatial { + return false, existingMax + } + + f.logger.Debugw("setting max spatial layer", "layer", spatialLayer) + f.vls.SetMaxSpatial(spatialLayer) + if f.disableOpportunisticAllocation { + return true, f.vls.GetMax() + } + + if f.vls.GetTarget().Spatial != buffer.InvalidLayerSpatial || + f.isDeficientLocked() || + f.lastAllocation.PauseReason == VideoPauseReasonMuted || + f.lastAllocation.PauseReason == VideoPauseReasonPubMuted { + return true, f.vls.GetMax() + } + + f.logger.Debugw("opportunistically setting target spatial layer", "layer", spatialLayer) + + alloc := f.lastAllocation + + // bitrates are not known + alloc.BandwidthRequested = 0 + alloc.BandwidthDelta = 0 + alloc.Bitrates = Bitrates{} + + alloc.TargetLayer = f.vls.GetMax() + alloc.RequestLayerSpatial = f.vls.GetMax().Spatial + alloc.MaxLayer = f.vls.GetMax() + + alloc.DistanceToDesired = getDistanceToDesired( + f.muted, + f.pubMuted, + f.vls.GetMaxSeen(), + nil, + alloc.Bitrates, + alloc.TargetLayer, + f.vls.GetMax(), + ) + + f.updateAllocation(alloc, "opportunistic") + return true, f.vls.GetMax() +} + +func (f *Forwarder) SetMaxTemporalLayer(temporalLayer int32) (bool, buffer.VideoLayer) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.kind == webrtc.RTPCodecTypeAudio { + return false, buffer.InvalidLayer + } + + existingMax := f.vls.GetMax() + if temporalLayer == existingMax.Temporal { + return false, existingMax + } + + f.logger.Debugw("setting max temporal layer", "layer", temporalLayer) + f.vls.SetMaxTemporal(temporalLayer) + return true, f.vls.GetMax() +} + +func (f *Forwarder) MaxLayer() buffer.VideoLayer { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.vls.GetMax() +} + +func (f *Forwarder) CurrentLayer() buffer.VideoLayer { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.vls.GetCurrent() +} + +func (f *Forwarder) TargetLayer() buffer.VideoLayer { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.vls.GetTarget() +} + +func (f *Forwarder) GetMaxSubscribedSpatial() int32 { + f.lock.RLock() + defer f.lock.RUnlock() + + layer := buffer.InvalidLayerSpatial // covers muted case + if !f.muted { + // If current is higher, mark the current layer as max subscribed layer + // to prevent the current layer from stopping before forwarder switches + // to the new and lower max layer, + layer = max(f.vls.GetMax().Spatial, f.vls.GetCurrent().Spatial) + + // if reference layer is higher, hold there until an RTCP Sender Report from + // publisher is available as that is used for reference time stamp between layers. + if f.referenceLayerSpatial != buffer.InvalidLayerSpatial && + layer < f.referenceLayerSpatial && + f.refInfos[f.referenceLayerSpatial].senderReport == nil { + layer = f.referenceLayerSpatial + } + } + + return layer +} + +func (f *Forwarder) getRefLayer() (int32, int32) { + if f.lastSSRC == 0 { + return buffer.InvalidLayerSpatial, buffer.InvalidLayerSpatial + } + + if f.kind == webrtc.RTPCodecTypeAudio { + return 0, 0 + } + + currentLayerSpatial := f.vls.GetCurrent().Spatial + if currentLayerSpatial < 0 || currentLayerSpatial > buffer.DefaultMaxLayerSpatial { + return buffer.InvalidLayerSpatial, buffer.InvalidLayerSpatial + } + + if f.refVideoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + return 0, currentLayerSpatial + } + + return currentLayerSpatial, currentLayerSpatial +} + +func (f *Forwarder) SetRefSenderReport(layer int32, srData *livekit.RTCPSenderReportState) { + f.lock.Lock() + defer f.lock.Unlock() + + if layer >= 0 && int(layer) < len(f.refInfos) { + if layer == f.referenceLayerSpatial && f.refInfos[layer].senderReport == nil { + f.logger.Debugw( + "received RTCP sender report for reference layer spatial", + "layer", layer, + "srData", rtpstats.WrappedRTCPSenderReportStateLogger{RTCPSenderReportState: srData}, + ) + } + f.refInfos[layer] = refInfo{srData, 0, false} + + // Mark validity of time stamp offset. + // + // It is possible to implement mute using pause/unpause + // which can be implemented using replaceTrack(null)/replaceTrack(track). + // In those cases, the RTP time stamp may not jump across + // the mute/pause valley (for the time it is replaced with null track). + // So, relying on a report that happened before unmute/unpause + // could result in incorrect RTCP sender report on subscriber side. + // + // It could happen like this + // 1. Normal operation: publisher sending sender reports and + // subscribers use reports from publisher to calculate and send + // RTCP sender report. + // 2. Publisher pauses: there are no more reports. + // 3. When paused, subscriber can still use the publisher side sender + // report to send reports. Although the time since last publisher + // sender report is increasing, the reports would still be correct + // as they referencing a previous (albeit older) correct report. + // 4. Publisher unpauses after 20 seconds. But, it may not have advanced + // RTP Timestamp by that much. Let us say, it advances only by 5 seconds. + // 5. When subscriber starts forwarding packets, it will calculate + // a new time stamp offset to adjust to the new time stamp of publisher. + // 6. But, when that same offset is used on an old publisher sender report + // (i. e. a report from before the pause), the subscriber side sender + // reports jumps ahead in time by 15 seconds. + // + // So, mark valid for reports after last switch. + refLayer, _ := f.getRefLayer() + if layer == refLayer && srData.RtpTimestampExt >= f.lastSwitchExtIncomingTS { + f.refInfos[layer].tsOffset = f.rtpMunger.GetTSOffset() + f.refInfos[layer].isTSOffsetValid = true + } + } +} + +func (f *Forwarder) GetSenderReportParams() (int32, bool, uint64, *livekit.RTCPSenderReportState) { + f.lock.RLock() + defer f.lock.RUnlock() + + refLayer, currentLayerSpatial := f.getRefLayer() + if refLayer == buffer.InvalidLayerSpatial || + f.refInfos[refLayer].senderReport == nil || + !f.refInfos[refLayer].isTSOffsetValid { + return buffer.InvalidLayerSpatial, false, 0, nil + } + + return currentLayerSpatial, f.refVideoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM, f.refInfos[refLayer].tsOffset, f.refInfos[refLayer].senderReport +} + +func (f *Forwarder) isDeficientLocked() bool { + return f.lastAllocation.IsDeficient +} + +func (f *Forwarder) IsDeficient() bool { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.isDeficientLocked() +} + +func (f *Forwarder) PauseReason() VideoPauseReason { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.lastAllocation.PauseReason +} + +func (f *Forwarder) BandwidthRequested(brs Bitrates) int64 { + f.lock.RLock() + defer f.lock.RUnlock() + + return getBandwidthNeeded(brs, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested) +} + +func (f *Forwarder) DistanceToDesired(availableLayers []int32, brs Bitrates) float64 { + f.lock.RLock() + defer f.lock.RUnlock() + + return getDistanceToDesired( + f.muted, + f.pubMuted, + f.vls.GetMaxSeen(), + availableLayers, + brs, + f.vls.GetTarget(), + f.vls.GetMax(), + ) +} + +func (f *Forwarder) GetOptimalBandwidthNeeded(brs Bitrates) int64 { + f.lock.RLock() + defer f.lock.RUnlock() + + return getOptimalBandwidthNeeded(f.muted, f.pubMuted, f.vls.GetMaxSeen().Spatial, brs, f.vls.GetMax()) +} + +func (f *Forwarder) AllocateOptimal(availableLayers []int32, brs Bitrates, allowOvershoot bool, hold bool) VideoAllocation { + f.lock.Lock() + defer f.lock.Unlock() + + if f.kind == webrtc.RTPCodecTypeAudio { + return f.lastAllocation + } + + maxLayer := f.vls.GetMax() + maxSeenLayer := f.vls.GetMaxSeen() + currentLayer := f.vls.GetCurrent() + requestSpatial := f.vls.GetRequestSpatial() + alloc := VideoAllocation{ + PauseReason: VideoPauseReasonNone, + Bitrates: brs, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: requestSpatial, + MaxLayer: maxLayer, + } + optimalBandwidthNeeded := getOptimalBandwidthNeeded(f.muted, f.pubMuted, maxSeenLayer.Spatial, brs, maxLayer) + if optimalBandwidthNeeded == 0 { + alloc.PauseReason = VideoPauseReasonFeedDry + } + alloc.BandwidthNeeded = optimalBandwidthNeeded + + getMaxTemporal := func() int32 { + maxTemporal := maxLayer.Temporal + if maxSeenLayer.Temporal != buffer.InvalidLayerTemporal && maxSeenLayer.Temporal < maxTemporal { + maxTemporal = maxSeenLayer.Temporal + } + return maxTemporal + } + + opportunisticAlloc := func() { + // opportunistically latch on to anything + maxSpatial := maxLayer.Spatial + if allowOvershoot && f.vls.IsOvershootOkay() && maxSeenLayer.Spatial > maxSpatial { + maxSpatial = maxSeenLayer.Spatial + } + + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: min(maxSeenLayer.Spatial, maxSpatial), + Temporal: getMaxTemporal(), + } + } + + switch { + case !maxLayer.IsValid() || maxSeenLayer.Spatial == buffer.InvalidLayerSpatial: + // nothing to do when max layers are not valid OR max published layer is invalid + + case f.muted: + alloc.PauseReason = VideoPauseReasonMuted + + case f.pubMuted: + alloc.PauseReason = VideoPauseReasonPubMuted + + default: + // lots of different events could end up here + // 1. Publisher side layer resuming/stopping + // 2. Bitrate becoming available + // 3. New max published spatial layer or max temporal layer seen + // 4. Subscriber layer changes + // + // to handle all of the above + // 1. Find highest that can be requested - takes into account available layers and overshoot. + // This should catch scenarios like layers resuming/stopping. + // 2. If current is a valid layer, check against currently available layers and continue at current + // if possible. Else, choose the highest available layer as the next target. + // 3. If current is not valid, set next target to be opportunistic. + maxLayerSpatialLimit := min(maxLayer.Spatial, maxSeenLayer.Spatial) + highestAvailableLayer := buffer.InvalidLayerSpatial + lowestAvailableLayer := buffer.InvalidLayerSpatial + requestLayerSpatial := buffer.InvalidLayerSpatial + for _, al := range availableLayers { + if al > requestLayerSpatial && al <= maxLayerSpatialLimit { + requestLayerSpatial = al + } + if al > highestAvailableLayer { + highestAvailableLayer = al + } + if lowestAvailableLayer == buffer.InvalidLayerSpatial || al < lowestAvailableLayer { + lowestAvailableLayer = al + } + } + if requestLayerSpatial == buffer.InvalidLayerSpatial && highestAvailableLayer != buffer.InvalidLayerSpatial && allowOvershoot && f.vls.IsOvershootOkay() { + requestLayerSpatial = highestAvailableLayer + } + + if currentLayer.IsValid() { + if (requestLayerSpatial == requestSpatial && currentLayer.Spatial == requestSpatial) || requestLayerSpatial == buffer.InvalidLayerSpatial { + // 1. current is locked to desired, stay there + // OR + // 2. feed may be dry, let it continue at current layer if valid. + // covers the cases of + // 1. mis-detection of layer stop - can continue streaming + // 2. current layer resuming - can latch on when it starts + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: currentLayer.Spatial, + Temporal: getMaxTemporal(), + } + } else { + // current layer has stopped, switch to lowest available if `hold`ing, else switch to highest available + if hold { + // if `hold` is requested, may be set due to early warning congestion + // signal, in that case layers are not increased as increasing layers + // will result in more load on the channel + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: lowestAvailableLayer, + Temporal: 0, + } + } else { + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: requestLayerSpatial, + Temporal: getMaxTemporal(), + } + } + } + alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial + } else { + if hold { + // allocate minimal to make the stream active while `hold`ing. + if lowestAvailableLayer == buffer.InvalidLayerSpatial { + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + } else { + alloc.TargetLayer = buffer.VideoLayer{ + Spatial: lowestAvailableLayer, + Temporal: 0, + } + } + alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial + } else { + // opportunistically latch on to anything + opportunisticAlloc() + if requestLayerSpatial == buffer.InvalidLayerSpatial { + alloc.RequestLayerSpatial = maxLayerSpatialLimit + } else { + alloc.RequestLayerSpatial = requestLayerSpatial + } + } + } + } + + if !alloc.TargetLayer.IsValid() { + alloc.TargetLayer = buffer.InvalidLayer + alloc.RequestLayerSpatial = buffer.InvalidLayerSpatial + } + if alloc.TargetLayer.IsValid() { + alloc.BandwidthRequested = getOptimalBandwidthNeeded(f.muted, f.pubMuted, maxSeenLayer.Spatial, brs, alloc.TargetLayer) + } + alloc.BandwidthDelta = alloc.BandwidthRequested - getBandwidthNeeded(brs, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested) + alloc.DistanceToDesired = getDistanceToDesired( + f.muted, + f.pubMuted, + f.vls.GetMaxSeen(), + availableLayers, + brs, + alloc.TargetLayer, + f.vls.GetMax(), + ) + + return f.updateAllocation(alloc, "optimal") +} + +func (f *Forwarder) ProvisionalAllocatePrepare(availableLayers []int32, bitrates Bitrates) { + f.lock.Lock() + defer f.lock.Unlock() + + f.provisional = &VideoAllocationProvisional{ + allocatedLayer: buffer.InvalidLayer, + muted: f.muted, + pubMuted: f.pubMuted, + maxSeenLayer: f.vls.GetMaxSeen(), + bitrates: bitrates, + maxLayer: f.vls.GetMax(), + currentLayer: f.vls.GetCurrent(), + } + + f.provisional.availableLayers = make([]int32, len(availableLayers)) + copy(f.provisional.availableLayers, availableLayers) +} + +func (f *Forwarder) ProvisionalAllocateReset() { + f.lock.Lock() + defer f.lock.Unlock() + + f.provisional.allocatedLayer = buffer.InvalidLayer +} + +func (f *Forwarder) ProvisionalAllocate(availableChannelCapacity int64, layer buffer.VideoLayer, allowPause bool, allowOvershoot bool) (bool, int64) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.provisional.muted || + f.provisional.pubMuted || + f.provisional.maxSeenLayer.Spatial == buffer.InvalidLayerSpatial || + !f.provisional.maxLayer.IsValid() || + ((!allowOvershoot || !f.vls.IsOvershootOkay()) && layer.GreaterThan(f.provisional.maxLayer)) { + return false, 0 + } + + requiredBitrate := f.provisional.bitrates[layer.Spatial][layer.Temporal] + if requiredBitrate == 0 { + return false, 0 + } + + alreadyAllocatedBitrate := int64(0) + if f.provisional.allocatedLayer.IsValid() { + alreadyAllocatedBitrate = f.provisional.bitrates[f.provisional.allocatedLayer.Spatial][f.provisional.allocatedLayer.Temporal] + } + + // a layer under maximum fits, take it + if !layer.GreaterThan(f.provisional.maxLayer) && requiredBitrate <= (availableChannelCapacity+alreadyAllocatedBitrate) { + f.provisional.allocatedLayer = layer + return true, requiredBitrate - alreadyAllocatedBitrate + } + + // + // Given layer does not fit. + // + // Could be one of + // 1. a layer below maximum that does not fit + // 2. a layer above maximum which may or may not fit, but overshoot is allowed. + // In any of those cases, take the lowest possible layer if pause is not allowed + // + if !allowPause && (!f.provisional.allocatedLayer.IsValid() || !layer.GreaterThan(f.provisional.allocatedLayer)) { + f.provisional.allocatedLayer = layer + return true, requiredBitrate - alreadyAllocatedBitrate + } + + return false, 0 +} + +func (f *Forwarder) ProvisionalAllocateGetCooperativeTransition(allowOvershoot bool) (VideoTransition, []int32, Bitrates) { + // + // This is called when a track needs a change (could be mute/unmute, subscribed layers changed, published layers changed) + // when channel is congested. + // + // The goal is to provide a co-operative transition. Co-operative stream allocation aims to keep all the streams active + // as much as possible. + // + // When channel is congested, effecting a transition which will consume more bits will lead to more congestion. + // So, this routine does the following + // 1. When muting, it is not going to increase consumption. + // 2. If the stream is currently active and the transition needs more bits (higher layers = more bits), do not make the up move. + // The higher layer requirement could be due to a new published layer becoming available or subscribed layers changing. + // 3. If the new target layers are lower than current target, take the move down and save bits. + // 4. If not currently streaming, find the minimum layers that can unpause the stream. + // + // To summarize, co-operative streaming means + // - Try to keep tracks streaming, i.e. no pauses at the expense of some streams not being at optimal layers + // - Do not make an upgrade as it could affect other tracks + // + f.lock.Lock() + defer f.lock.Unlock() + + existingTargetLayer := f.vls.GetTarget() + if f.provisional.muted || f.provisional.pubMuted { + f.provisional.allocatedLayer = buffer.InvalidLayer + return VideoTransition{ + From: existingTargetLayer, + To: f.provisional.allocatedLayer, + BandwidthDelta: -getBandwidthNeeded(f.provisional.bitrates, existingTargetLayer, f.lastAllocation.BandwidthRequested), + }, f.provisional.availableLayers, f.provisional.bitrates + } + + // check if we should preserve current target + if existingTargetLayer.IsValid() { + // what is the highest that is available + maximalLayer := buffer.InvalidLayer + maximalBandwidthRequired := int64(0) + for s := f.provisional.maxLayer.Spatial; s >= 0; s-- { + for t := f.provisional.maxLayer.Temporal; t >= 0; t-- { + if f.provisional.bitrates[s][t] != 0 { + maximalLayer = buffer.VideoLayer{Spatial: s, Temporal: t} + maximalBandwidthRequired = f.provisional.bitrates[s][t] + break + } + } + + if maximalBandwidthRequired != 0 { + break + } + } + + if maximalLayer.IsValid() { + if !existingTargetLayer.GreaterThan(maximalLayer) && f.provisional.bitrates[existingTargetLayer.Spatial][existingTargetLayer.Temporal] != 0 { + // currently streaming and maybe wanting an upgrade (existingTargetLayer <= maximalLayer), + // just preserve current target in the cooperative scheme of things + f.provisional.allocatedLayer = existingTargetLayer + return VideoTransition{ + From: existingTargetLayer, + To: existingTargetLayer, + BandwidthDelta: 0, + }, f.provisional.availableLayers, f.provisional.bitrates + } + + if existingTargetLayer.GreaterThan(maximalLayer) { + // maximalLayer < existingTargetLayer, make the down move + f.provisional.allocatedLayer = maximalLayer + return VideoTransition{ + From: existingTargetLayer, + To: maximalLayer, + BandwidthDelta: maximalBandwidthRequired - getBandwidthNeeded(f.provisional.bitrates, existingTargetLayer, f.lastAllocation.BandwidthRequested), + }, f.provisional.availableLayers, f.provisional.bitrates + } + } + } + + findNextLayer := func( + minSpatial, maxSpatial int32, + minTemporal, maxTemporal int32, + ) (buffer.VideoLayer, int64) { + layers := buffer.InvalidLayer + bw := int64(0) + for s := minSpatial; s <= maxSpatial; s++ { + for t := minTemporal; t <= maxTemporal; t++ { + if f.provisional.bitrates[s][t] != 0 { + layers = buffer.VideoLayer{Spatial: s, Temporal: t} + bw = f.provisional.bitrates[s][t] + break + } + } + + if bw != 0 { + break + } + } + + return layers, bw + } + + targetLayer := buffer.InvalidLayer + bandwidthRequired := int64(0) + if !existingTargetLayer.IsValid() { + // currently not streaming, find minimal + // NOTE: a layer in feed could have paused and there could be other options than going back to minimal, + // but the cooperative scheme knocks things back to minimal + targetLayer, bandwidthRequired = findNextLayer( + 0, f.provisional.maxLayer.Spatial, + 0, f.provisional.maxLayer.Temporal, + ) + + // could not find a minimal layer, overshoot if allowed + if bandwidthRequired == 0 && f.provisional.maxLayer.IsValid() && allowOvershoot && f.vls.IsOvershootOkay() { + targetLayer, bandwidthRequired = findNextLayer( + f.provisional.maxLayer.Spatial+1, buffer.DefaultMaxLayerSpatial, + 0, buffer.DefaultMaxLayerTemporal, + ) + } + } + + // if nothing available, just leave target at current to enable opportunistic forwarding in case current resumes + if !targetLayer.IsValid() { + targetLayer = f.provisional.currentLayer + if targetLayer.IsValid() { + bandwidthRequired = f.provisional.bitrates[targetLayer.Spatial][targetLayer.Temporal] + } + } + + f.provisional.allocatedLayer = targetLayer + return VideoTransition{ + From: f.vls.GetTarget(), + To: targetLayer, + BandwidthDelta: bandwidthRequired - getBandwidthNeeded(f.provisional.bitrates, existingTargetLayer, f.lastAllocation.BandwidthRequested), + }, f.provisional.availableLayers, f.provisional.bitrates +} + +func (f *Forwarder) ProvisionalAllocateGetBestWeightedTransition() (VideoTransition, []int32, Bitrates) { + // + // This is called when a track needs a change (could be mute/unmute, subscribed layers changed, published layers changed) + // when channel is congested. This is called on tracks other than the one needing the change. When the track + // needing the change requires bits, this is called to check if this track can contribute some bits to the pool. + // + // The goal is to keep all tracks streaming as much as possible. So, the track that needs a change needs bandwidth to be unpaused. + // + // This tries to figure out how much this track can contribute back to the pool to enable the track that needs to be unpaused. + // 1. Track muted OR feed dry - can contribute everything back in case it was using bandwidth. + // 2. Look at all possible down transitions from current target and find the best offer. + // Best offer is calculated as bandwidth saved moving to a down layer divided by cost. + // Cost has two components + // a. Transition cost: Spatial layer switch is expensive due to key frame requirement, but temporal layer switch is free. + // b. Quality cost: The farther away from desired layers, the higher the quality cost. + // + f.lock.Lock() + defer f.lock.Unlock() + + targetLayer := f.vls.GetTarget() + if f.provisional.muted || f.provisional.pubMuted { + f.provisional.allocatedLayer = buffer.InvalidLayer + return VideoTransition{ + From: targetLayer, + To: f.provisional.allocatedLayer, + BandwidthDelta: 0 - getBandwidthNeeded(f.provisional.bitrates, targetLayer, f.lastAllocation.BandwidthRequested), + }, f.provisional.availableLayers, f.provisional.bitrates + } + + maxReachableLayerTemporal := buffer.InvalidLayerTemporal + for t := f.provisional.maxLayer.Temporal; t >= 0; t-- { + for s := f.provisional.maxLayer.Spatial; s >= 0; s-- { + if f.provisional.bitrates[s][t] != 0 { + maxReachableLayerTemporal = t + break + } + } + if maxReachableLayerTemporal != buffer.InvalidLayerTemporal { + break + } + } + + if maxReachableLayerTemporal == buffer.InvalidLayerTemporal { + // feed has gone dry, just leave target at current to enable opportunistic forwarding in case current resumes. + // Note that this is giving back bits and opportunistic forwarding resuming might trigger congestion again, + // but that should be handled by stream allocator. + f.provisional.allocatedLayer = f.provisional.currentLayer + return VideoTransition{ + From: targetLayer, + To: f.provisional.allocatedLayer, + BandwidthDelta: 0 - getBandwidthNeeded(f.provisional.bitrates, targetLayer, f.lastAllocation.BandwidthRequested), + }, f.provisional.availableLayers, f.provisional.bitrates + } + + // starting from minimum to target, find transition which gives the best + // transition taking into account bits saved vs cost of such a transition + existingBandwidthNeeded := getBandwidthNeeded(f.provisional.bitrates, targetLayer, f.lastAllocation.BandwidthRequested) + bestLayer := buffer.InvalidLayer + bestBandwidthDelta := int64(0) + bestValue := float32(0) + for s := int32(0); s <= targetLayer.Spatial; s++ { + for t := int32(0); t <= targetLayer.Temporal; t++ { + if s == targetLayer.Spatial && t == targetLayer.Temporal { + break + } + + bandwidthDelta := max(0, existingBandwidthNeeded-f.provisional.bitrates[s][t]) + + transitionCost := int32(0) + // SVC-TODO: SVC will need a different cost transition + if targetLayer.Spatial != s { + transitionCost = TransitionCostSpatial + } + + qualityCost := (maxReachableLayerTemporal+1)*(targetLayer.Spatial-s) + (targetLayer.Temporal - t) + + value := float32(0) + if (transitionCost + qualityCost) != 0 { + value = float32(bandwidthDelta) / float32(transitionCost+qualityCost) + } + if value > bestValue || (value == bestValue && bandwidthDelta > bestBandwidthDelta) { + bestValue = value + bestBandwidthDelta = bandwidthDelta + bestLayer = buffer.VideoLayer{Spatial: s, Temporal: t} + } + } + } + + f.provisional.allocatedLayer = bestLayer + return VideoTransition{ + From: targetLayer, + To: bestLayer, + BandwidthDelta: -bestBandwidthDelta, + }, f.provisional.availableLayers, f.provisional.bitrates +} + +func (f *Forwarder) ProvisionalAllocateCommit() VideoAllocation { + f.lock.Lock() + defer f.lock.Unlock() + + optimalBandwidthNeeded := getOptimalBandwidthNeeded( + f.provisional.muted, + f.provisional.pubMuted, + f.provisional.maxSeenLayer.Spatial, + f.provisional.bitrates, + f.provisional.maxLayer, + ) + alloc := VideoAllocation{ + BandwidthRequested: 0, + BandwidthDelta: 0 - getBandwidthNeeded(f.provisional.bitrates, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested), + Bitrates: f.provisional.bitrates, + BandwidthNeeded: optimalBandwidthNeeded, + TargetLayer: f.provisional.allocatedLayer, + RequestLayerSpatial: f.provisional.allocatedLayer.Spatial, + MaxLayer: f.provisional.maxLayer, + DistanceToDesired: getDistanceToDesired( + f.provisional.muted, + f.provisional.pubMuted, + f.provisional.maxSeenLayer, + f.provisional.availableLayers, + f.provisional.bitrates, + f.provisional.allocatedLayer, + f.provisional.maxLayer, + ), + } + + switch { + case f.provisional.muted: + alloc.PauseReason = VideoPauseReasonMuted + + case f.provisional.pubMuted: + alloc.PauseReason = VideoPauseReasonPubMuted + + case optimalBandwidthNeeded == 0: + if f.provisional.allocatedLayer.IsValid() { + // overshoot + alloc.BandwidthRequested = f.provisional.bitrates[f.provisional.allocatedLayer.Spatial][f.provisional.allocatedLayer.Temporal] + alloc.BandwidthDelta = alloc.BandwidthRequested - getBandwidthNeeded(f.provisional.bitrates, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested) + } else { + alloc.PauseReason = VideoPauseReasonFeedDry + + // leave target at current for opportunistic forwarding + if f.provisional.currentLayer.IsValid() && f.provisional.currentLayer.Spatial <= f.provisional.maxLayer.Spatial { + f.provisional.allocatedLayer = f.provisional.currentLayer + alloc.TargetLayer = f.provisional.allocatedLayer + alloc.RequestLayerSpatial = alloc.TargetLayer.Spatial + } + } + + default: + if f.provisional.allocatedLayer.IsValid() { + alloc.BandwidthRequested = f.provisional.bitrates[f.provisional.allocatedLayer.Spatial][f.provisional.allocatedLayer.Temporal] + } + alloc.BandwidthDelta = alloc.BandwidthRequested - getBandwidthNeeded(f.provisional.bitrates, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested) + + if f.provisional.allocatedLayer.GreaterThan(f.provisional.maxLayer) || + alloc.BandwidthRequested >= getOptimalBandwidthNeeded( + f.provisional.muted, + f.provisional.pubMuted, + f.provisional.maxSeenLayer.Spatial, + f.provisional.bitrates, + f.provisional.maxLayer, + ) { + // could be greater than optimal if overshooting + alloc.IsDeficient = false + } else { + alloc.IsDeficient = true + if !f.provisional.allocatedLayer.IsValid() { + alloc.PauseReason = VideoPauseReasonBandwidth + } + } + } + + return f.updateAllocation(alloc, "cooperative") +} + +func (f *Forwarder) AllocateNextHigher(availableChannelCapacity int64, availableLayers []int32, brs Bitrates, allowOvershoot bool) (VideoAllocation, bool) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.kind == webrtc.RTPCodecTypeAudio { + return f.lastAllocation, false + } + + // if not deficient, nothing to do + if !f.isDeficientLocked() { + return f.lastAllocation, false + } + + maxLayer := f.vls.GetMax() + maxSeenLayer := f.vls.GetMaxSeen() + optimalBandwidthNeeded := getOptimalBandwidthNeeded(f.muted, f.pubMuted, maxSeenLayer.Spatial, brs, maxLayer) + + alreadyAllocated := int64(0) + targetLayer := f.vls.GetTarget() + if targetLayer.IsValid() { + alreadyAllocated = brs[targetLayer.Spatial][targetLayer.Temporal] + } + + doAllocation := func( + minSpatial, maxSpatial int32, + minTemporal, maxTemporal int32, + ) (bool, VideoAllocation, bool) { + for s := minSpatial; s <= maxSpatial; s++ { + for t := minTemporal; t <= maxTemporal; t++ { + bandwidthRequested := brs[s][t] + if bandwidthRequested == 0 { + continue + } + + if (!allowOvershoot || !f.vls.IsOvershootOkay()) && bandwidthRequested-alreadyAllocated > availableChannelCapacity { + // next higher available layer does not fit, return + return true, f.lastAllocation, false + } + + newTargetLayer := buffer.VideoLayer{Spatial: s, Temporal: t} + alloc := VideoAllocation{ + IsDeficient: true, + BandwidthRequested: bandwidthRequested, + BandwidthDelta: bandwidthRequested - alreadyAllocated, + BandwidthNeeded: optimalBandwidthNeeded, + Bitrates: brs, + TargetLayer: newTargetLayer, + RequestLayerSpatial: newTargetLayer.Spatial, + MaxLayer: maxLayer, + DistanceToDesired: getDistanceToDesired( + f.muted, + f.pubMuted, + maxSeenLayer, + availableLayers, + brs, + newTargetLayer, + maxLayer, + ), + } + if newTargetLayer.GreaterThan(maxLayer) || bandwidthRequested >= optimalBandwidthNeeded { + alloc.IsDeficient = false + } + + return true, f.updateAllocation(alloc, "next-higher"), true + } + } + + return false, VideoAllocation{}, false + } + + done := false + var allocation VideoAllocation + boosted := false + + // try moving temporal layer up in currently streaming spatial layer + if targetLayer.IsValid() { + done, allocation, boosted = doAllocation( + targetLayer.Spatial, targetLayer.Spatial, + targetLayer.Temporal+1, maxLayer.Temporal, + ) + if done { + return allocation, boosted + } + } + + // try moving spatial layer up if temporal layer move up is not available + done, allocation, boosted = doAllocation( + targetLayer.Spatial+1, maxLayer.Spatial, + 0, maxLayer.Temporal, + ) + if done { + return allocation, boosted + } + + if allowOvershoot && f.vls.IsOvershootOkay() && maxLayer.IsValid() { + done, allocation, boosted = doAllocation( + maxLayer.Spatial+1, buffer.DefaultMaxLayerSpatial, + 0, buffer.DefaultMaxLayerTemporal, + ) + if done { + return allocation, boosted + } + } + + return f.lastAllocation, false +} + +func (f *Forwarder) GetNextHigherTransition(brs Bitrates, allowOvershoot bool) (VideoTransition, bool) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.kind == webrtc.RTPCodecTypeAudio { + return VideoTransition{}, false + } + + // if not deficient, nothing to do + if !f.isDeficientLocked() { + return VideoTransition{}, false + } + + // if targets are still pending, don't increase + targetLayer := f.vls.GetTarget() + if targetLayer.IsValid() && targetLayer != f.vls.GetCurrent() { + return VideoTransition{}, false + } + + alreadyAllocated := int64(0) + if targetLayer.IsValid() { + alreadyAllocated = brs[targetLayer.Spatial][targetLayer.Temporal] + } + + findNextHigher := func( + minSpatial, maxSpatial int32, + minTemporal, maxTemporal int32, + ) (bool, VideoTransition, bool) { + for s := minSpatial; s <= maxSpatial; s++ { + for t := minTemporal; t <= maxTemporal; t++ { + bandwidthRequested := brs[s][t] + // traverse till finding a layer requiring more bits. + // NOTE: it possible that higher temporal layer of lower spatial layer + // could use more bits than lower temporal layer of higher spatial layer. + if bandwidthRequested == 0 || bandwidthRequested < alreadyAllocated { + continue + } + + transition := VideoTransition{ + From: targetLayer, + To: buffer.VideoLayer{Spatial: s, Temporal: t}, + BandwidthDelta: bandwidthRequested - alreadyAllocated, + } + + return true, transition, true + } + } + + return false, VideoTransition{}, false + } + + done := false + var transition VideoTransition + isAvailable := false + + // try moving temporal layer up in currently streaming spatial layer + maxLayer := f.vls.GetMax() + if targetLayer.IsValid() { + done, transition, isAvailable = findNextHigher( + targetLayer.Spatial, targetLayer.Spatial, + targetLayer.Temporal+1, maxLayer.Temporal, + ) + if done { + return transition, isAvailable + } + } + + // try moving spatial layer up if temporal layer move up is not available + done, transition, isAvailable = findNextHigher( + targetLayer.Spatial+1, maxLayer.Spatial, + 0, maxLayer.Temporal, + ) + if done { + return transition, isAvailable + } + + if allowOvershoot && f.vls.IsOvershootOkay() && maxLayer.IsValid() { + done, transition, isAvailable = findNextHigher( + maxLayer.Spatial+1, buffer.DefaultMaxLayerSpatial, + 0, buffer.DefaultMaxLayerTemporal, + ) + if done { + return transition, isAvailable + } + } + + return VideoTransition{}, false +} + +func (f *Forwarder) Pause(availableLayers []int32, brs Bitrates) VideoAllocation { + f.lock.Lock() + defer f.lock.Unlock() + + maxLayer := f.vls.GetMax() + maxSeenLayer := f.vls.GetMaxSeen() + optimalBandwidthNeeded := getOptimalBandwidthNeeded(f.muted, f.pubMuted, maxSeenLayer.Spatial, brs, maxLayer) + alloc := VideoAllocation{ + BandwidthRequested: 0, + BandwidthDelta: 0 - getBandwidthNeeded(brs, f.vls.GetTarget(), f.lastAllocation.BandwidthRequested), + Bitrates: brs, + BandwidthNeeded: optimalBandwidthNeeded, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: maxLayer, + DistanceToDesired: getDistanceToDesired( + f.muted, + f.pubMuted, + maxSeenLayer, + availableLayers, + brs, + buffer.InvalidLayer, + maxLayer, + ), + } + + switch { + case f.muted: + alloc.PauseReason = VideoPauseReasonMuted + + case f.pubMuted: + alloc.PauseReason = VideoPauseReasonPubMuted + + case optimalBandwidthNeeded == 0: + alloc.PauseReason = VideoPauseReasonFeedDry + + default: + // pausing due to lack of bandwidth + alloc.IsDeficient = true + alloc.PauseReason = VideoPauseReasonBandwidth + } + + return f.updateAllocation(alloc, "pause") +} + +func (f *Forwarder) updateAllocation(alloc VideoAllocation, reason string) VideoAllocation { + // restrict target temporal to 0 if codec does not support temporal layers + if alloc.TargetLayer.IsValid() && f.mime == mime.MimeTypeH264 { + alloc.TargetLayer.Temporal = 0 + } + + if alloc.IsDeficient != f.lastAllocation.IsDeficient || + alloc.PauseReason != f.lastAllocation.PauseReason || + alloc.TargetLayer != f.lastAllocation.TargetLayer || + alloc.RequestLayerSpatial != f.lastAllocation.RequestLayerSpatial { + f.logger.Debugw( + fmt.Sprintf("stream allocation: %s", reason), + "allocation", &alloc, + "lastAllocation", &f.lastAllocation, + ) + } + f.lastAllocation = alloc + + f.setTargetLayer(f.lastAllocation.TargetLayer, f.lastAllocation.RequestLayerSpatial) + if !f.vls.GetTarget().IsValid() { + f.resyncLocked() + } + + return f.lastAllocation +} + +func (f *Forwarder) setTargetLayer(targetLayer buffer.VideoLayer, requestLayerSpatial int32) { + f.vls.SetTarget(targetLayer) + if targetLayer.IsValid() { + f.vls.SetRequestSpatial(requestLayerSpatial) + } else { + f.vls.SetRequestSpatial(buffer.InvalidLayerSpatial) + } +} + +func (f *Forwarder) Resync() { + f.lock.Lock() + defer f.lock.Unlock() + + f.resyncLocked() +} + +func (f *Forwarder) resyncLocked() { + f.vls.SetCurrent(buffer.InvalidLayer) + f.lastSSRC = 0 + if f.pubMuted { + f.resumeBehindThreshold = ResumeBehindThresholdSeconds + } +} + +func (f *Forwarder) CheckSync() (bool, int32) { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.vls.CheckSync() +} + +func (f *Forwarder) Restart() { + f.lock.Lock() + defer f.lock.Unlock() + + f.resyncLocked() + f.setTargetLayer(buffer.InvalidLayer, buffer.InvalidLayerSpatial) + f.referenceLayerSpatial = buffer.InvalidLayerSpatial + f.lastReferencePayloadType = -1 + + for layer := range len(f.refInfos) { + f.refInfos[layer] = refInfo{} + } + f.lastSwitchExtIncomingTS = 0 + f.refVideoLayerMode = livekit.VideoLayer_MODE_UNUSED +} + +func (f *Forwarder) FilterRTX(nacks []uint16) (filtered []uint16, disallowedLayers [buffer.DefaultMaxLayerSpatial + 1]bool) { + f.lock.RLock() + defer f.lock.RUnlock() + + if !FlagFilterRTX { + filtered = nacks + } else { + filtered = f.rtpMunger.FilterRTX(nacks) + } + + // + // Curb RTX when deficient for two cases + // 1. Target layer is lower than current layer. When current hits target, a key frame should flush the decoder. + // 2. Requested layer is higher than current. Current layer's key frame should have flushed encoder. + // Remote might ask for older layer because of its jitter buffer, but let it starve as channel is already congested. + // + // Without the curb, when congestion hits, RTX rate could be so high that it further congests the channel. + // + if FlagFilterRTXLayers { + currentLayer := f.vls.GetCurrent() + targetLayer := f.vls.GetTarget() + for layer := range buffer.DefaultMaxLayerSpatial + 1 { + if f.isDeficientLocked() && (targetLayer.Spatial < currentLayer.Spatial || layer > currentLayer.Spatial) { + disallowedLayers[layer] = true + } + } + } + return +} + +func (f *Forwarder) GetTranslationParams(extPkt *buffer.ExtPacket, layer int32) (TranslationParams, error) { + f.lock.Lock() + defer f.lock.Unlock() + + if f.muted || f.pubMuted { + return TranslationParams{ + shouldDrop: true, + }, nil + } + + switch f.kind { + case webrtc.RTPCodecTypeAudio: + return f.getTranslationParamsAudio(extPkt, layer) + + case webrtc.RTPCodecTypeVideo: + return f.getTranslationParamsVideo(extPkt, layer) + } + + return TranslationParams{ + shouldDrop: true, + }, errUnknownKind +} + +func (f *Forwarder) getRefLayerRTPTimestamp(ts uint32, refLayer, targetLayer int32) (uint32, error) { + if refLayer < 0 || int(refLayer) > len(f.refInfos) || targetLayer < 0 || int(targetLayer) > len(f.refInfos) { + return 0, fmt.Errorf("invalid layer(s), refLayer: %d, targetLayer: %d", refLayer, targetLayer) + } + + if refLayer == targetLayer || f.refVideoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + return ts, nil + } + + srRef := f.refInfos[refLayer].senderReport + srTarget := f.refInfos[targetLayer].senderReport + if srRef == nil || srRef.NtpTimestamp == 0 { + return 0, fmt.Errorf("unavailable layer ref, refLayer: %d, targetLayer: %d", refLayer, targetLayer) + } + if srTarget == nil || srTarget.NtpTimestamp == 0 { + return 0, fmt.Errorf("unavailable layer target, refLayer: %d, targetLayer: %d", refLayer, targetLayer) + } + + ntpDiff := mediatransportutil.NtpTime(srRef.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(srTarget.NtpTimestamp).Time()) + rtpDiff := ntpDiff.Nanoseconds() * int64(f.clockRate) / 1e9 + + // calculate other layer's time stamp at the same time as ref layer's NTP time + normalizedOtherTS := srTarget.RtpTimestamp + uint32(rtpDiff) + + // now both layers' time stamp refer to the same NTP time and the diff is the offset between the layers + offset := srRef.RtpTimestamp - normalizedOtherTS + + return ts + offset, nil +} + +func (f *Forwarder) processSourceSwitch(extPkt *buffer.ExtPacket, layer int32) error { + if !f.started { + if extPkt.IsOutOfOrder { + return errSkipStartOnOutOfOrderPacket + } + + f.started = true + f.referenceLayerSpatial = layer + f.rtpMunger.SetLastSnTs(extPkt) + f.codecMunger.SetLast(extPkt) + f.logger.Debugw( + "starting forwarding", + "sequenceNumber", extPkt.Packet.SequenceNumber, + "extSequenceNumber", extPkt.ExtSequenceNumber, + "timestamp", extPkt.Packet.Timestamp, + "extTimestamp", extPkt.ExtTimestamp, + "layer", layer, + "referenceLayerSpatial", f.referenceLayerSpatial, + ) + return nil + } else if f.referenceLayerSpatial == buffer.InvalidLayerSpatial { + if extPkt.IsOutOfOrder { + return errSkipStartOnOutOfOrderPacket + } + + f.referenceLayerSpatial = layer + f.codecMunger.SetLast(extPkt) + f.logger.Debugw( + "catch up forwarding", + "sequenceNumber", extPkt.Packet.SequenceNumber, + "extSequenceNumber", extPkt.ExtSequenceNumber, + "timestamp", extPkt.Packet.Timestamp, + "extTimestamp", extPkt.ExtTimestamp, + "layer", layer, + "referenceLayerSpatial", f.referenceLayerSpatial, + ) + } + + logTransition := func(message string, extExpectedTS, extRefTS, extLastTS uint64, diffSeconds float64) { + f.logger.Debugw( + message, + "layer", layer, + "referenceLayerSpatial", f.referenceLayerSpatial, + "extExpectedTS", extExpectedTS, + "incomingTS", extPkt.Packet.Timestamp, + "extIncomingTS", extPkt.ExtTimestamp, + "extRefTS", extRefTS, + "extLastTS", extLastTS, + "diffSeconds", math.Abs(diffSeconds), + "refInfos", logger.ObjectSlice(f.refInfos[:]), + "lastSwitchExtIncomingTS", f.lastSwitchExtIncomingTS, + "rtpStats", f.rtpStats, + ) + } + // TODO-REMOVE-AFTER-DATA-COLLECTION + logTransitionInfo := func(message string, extExpectedTS, extRefTS, extLastTS uint64, diffSeconds float64) { + f.logger.Infow( + message, + "layer", layer, + "referenceLayerSpatial", f.referenceLayerSpatial, + "extExpectedTS", extExpectedTS, + "incomingTS", extPkt.Packet.Timestamp, + "extIncomingTS", extPkt.ExtTimestamp, + "extRefTS", extRefTS, + "extLastTS", extLastTS, + "diffSeconds", math.Abs(diffSeconds), + "refInfos", logger.ObjectSlice(f.refInfos[:]), + "lastSwitchExtIncomingTS", f.lastSwitchExtIncomingTS, + "rtpStats", f.rtpStats, + ) + } + + // Compute how much time passed between the previous forwarded packet + // and the current incoming (to be forwarded) packet and calculate + // timestamp offset on source change. + // + // There are three timestamps to consider here + // 1. extLastTS -> timestamp of last sent packet + // 2. extRefTS -> timestamp of this packet (after munging) calculated using feed's RTCP sender report + // 3. extExpectedTS -> expected timestamp of this packet calculated based on elapsed time since first packet + // Ideally, extRefTS and extExpectedTS should be very close and extLastTS should be before both of those. + // But, cases like muting/unmuting, clock vagaries, pacing, etc. make them not satisfy those conditions always. + rtpMungerState := f.rtpMunger.GetState() + extLastTS := rtpMungerState.ExtLastTimestamp + extExpectedTS := extLastTS + extRefTS := extLastTS + refTS := uint32(extRefTS) + switchingAt := mono.Now() + if !f.skipReferenceTS { + var err error + refTS, err = f.getRefLayerRTPTimestamp(extPkt.Packet.Timestamp, f.referenceLayerSpatial, layer) + if err != nil { + // error out if refTS is not available. It can happen when there is no sender report + // for the layer being switched to. Can especially happen at the start of the track when layer switches are + // potentially happening very quickly. Erroring out and waiting for a layer for which a sender report has been + // received will calculate a better offset, but may result in initial adaptation to take a bit longer depending + // on how often publisher/remote side sends RTCP sender report. + f.logger.Debugw( + "could not get ref layer timestamp", + "referenceLayerSpatial", f.referenceLayerSpatial, + "layer", layer, + "error", err, + ) + return err + } + } + + // adjust extRefTS to current packet's timestamp mapped to that of reference layer's + extRefTS = (extRefTS & 0xFFFF_FFFF_0000_0000) + uint64(refTS) + f.dummyStartTSOffset + lastTS := uint32(extLastTS) + refTS = uint32(extRefTS) + if (refTS-lastTS) < 1<<31 && refTS < lastTS { + extRefTS += (1 << 32) + } + if (lastTS-refTS) < 1<<31 && lastTS < refTS && extRefTS >= 1<<32 { + extRefTS -= (1 << 32) + } + + if f.rtpStats != nil { + tsExt, err := f.rtpStats.GetExpectedRTPTimestamp(switchingAt) + if err == nil { + extExpectedTS = tsExt + if f.lastReferencePayloadType == -1 { + f.dummyStartTSOffset = extExpectedTS - uint64(refTS) + extRefTS = extExpectedTS + } + } else { + if !f.preStartTime.IsZero() { + timeSinceFirst := time.Since(f.preStartTime) + rtpDiff := uint64(timeSinceFirst.Nanoseconds() * int64(f.clockRate) / 1e9) + extExpectedTS = f.extFirstTS + rtpDiff + if f.dummyStartTSOffset == 0 { + f.dummyStartTSOffset = extExpectedTS - uint64(refTS) + extRefTS = extExpectedTS + f.logger.Infow( + "calculating dummyStartTSOffset", + "preStartTime", f.preStartTime, + "extFirstTS", f.extFirstTS, + "timeSinceFirst", timeSinceFirst, + "rtpDiff", rtpDiff, + "extRefTS", extRefTS, + "incomingTS", extPkt.Packet.Timestamp, + "referenceLayerSpatial", f.referenceLayerSpatial, + "dummyStartTSOffset", f.dummyStartTSOffset, + ) + } + } + } + } + + bigJump := false + var extNextTS uint64 + if f.lastSSRC == 0 { + // If resuming (e. g. on unmute), keep next timestamp close to expected timestamp. + // + // Rationale: + // Case 1: If mute is implemented via something like stopping a track and resuming it on unmute, + // the RTP timestamp may not have jumped across mute valley. In this case, old timestamp + // should not be used. + // + // Case 2: OTOH, something like pacing may be adding latency in the publisher path (even if + // the timestamps incremented correctly across the mute valley). In this case, reference + // timestamp should be used as things will catch up to real time when channel capacity + // increases and pacer starts sending at faster rate. + // + // But, the challenge is distinguishing between the two cases. As a compromise, the difference + // between extExpectedTS and extRefTS is thresholded. Difference below the threshold is treated as Case 2 + // and above as Case 1. + // + // In the event of extRefTS > extExpectedTS, use extRefTS. + // Ideally, extRefTS should not be ahead of extExpectedTS, but extExpectedTS uses the first packet's + // wall clock time. So, if the first packet experienced abmormal latency, it is possible + // for extRefTS > extExpectedTS + diffSeconds := float64(int64(extExpectedTS-extRefTS)) / float64(f.clockRate) + if diffSeconds >= 0.0 { + if f.resumeBehindThreshold > 0 && diffSeconds > f.resumeBehindThreshold { + logTransitionInfo("resume, reference too far behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) + extNextTS = extExpectedTS + bigJump = true + } else if diffSeconds > ResumeBehindHighThresholdSeconds { + // could be due to incoming time stamp lagging a lot, like an unpause of the track + logTransitionInfo("resume, reference very far behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) + extNextTS = extExpectedTS + bigJump = true + } else { + extNextTS = extRefTS + } + } else { + if math.Abs(diffSeconds) > SwitchAheadThresholdSeconds { + logTransition("resume, reference too far ahead", extExpectedTS, extRefTS, extLastTS, diffSeconds) + } + extNextTS = extRefTS + } + f.resumeBehindThreshold = 0.0 + } else { + // switching between layers, check if extRefTS is too far behind the last sent + diffSeconds := float64(int64(extRefTS-extLastTS)) / float64(f.clockRate) + if diffSeconds < 0.0 { + if math.Abs(diffSeconds) > LayerSwitchBehindThresholdSeconds { + // this could be due to pacer trickling out this layer. Error out and wait for a more opportune time. + // AVSYNC-TODO: Consider some forcing function to do the switch + // (like "have waited for too long for layer switch, nothing available, switch to whatever is available" kind of condition). + logTransition("layer switch, reference too far behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) + + return errSwitchPointTooFarBehind + } + + // use a nominal increase to ensure that timestamp is always moving forward + logTransition("layer switch, reference is slightly behind", extExpectedTS, extRefTS, extLastTS, diffSeconds) + extNextTS = extLastTS + 1 + } else { + diffSeconds = float64(int64(extRefTS-extExpectedTS)) / float64(f.clockRate) + if diffSeconds > SwitchAheadThresholdSeconds { + logTransition("layer switch, reference too far ahead", extExpectedTS, extRefTS, extLastTS, diffSeconds) + } + + extNextTS = extRefTS + } + } + + if int64(extNextTS-extLastTS) <= 0 { + f.logger.Debugw("next timestamp is before last, adjusting", "extNextTS", extNextTS, "extLastTS", extLastTS) + // nominal increase + extNextTS = extLastTS + 1 + } + if bigJump { // TODO-REMOVE-AFTER-DATA-COLLECTION + f.logger.Infow( + "next timestamp on switch", + "switchingAt", switchingAt, + "layer", layer, + "extLastTS", extLastTS, + "lastMarker", rtpMungerState.LastMarker, + "extRefTS", extRefTS, + "dummyStartTSOffset", f.dummyStartTSOffset, + "referenceLayerSpatial", f.referenceLayerSpatial, + "extExpectedTS", extExpectedTS, + "extNextTS", extNextTS, + "tsJump", extNextTS-extLastTS, + "nextSN", rtpMungerState.ExtLastSequenceNumber+1, + "extIncomingSN", extPkt.ExtSequenceNumber, + "incomingTS", extPkt.Packet.Timestamp, + "extIncomingTS", extPkt.ExtTimestamp, + "rtpStats", f.rtpStats, + ) + } else { + f.logger.Debugw( + "next timestamp on switch", + "switchingAt", switchingAt, + "layer", layer, + "extLastTS", extLastTS, + "lastMarker", rtpMungerState.LastMarker, + "extRefTS", extRefTS, + "dummyStartTSOffset", f.dummyStartTSOffset, + "referenceLayerSpatial", f.referenceLayerSpatial, + "extExpectedTS", extExpectedTS, + "extNextTS", extNextTS, + "tsJump", extNextTS-extLastTS, + "nextSN", rtpMungerState.ExtLastSequenceNumber+1, + "extIncomingSN", extPkt.ExtSequenceNumber, + "extIncomingTS", extPkt.ExtTimestamp, + "rtpStats", f.rtpStats, + ) + } + + f.rtpMunger.UpdateSnTsOffsets(extPkt, 1, extNextTS-extLastTS) + f.codecMunger.UpdateOffsets(extPkt) + return nil +} + +// should be called with lock held +func (f *Forwarder) getTranslationParamsCommon(extPkt *buffer.ExtPacket, layer int32, tp *TranslationParams) error { + if f.lastSSRC != extPkt.Packet.SSRC { + if err := f.processSourceSwitch(extPkt, layer); err != nil { + f.logger.Debugw( + "could not switch feed", + "error", err, + "layer", layer, + "refInfos", logger.ObjectSlice(f.refInfos[:]), + "lastSwitchExtIncomingTS", f.lastSwitchExtIncomingTS, + "rtpStats", f.rtpStats, + "currentLayer", f.vls.GetCurrent(), + "targetLayer", f.vls.GetCurrent(), + "maxLayer", f.vls.GetMax(), + ) + tp.shouldDrop = true + f.vls.Rollback() + return nil + } + f.logger.Debugw( + "switching feed", + "fromSSRC", f.lastSSRC, + "toSSRC", extPkt.Packet.SSRC, + "fromPayloadType", f.lastReferencePayloadType, + "toPayloadType", extPkt.Packet.PayloadType, + "layer", layer, + "refInfos", logger.ObjectSlice(f.refInfos[:]), + "lastSwitchExtIncomingTS", f.lastSwitchExtIncomingTS, + "currentLayer", f.vls.GetCurrent(), + "targetLayer", f.vls.GetCurrent(), + "maxLayer", f.vls.GetMax(), + ) + f.lastSSRC = extPkt.Packet.SSRC + f.lastReferencePayloadType = int8(extPkt.Packet.PayloadType) + f.lastSwitchExtIncomingTS = extPkt.ExtTimestamp + } + + tpRTP, err := f.rtpMunger.UpdateAndGetSnTs(extPkt, tp.marker) + if err != nil { + tp.shouldDrop = true + if err == errPaddingOnlyPacket || err == errDuplicatePacket || err == errOutOfOrderSequenceNumberCacheMiss { + return nil + } + return err + } + + tp.rtp = tpRTP + + if len(extPkt.Packet.Payload) > 0 { + return f.translateCodecHeader(extPkt, tp) + } + + return nil +} + +// should be called with lock held +func (f *Forwarder) getTranslationParamsAudio(extPkt *buffer.ExtPacket, layer int32) (TranslationParams, error) { + tp := TranslationParams{} + if err := f.getTranslationParamsCommon(extPkt, layer, &tp); err != nil { + tp.shouldDrop = true + return tp, err + } + return tp, nil +} + +// should be called with lock held +func (f *Forwarder) getTranslationParamsVideo(extPkt *buffer.ExtPacket, layer int32) (TranslationParams, error) { + tp := TranslationParams{} + if !f.vls.GetTarget().IsValid() { + // stream is paused by streamallocator + tp.shouldDrop = true + return tp, nil + } + + result := f.vls.Select(extPkt, layer) + if !result.IsSelected { + if f.isDDAvailable && extPkt.DependencyDescriptor == nil { + f.logger.Infow( + "turning off dependency descriptor", + "layer", layer, + "refInfos", logger.ObjectSlice(f.refInfos[:]), + "lastSwitchExtIncomingTS", f.lastSwitchExtIncomingTS, + "currentLayer", f.vls.GetCurrent(), + "targetLayer", f.vls.GetCurrent(), + "maxLayer", f.vls.GetMax(), + ) + f.isDDAvailable = false + switch f.mime { + case mime.MimeTypeVP9: + f.vls = videolayerselector.NewVP9FromOther(f.vls) + case mime.MimeTypeAV1: + f.vls = videolayerselector.NewSimulcastFromOther(f.vls) + } + } + tp.shouldDrop = true + if f.started && result.IsRelevant { + // call to update highest incoming sequence number and other internal structures + if tpRTP, err := f.rtpMunger.UpdateAndGetSnTs(extPkt, result.RTPMarker); err == nil { + if tpRTP.snOrdering == SequenceNumberOrderingContiguous { + f.rtpMunger.PacketDropped(extPkt) + } + } + } + return tp, nil + } + tp.isResuming = result.IsResuming + tp.isSwitching = result.IsSwitching + tp.ddBytes = result.DependencyDescriptorExtension + tp.marker = result.RTPMarker + + err := f.getTranslationParamsCommon(extPkt, layer, &tp) + if tp.shouldDrop { + return tp, err + } + + if FlagPauseOnDowngrade && f.isDeficientLocked() && f.vls.GetTarget().Spatial < f.vls.GetCurrent().Spatial { + // + // If target layer is lower than both the current and + // maximum subscribed layer, it is due to bandwidth + // constraints that the target layer has been switched down. + // Continuing to send higher layer will only exacerbate the + // situation by putting more stress on the channel. So, drop it. + // + // In the other direction, it is okay to keep forwarding till + // switch point to get a smoother stream till the higher + // layer key frame arrives. + // + // Note that it is possible for client subscription layer restriction + // to coincide with server restriction due to bandwidth limitation, + // In the case of subscription change, higher should continue streaming + // to ensure smooth transition. + // + // To differentiate between the two cases, drop only when in DEFICIENT state. + // + tp.shouldDrop = true + return tp, nil + } + + return tp, nil +} + +func (f *Forwarder) translateCodecHeader(extPkt *buffer.ExtPacket, tp *TranslationParams) error { + // codec specific forwarding check and any needed packet munging + tl := f.vls.SelectTemporal(extPkt) + inputSize, codecBytes, err := f.codecMunger.UpdateAndGet( + extPkt, + tp.rtp.snOrdering == SequenceNumberOrderingOutOfOrder, + tp.rtp.snOrdering == SequenceNumberOrderingGap, + tl, + ) + if err != nil { + tp.shouldDrop = true + if err == codecmunger.ErrFilteredVP8TemporalLayer || err == codecmunger.ErrOutOfOrderVP8PictureIdCacheMiss { + if err == codecmunger.ErrFilteredVP8TemporalLayer { + // filtered temporal layer, update sequence number offset to prevent holes + f.rtpMunger.PacketDropped(extPkt) + } + return nil + } + + return err + } + tp.incomingHeaderSize = inputSize + tp.codecBytes = codecBytes + return nil +} + +func (f *Forwarder) maybeStart() { + if f.started { + return + } + + f.started = true + f.preStartTime = time.Now() + + sequenceNumber := uint16(rand.Intn(1<<14)) + uint16(1<<15) // a random number in third quartile of sequence number space + timestamp := uint32(rand.Intn(1<<30)) + uint32(1<<31) // a random number in third quartile of timestamp space + extPkt := &buffer.ExtPacket{ + Packet: &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sequenceNumber, + Timestamp: timestamp, + }, + }, + ExtSequenceNumber: uint64(sequenceNumber), + ExtTimestamp: uint64(timestamp), + } + f.rtpMunger.SetLastSnTs(extPkt) + + f.extFirstTS = uint64(timestamp) + f.logger.Infow( + "starting with dummy forwarding", + "sequenceNumber", extPkt.Packet.SequenceNumber, + "timestamp", extPkt.Packet.Timestamp, + "preStartTime", f.preStartTime, + ) +} + +func (f *Forwarder) GetSnTsForPadding(num int, frameRate uint32, forceMarker bool) ([]SnTs, error) { + f.lock.Lock() + defer f.lock.Unlock() + + f.maybeStart() + + // padding is used for probing. Padding packets should only + // be at frame boundaries to ensure decoder sequencer does + // not get out-of-sync. But, when a stream is paused, + // force a frame marker as a restart of the stream will + // start with a key frame which will reset the decoder. + if !f.vls.GetTarget().IsValid() { + forceMarker = true + } + return f.rtpMunger.UpdateAndGetPaddingSnTs( + num, + f.clockRate, + frameRate, + forceMarker, + f.rtpMunger.GetState().ExtLastTimestamp, + ) +} + +func (f *Forwarder) GetSnTsForBlankFrames(frameRate uint32, numPackets int) ([]SnTs, bool, error) { + f.lock.Lock() + defer f.lock.Unlock() + + f.maybeStart() + + frameEndNeeded := !f.rtpMunger.IsOnFrameBoundary() + if frameEndNeeded { + numPackets++ + } + + extLastTS := f.rtpMunger.GetState().ExtLastTimestamp + extExpectedTS := extLastTS + if f.rtpStats != nil { + tsExt, err := f.rtpStats.GetExpectedRTPTimestamp(mono.Now()) + if err == nil { + extExpectedTS = tsExt + } + } + if int64(extExpectedTS-extLastTS) <= 0 { + extExpectedTS = extLastTS + 1 + } + snts, err := f.rtpMunger.UpdateAndGetPaddingSnTs( + numPackets, + f.clockRate, + frameRate, + frameEndNeeded, + extExpectedTS, + ) + return snts, frameEndNeeded, err +} + +func (f *Forwarder) GetPadding(frameEndNeeded bool) ([]byte, error) { + f.lock.Lock() + defer f.lock.Unlock() + + return f.codecMunger.UpdateAndGetPadding(!frameEndNeeded) +} + +func (f *Forwarder) RTPMungerDebugInfo() map[string]any { + f.lock.RLock() + defer f.lock.RUnlock() + + return f.rtpMunger.DebugInfo() +} + +// ----------------------------------------------------------------------------- + +func getOptimalBandwidthNeeded(muted bool, pubMuted bool, maxPublishedLayer int32, brs Bitrates, maxLayer buffer.VideoLayer) int64 { + if muted || pubMuted || maxPublishedLayer == buffer.InvalidLayerSpatial { + return 0 + } + + for i := maxLayer.Spatial; i >= 0; i-- { + for j := maxLayer.Temporal; j >= 0; j-- { + if brs[i][j] == 0 { + continue + } + + return brs[i][j] + } + } + + // could be 0 due to either + // 1. publisher has stopped all layers ==> feed dry. + // 2. stream tracker has declared all layers stopped, functionally same as above. + // But, listed differently as this could be a mis-detection. + // 3. Bitrate measurement is pending. + return 0 +} + +func getBandwidthNeeded(brs Bitrates, layer buffer.VideoLayer, fallback int64) int64 { + if layer.IsValid() && brs[layer.Spatial][layer.Temporal] > 0 { + return brs[layer.Spatial][layer.Temporal] + } + + return fallback +} + +func getDistanceToDesired( + muted bool, + pubMuted bool, + maxSeenLayer buffer.VideoLayer, + availableLayers []int32, + brs Bitrates, + targetLayer buffer.VideoLayer, + maxLayer buffer.VideoLayer, +) float64 { + if muted || pubMuted || !maxSeenLayer.IsValid() || !maxLayer.IsValid() { + return 0.0 + } + + adjustedMaxLayer := maxLayer + + maxAvailableSpatial := buffer.InvalidLayerSpatial + maxAvailableTemporal := buffer.InvalidLayerTemporal + + // max available spatial is min(subscribedMax, publishedMax, availableMax) + // subscribedMax = subscriber requested max spatial layer + // publishedMax = max spatial layer ever published + // availableMax = based on bit rate measurement, available max spatial layer +done: + for s := int32(len(brs)) - 1; s >= 0; s-- { + for t := int32(len(brs[0])) - 1; t >= 0; t-- { + if brs[s][t] != 0 { + maxAvailableSpatial = s + break done + } + } + } + + // before bit rate measurement is available, stream tracker could declare layer seen, account for that + for _, layer := range availableLayers { + if layer > maxAvailableSpatial { + maxAvailableSpatial = layer + maxAvailableTemporal = maxSeenLayer.Temporal // till bit rate measurement is available, assume max seen as temporal + } + } + + if maxAvailableSpatial < adjustedMaxLayer.Spatial { + adjustedMaxLayer.Spatial = maxAvailableSpatial + } + + if maxSeenLayer.Spatial < adjustedMaxLayer.Spatial { + adjustedMaxLayer.Spatial = maxSeenLayer.Spatial + } + + // max available temporal is min(subscribedMax, temporalLayerSeenMax, availableMax) + // subscribedMax = subscriber requested max temporal layer + // temporalLayerSeenMax = max temporal layer ever published/seen + // availableMax = based on bit rate measurement, available max temporal in the adjusted max spatial layer + if adjustedMaxLayer.Spatial != buffer.InvalidLayerSpatial { + for t := int32(len(brs[0])) - 1; t >= 0; t-- { + if brs[adjustedMaxLayer.Spatial][t] != 0 { + maxAvailableTemporal = t + break + } + } + } + if maxAvailableTemporal < adjustedMaxLayer.Temporal { + adjustedMaxLayer.Temporal = maxAvailableTemporal + } + + if maxSeenLayer.Temporal < adjustedMaxLayer.Temporal { + adjustedMaxLayer.Temporal = maxSeenLayer.Temporal + } + + if !adjustedMaxLayer.IsValid() { + adjustedMaxLayer = buffer.VideoLayer{Spatial: 0, Temporal: 0} + } + + // adjust target layers if they are invalid, i. e. not streaming + adjustedTargetLayer := targetLayer + if !targetLayer.IsValid() { + adjustedTargetLayer = buffer.VideoLayer{Spatial: 0, Temporal: 0} + } + + distance := + ((adjustedMaxLayer.Spatial - adjustedTargetLayer.Spatial) * (maxSeenLayer.Temporal + 1)) + + (adjustedMaxLayer.Temporal - adjustedTargetLayer.Temporal) + if !targetLayer.IsValid() { + distance += (maxSeenLayer.Temporal + 1) + } + + return float64(distance) / float64(maxSeenLayer.Temporal+1) +} diff --git a/livekit/pkg/sfu/forwarder_test.go b/livekit/pkg/sfu/forwarder_test.go new file mode 100644 index 0000000..8e1df4a --- /dev/null +++ b/livekit/pkg/sfu/forwarder_test.go @@ -0,0 +1,2144 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "testing" + + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/testutils" +) + +func disable(f *Forwarder) { + f.vls.SetCurrent(buffer.InvalidLayer) + f.vls.SetTarget(buffer.InvalidLayer) +} + +func newForwarder(codec webrtc.RTPCodecCapability, kind webrtc.RTPCodecType) *Forwarder { + f := NewForwarder( + kind, + logger.GetLogger(), + true, // skipReferenceTS + true, // disableOpportunisticAllocation + nil, + ) + f.DetermineCodec(codec, nil, livekit.VideoLayer_MODE_UNUSED) + return f +} + +func TestForwarderMute(t *testing.T) { + f := newForwarder(testutils.TestOpusCodec, webrtc.RTPCodecTypeAudio) + require.False(t, f.IsMuted()) + muted := f.Mute(false, true) + require.False(t, muted) // no change in mute state + require.False(t, f.IsMuted()) + + muted = f.Mute(true, false) + require.False(t, muted) + require.False(t, f.IsMuted()) + + muted = f.Mute(true, true) + require.True(t, muted) + require.True(t, f.IsMuted()) + + muted = f.Mute(false, true) + require.True(t, muted) + require.False(t, f.IsMuted()) +} + +func TestForwarderLayersAudio(t *testing.T) { + f := newForwarder(testutils.TestOpusCodec, webrtc.RTPCodecTypeAudio) + + require.Equal(t, buffer.InvalidLayer, f.MaxLayer()) + + require.Equal(t, buffer.InvalidLayer, f.CurrentLayer()) + require.Equal(t, buffer.InvalidLayer, f.TargetLayer()) + + changed, maxLayer := f.SetMaxSpatialLayer(1) + require.False(t, changed) + require.Equal(t, buffer.InvalidLayer, maxLayer) + + changed, maxLayer = f.SetMaxTemporalLayer(1) + require.False(t, changed) + require.Equal(t, buffer.InvalidLayer, maxLayer) + + require.Equal(t, buffer.InvalidLayer, f.MaxLayer()) +} + +func TestForwarderLayersVideo(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + + maxLayer := f.MaxLayer() + expectedLayers := buffer.VideoLayer{Spatial: buffer.InvalidLayerSpatial, Temporal: buffer.DefaultMaxLayerTemporal} + require.Equal(t, expectedLayers, maxLayer) + + require.Equal(t, buffer.InvalidLayer, f.CurrentLayer()) + require.Equal(t, buffer.InvalidLayer, f.TargetLayer()) + + expectedLayers = buffer.VideoLayer{ + Spatial: buffer.DefaultMaxLayerSpatial, + Temporal: buffer.DefaultMaxLayerTemporal, + } + changed, maxLayer := f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + require.True(t, changed) + require.Equal(t, expectedLayers, maxLayer) + + changed, maxLayer = f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial - 1) + require.True(t, changed) + expectedLayers = buffer.VideoLayer{ + Spatial: buffer.DefaultMaxLayerSpatial - 1, + Temporal: buffer.DefaultMaxLayerTemporal, + } + require.Equal(t, expectedLayers, maxLayer) + require.Equal(t, expectedLayers, f.MaxLayer()) + + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 1}) + changed, maxLayer = f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial - 1) + require.False(t, changed) + require.Equal(t, expectedLayers, maxLayer) + require.Equal(t, expectedLayers, f.MaxLayer()) + + changed, maxLayer = f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + require.False(t, changed) + require.Equal(t, expectedLayers, maxLayer) + + changed, maxLayer = f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal - 1) + require.True(t, changed) + expectedLayers = buffer.VideoLayer{ + Spatial: buffer.DefaultMaxLayerSpatial - 1, + Temporal: buffer.DefaultMaxLayerTemporal - 1, + } + require.Equal(t, expectedLayers, maxLayer) + require.Equal(t, expectedLayers, f.MaxLayer()) +} + +func TestForwarderAllocateOptimal(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + + emptyBitrates := Bitrates{} + bitrates := Bitrates{ + {2, 3, 0, 0}, + {4, 0, 0, 5}, + {0, 7, 0, 0}, + } + + // invalid max layers + f.vls.SetMax(buffer.InvalidLayer) + expectedResult := VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.InvalidLayer, + DistanceToDesired: 0, + } + result := f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + + // should still have target at buffer.InvalidLayer until max publisher layer is available + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0, + } + result = f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + + // muted should not consume any bandwidth + f.Mute(true, true) + disable(f) + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonMuted, + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0, + } + result = f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + + f.Mute(false, true) + + // pub muted should not consume any bandwidth + f.PubMute(true) + disable(f) + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonPubMuted, + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0, + } + result = f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + + f.PubMute(false) + + // when max layers changes, target is opportunistic, but requested spatial layer should be at max + f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal) + f.vls.SetMax(buffer.VideoLayer{Spatial: 1, Temporal: 3}) + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[2][1], + BandwidthDelta: bitrates[2][1], + BandwidthNeeded: bitrates[1][3], + Bitrates: bitrates, + TargetLayer: buffer.DefaultMaxLayer, + RequestLayerSpatial: f.vls.GetMax().Spatial, + MaxLayer: f.vls.GetMax(), + DistanceToDesired: -1, + } + result = f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) + + // reset max layers for rest of the tests below + f.vls.SetMax(buffer.DefaultMaxLayer) + + // when feed is dry and current is not valid, should set up for opportunistic forwarding + // NOTE: feed is dry due to availableLayers = nil, some valid bitrates may be passed in here for testing purposes only + disable(f) + expectedTargetLayer := buffer.DefaultMaxLayer + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[2][1], + BandwidthDelta: 0, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: -0.5, + } + result = f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + f.vls.SetTarget(buffer.VideoLayer{Spatial: 0, Temporal: 0}) // set to valid to trigger paths in tests below + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 3}) // set to valid to trigger paths in tests below + + // when feed is dry and current is valid, should stay at current + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 3, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: 0, + BandwidthDelta: 0 - bitrates[2][1], + Bitrates: emptyBitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: -0.75, + } + result = f.AllocateOptimal(nil, emptyBitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + f.vls.SetCurrent(buffer.InvalidLayer) + + // opportunistic target if feed is not dry and current is not valid, i. e. not forwarding + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[2][1], + BandwidthDelta: bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: buffer.DefaultMaxLayer, + RequestLayerSpatial: 1, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: -0.5, + } + result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) + + // when holding in above scenario, should choose the lowest available layer + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: 0, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[1][0], + BandwidthDelta: bitrates[1][0] - bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 1, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 1.25, + } + result = f.AllocateOptimal([]int32{1, 2}, bitrates, true, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // opportunistic target if feed is dry and current is not valid, i. e. not forwarding + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[2][1], + BandwidthDelta: bitrates[2][1] - bitrates[1][0], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: buffer.DefaultMaxLayer, + RequestLayerSpatial: 2, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: -0.5, + } + result = f.AllocateOptimal(nil, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) + + // when holding in above scenario, should choose layer 0 + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[0][0], + BandwidthDelta: bitrates[0][0] - bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 0, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result = f.AllocateOptimal(nil, bitrates, true, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // if feed is not dry and current is not locked, should be opportunistic (with and without overshoot) + f.vls.SetTarget(buffer.InvalidLayer) + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: 0, + BandwidthDelta: 0 - bitrates[0][0], + BandwidthNeeded: 0, + Bitrates: emptyBitrates, + TargetLayer: buffer.DefaultMaxLayer, + RequestLayerSpatial: 1, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: -1.0, + } + result = f.AllocateOptimal([]int32{0, 1}, emptyBitrates, false, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.DefaultMaxLayer, f.TargetLayer()) + + f.vls.SetTarget(buffer.InvalidLayer) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 2, + Temporal: buffer.DefaultMaxLayerTemporal, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[2][1], + BandwidthDelta: bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 1, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: -0.5, + } + result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // switches request layer to highest available if feed is not dry and current is valid and current is not available + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 1}) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: buffer.DefaultMaxLayerTemporal, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[1][3], + BandwidthDelta: bitrates[1][3] - bitrates[2][1], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 1, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0.5, + } + result = f.AllocateOptimal([]int32{1}, bitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // when holding in above scenario, should switch to lowest available layer + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonNone, + BandwidthRequested: bitrates[0][0], + BandwidthDelta: bitrates[0][0] - bitrates[1][3], + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 0, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result = f.AllocateOptimal([]int32{0, 1}, bitrates, true, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // stays the same if feed is not dry and current is valid, available and locked + f.vls.SetMax(buffer.VideoLayer{Spatial: 0, Temporal: 1}) + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 1}) + f.vls.SetRequestSpatial(0) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 1, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: 0, + BandwidthDelta: 0 - bitrates[0][0], + Bitrates: emptyBitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: 0, + MaxLayer: f.vls.GetMax(), + DistanceToDesired: 0.0, + } + result = f.AllocateOptimal([]int32{0}, emptyBitrates, true, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) +} + +func TestForwarderProvisionalAllocate(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal) + + // Reset to invalid layers for testing allocation from scratch + disable(f) + + bitrates := Bitrates{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.ProvisionalAllocatePrepare(nil, bitrates) + + isCandidate, usedBitrate := f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, true, false) + require.True(t, isCandidate) + require.Equal(t, bitrates[0][0], usedBitrate) + + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 2, Temporal: 3}, true, false) + require.True(t, isCandidate) + require.Equal(t, bitrates[2][3]-bitrates[0][0], usedBitrate) + + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 3}, true, false) + require.True(t, isCandidate) + require.Equal(t, bitrates[0][3]-bitrates[2][3], usedBitrate) + + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 1, Temporal: 2}, true, false) + require.True(t, isCandidate) + require.Equal(t, bitrates[1][2]-bitrates[0][3], usedBitrate) + + // available not enough to reach (2, 2), allocating at (2, 2) should not succeed + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][2]-bitrates[1][2]-1, buffer.VideoLayer{Spatial: 2, Temporal: 2}, true, false) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // committing should set target to (1, 2) + expectedTargetLayer := buffer.VideoLayer{ + Spatial: 1, + Temporal: 2, + } + expectedResult := VideoAllocation{ + IsDeficient: true, + BandwidthRequested: bitrates[1][2], + BandwidthDelta: bitrates[1][2], + BandwidthNeeded: bitrates[2][3], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 1.25, + } + result := f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // when nothing fits and pausing disallowed, should allocate (0, 0) + f.vls.SetTarget(buffer.InvalidLayer) + f.ProvisionalAllocatePrepare(nil, bitrates) + isCandidate, usedBitrate = f.ProvisionalAllocate(0, buffer.VideoLayer{Spatial: 0, Temporal: 0}, false, false) + require.True(t, isCandidate) + require.Equal(t, int64(1), usedBitrate) + + // committing should set target to (0, 0) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + expectedResult = VideoAllocation{ + IsDeficient: true, + BandwidthRequested: bitrates[0][0], + BandwidthDelta: bitrates[0][0] - bitrates[1][2], + BandwidthNeeded: bitrates[2][3], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.75, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // + // Test allowOvershoot. + // Max spatial set to 0 and layer 0 bit rates are not available. + // + f.SetMaxSpatialLayer(0) + bitrates = Bitrates{ + {0, 0, 0, 0}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.ProvisionalAllocatePrepare(nil, bitrates) + + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // overshoot should succeed + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 2, Temporal: 3}, false, true) + require.True(t, isCandidate) + require.Equal(t, bitrates[2][3], usedBitrate) + + // overshoot should succeed - this should win as this is lesser overshoot + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 1, Temporal: 3}, false, true) + require.True(t, isCandidate) + require.Equal(t, bitrates[1][3]-bitrates[2][3], usedBitrate) + + // committing should set target to (1, 3) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: 3, + } + expectedMaxLayer := buffer.VideoLayer{ + Spatial: 0, + Temporal: 3, + } + expectedResult = VideoAllocation{ + BandwidthRequested: bitrates[1][3], + BandwidthDelta: bitrates[1][3] - 1, // 1 is the last allocation bandwidth requested + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: -1.75, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // + // Even if overshoot is allowed, but if higher layers do not have bit rates, should continue with current layer. + // + bitrates = Bitrates{ + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + } + + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 2}) + f.ProvisionalAllocatePrepare(nil, bitrates) + + // all the provisional allocations should not succeed because the feed is dry + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // overshoot should not succeed + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 2, Temporal: 3}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // overshoot should not succeed + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 1, Temporal: 3}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // committing should set target to (0, 2), i. e. leave it at current for opportunistic forwarding + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 2, + } + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: bitrates[0][2], + BandwidthDelta: bitrates[0][2] - 8, // 8 is the last allocation bandwidth requested + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: 1.0, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + + // + // Same case as above, but current is above max, so target should go to invalid + // + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 1, Temporal: 2}) + f.ProvisionalAllocatePrepare(nil, bitrates) + + // all the provisional allocations below should not succeed because the feed is dry + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // overshoot should not succeed + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 2, Temporal: 3}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // overshoot should not succeed + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 1, Temporal: 3}, false, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + expectedResult = VideoAllocation{ + PauseReason: VideoPauseReasonFeedDry, + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: 1.0, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.InvalidLayer, f.TargetLayer()) + require.Equal(t, buffer.InvalidLayer, f.CurrentLayer()) +} + +func TestForwarderProvisionalAllocateMute(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + + // Reset to invalid layers for testing muted state + disable(f) + + bitrates := Bitrates{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.Mute(true, true) + f.ProvisionalAllocatePrepare(nil, bitrates) + + isCandidate, usedBitrate := f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, true, false) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + isCandidate, usedBitrate = f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 1, Temporal: 2}, true, true) + require.False(t, isCandidate) + require.Equal(t, int64(0), usedBitrate) + + // committing should set target to buffer.InvalidLayer as track is muted + expectedResult := VideoAllocation{ + PauseReason: VideoPauseReasonMuted, + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0, + } + result := f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.InvalidLayer, f.TargetLayer()) +} + +func TestForwarderProvisionalAllocateGetCooperativeTransition(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal) + + // Reset to invalid layers for testing cooperative transition from scratch + disable(f) + + availableLayers := []int32{0, 1, 2} + bitrates := Bitrates{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 0, 0}, + } + + f.ProvisionalAllocatePrepare(availableLayers, bitrates) + + // from scratch (buffer.InvalidLayer) should give back layer (0, 0) + expectedTransition := VideoTransition{ + From: buffer.InvalidLayer, + To: buffer.VideoLayer{Spatial: 0, Temporal: 0}, + BandwidthDelta: 1, + } + transition, al, brs := f.ProvisionalAllocateGetCooperativeTransition(false) + require.Equal(t, expectedTransition, transition) + require.Equal(t, availableLayers, al) + require.Equal(t, bitrates, brs) + + // committing should set target to (0, 0) + expectedLayers := buffer.VideoLayer{Spatial: 0, Temporal: 0} + expectedResult := VideoAllocation{ + IsDeficient: true, + BandwidthRequested: 1, + BandwidthDelta: 1, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedLayers, + RequestLayerSpatial: expectedLayers.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result := f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedLayers, f.TargetLayer()) + + // a higher target that is already streaming, just maintain it + targetLayer := buffer.VideoLayer{Spatial: 2, Temporal: 1} + f.vls.SetTarget(targetLayer) + f.lastAllocation.BandwidthRequested = 10 + expectedTransition = VideoTransition{ + From: targetLayer, + To: targetLayer, + BandwidthDelta: 0, + } + transition, al, brs = f.ProvisionalAllocateGetCooperativeTransition(false) + require.Equal(t, expectedTransition, transition) + require.Equal(t, availableLayers, al) + require.Equal(t, bitrates, brs) + + // committing should set target to (2, 1) + expectedLayers = buffer.VideoLayer{Spatial: 2, Temporal: 1} + expectedResult = VideoAllocation{ + BandwidthRequested: 10, + BandwidthDelta: 0, + Bitrates: bitrates, + BandwidthNeeded: bitrates[2][1], + TargetLayer: expectedLayers, + RequestLayerSpatial: expectedLayers.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0.0, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedLayers, f.TargetLayer()) + + // from a target that has become unavailable, should switch to lower available layer + targetLayer = buffer.VideoLayer{Spatial: 2, Temporal: 2} + f.vls.SetTarget(targetLayer) + expectedTransition = VideoTransition{ + From: targetLayer, + To: buffer.VideoLayer{Spatial: 2, Temporal: 1}, + BandwidthDelta: 0, + } + transition, al, brs = f.ProvisionalAllocateGetCooperativeTransition(false) + require.Equal(t, expectedTransition, transition) + require.Equal(t, availableLayers, al) + require.Equal(t, bitrates, brs) + + f.ProvisionalAllocateCommit() + + // mute + f.Mute(true, true) + f.ProvisionalAllocatePrepare(availableLayers, bitrates) + + // mute should send target to buffer.InvalidLayer + expectedTransition = VideoTransition{ + From: buffer.VideoLayer{Spatial: 2, Temporal: 1}, + To: buffer.InvalidLayer, + BandwidthDelta: -10, + } + transition, al, brs = f.ProvisionalAllocateGetCooperativeTransition(false) + require.Equal(t, expectedTransition, transition) + require.Equal(t, availableLayers, al) + require.Equal(t, bitrates, brs) + + f.ProvisionalAllocateCommit() + + // + // Test allowOvershoot + // + f.Mute(false, true) + f.SetMaxSpatialLayer(0) + + availableLayers = []int32{1, 2} + bitrates = Bitrates{ + {0, 0, 0, 0}, + {5, 6, 7, 8}, + {9, 10, 0, 0}, + } + + f.vls.SetTarget(buffer.InvalidLayer) + f.ProvisionalAllocatePrepare(availableLayers, bitrates) + + // from scratch (buffer.InvalidLayer) should go to a layer past maximum as overshoot is allowed + expectedTransition = VideoTransition{ + From: buffer.InvalidLayer, + To: buffer.VideoLayer{Spatial: 1, Temporal: 0}, + BandwidthDelta: 5, + } + transition, al, brs = f.ProvisionalAllocateGetCooperativeTransition(true) + require.Equal(t, expectedTransition, transition) + require.Equal(t, availableLayers, al) + require.Equal(t, bitrates, brs) + + // committing should set target to (1, 0) + expectedLayers = buffer.VideoLayer{Spatial: 1, Temporal: 0} + expectedMaxLayer := buffer.VideoLayer{Spatial: 0, Temporal: buffer.DefaultMaxLayerTemporal} + expectedResult = VideoAllocation{ + BandwidthRequested: 5, + BandwidthDelta: 5, + Bitrates: bitrates, + TargetLayer: expectedLayers, + RequestLayerSpatial: expectedLayers.Spatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: -1.0, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedLayers, f.TargetLayer()) + + // + // Test continuing at current layers when feed is dry + // + bitrates = Bitrates{ + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + } + + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 0, Temporal: 2}) + f.vls.SetTarget(buffer.InvalidLayer) + f.ProvisionalAllocatePrepare(nil, bitrates) + + // from scratch (buffer.InvalidLayer) should go to current layer + // NOTE: targetLayer is set to buffer.InvalidLayer for testing, but in practice current layers valid and target layers invalid should not happen + expectedTransition = VideoTransition{ + From: buffer.InvalidLayer, + To: buffer.VideoLayer{Spatial: 0, Temporal: 2}, + BandwidthDelta: -5, // 5 was the bandwidth needed for the last allocation + } + transition, al, brs = f.ProvisionalAllocateGetCooperativeTransition(true) + require.Equal(t, expectedTransition, transition) + require.Equal(t, []int32{}, al) + require.Equal(t, bitrates, brs) + + // committing should set target to (0, 2) + expectedLayers = buffer.VideoLayer{Spatial: 0, Temporal: 2} + expectedResult = VideoAllocation{ + BandwidthRequested: 0, + BandwidthDelta: -5, + Bitrates: bitrates, + TargetLayer: expectedLayers, + RequestLayerSpatial: expectedLayers.Spatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: -0.5, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedLayers, f.TargetLayer()) + + // committing should set target to current layers to enable opportunistic forwarding + expectedResult = VideoAllocation{ + BandwidthRequested: 0, + BandwidthDelta: 0, + Bitrates: bitrates, + TargetLayer: expectedLayers, + RequestLayerSpatial: expectedLayers.Spatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: -0.5, + } + result = f.ProvisionalAllocateCommit() + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedLayers, f.TargetLayer()) +} + +func TestForwarderProvisionalAllocateGetBestWeightedTransition(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + + availableLayers := []int32{0, 1, 2} + bitrates := Bitrates{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.ProvisionalAllocatePrepare(availableLayers, bitrates) + + f.vls.SetTarget(buffer.VideoLayer{Spatial: 2, Temporal: 2}) + f.lastAllocation.BandwidthRequested = bitrates[2][2] + expectedTransition := VideoTransition{ + From: f.TargetLayer(), + To: buffer.VideoLayer{Spatial: 2, Temporal: 0}, + BandwidthDelta: -2, + } + transition, al, brs := f.ProvisionalAllocateGetBestWeightedTransition() + require.Equal(t, expectedTransition, transition) + require.Equal(t, availableLayers, al) + require.Equal(t, bitrates, brs) +} + +func TestForwarderAllocateNextHigher(t *testing.T) { + f := newForwarder(testutils.TestOpusCodec, webrtc.RTPCodecTypeAudio) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + + emptyBitrates := Bitrates{} + bitrates := Bitrates{ + {2, 3, 0, 0}, + {4, 0, 0, 5}, + {0, 7, 0, 0}, + } + + result, boosted := f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, VideoAllocationDefault, result) // no layer for audio + require.False(t, boosted) + + f = newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal) + + // when not in deficient state, does not boost + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, VideoAllocationDefault, result) + require.False(t, boosted) + + // if layers have not caught up, should not allocate next layer even if deficient + f.vls.SetTarget(buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + }) + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, VideoAllocationDefault, result) + require.False(t, boosted) + + f.lastAllocation.IsDeficient = true + f.vls.SetCurrent(buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + }) + + // move from (0, 0) -> (0, 1), i.e. a higher temporal layer is available in the same spatial layer + expectedTargetLayer := buffer.VideoLayer{ + Spatial: 0, + Temporal: 1, + } + expectedResult := VideoAllocation{ + IsDeficient: true, + BandwidthRequested: 3, + BandwidthDelta: 1, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.0, + } + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.True(t, boosted) + + // empty bitrates cannot increase layer, i. e. last allocation is left unchanged + result, boosted = f.AllocateNextHigher(100_000_000, nil, emptyBitrates, false) + require.Equal(t, expectedResult, result) + require.False(t, boosted) + + // move from (0, 1) -> (1, 0), i.e. a higher spatial layer is available + f.vls.SetCurrent(buffer.VideoLayer{Spatial: f.vls.GetCurrent().Spatial, Temporal: 1}) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: 0, + } + expectedResult = VideoAllocation{ + IsDeficient: true, + BandwidthRequested: 4, + BandwidthDelta: 1, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 1.25, + } + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.True(t, boosted) + + // next higher, move from (1, 0) -> (1, 3), still deficient though + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 1, Temporal: 0}) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: 3, + } + expectedResult = VideoAllocation{ + IsDeficient: true, + BandwidthRequested: 5, + BandwidthDelta: 1, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0.5, + } + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.True(t, boosted) + + // next higher, move from (1, 3) -> (2, 1), optimal allocation + f.vls.SetCurrent(buffer.VideoLayer{Spatial: f.vls.GetCurrent().Spatial, Temporal: 3}) + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 2, + Temporal: 1, + } + expectedResult = VideoAllocation{ + BandwidthRequested: 7, + BandwidthDelta: 2, + Bitrates: bitrates, + BandwidthNeeded: bitrates[2][1], + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0.0, + } + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.True(t, boosted) + + // ask again, should return not boosted as there is no room to go higher + f.vls.SetCurrent(buffer.VideoLayer{Spatial: 2, Temporal: 1}) + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.False(t, boosted) + + // turn off everything, allocating next layer should result in streaming lowest layers + disable(f) + f.lastAllocation.IsDeficient = true + f.lastAllocation.BandwidthRequested = 0 + + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 0, + Temporal: 0, + } + expectedResult = VideoAllocation{ + IsDeficient: true, + BandwidthRequested: 2, + BandwidthDelta: 2, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result, boosted = f.AllocateNextHigher(100_000_000, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.True(t, boosted) + + // no new available capacity cannot bump up layer + expectedResult = VideoAllocation{ + IsDeficient: true, + BandwidthRequested: 2, + BandwidthDelta: 2, + BandwidthNeeded: bitrates[2][1], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 2.25, + } + result, boosted = f.AllocateNextHigher(0, nil, bitrates, false) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.False(t, boosted) + + // test allowOvershoot + f.SetMaxSpatialLayer(0) + + bitrates = Bitrates{ + {0, 0, 0, 0}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.vls.SetCurrent(f.vls.GetTarget()) + + expectedTargetLayer = buffer.VideoLayer{ + Spatial: 1, + Temporal: 0, + } + expectedMaxLayer := buffer.VideoLayer{ + Spatial: 0, + Temporal: buffer.DefaultMaxLayerTemporal, + } + expectedResult = VideoAllocation{ + BandwidthRequested: bitrates[1][0], + BandwidthDelta: bitrates[1][0], + Bitrates: bitrates, + TargetLayer: expectedTargetLayer, + RequestLayerSpatial: expectedTargetLayer.Spatial, + MaxLayer: expectedMaxLayer, + DistanceToDesired: -1.0, + } + // overshoot should return (1, 0) even if there is not enough capacity + result, boosted = f.AllocateNextHigher(bitrates[1][0]-1, nil, bitrates, true) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, expectedTargetLayer, f.TargetLayer()) + require.True(t, boosted) +} + +func TestForwarderPause(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayerSeen(buffer.DefaultMaxLayerTemporal) + + bitrates := Bitrates{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.ProvisionalAllocatePrepare(nil, bitrates) + f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, true, false) + // should have set target at (0, 0) + f.ProvisionalAllocateCommit() + + expectedResult := VideoAllocation{ + PauseReason: VideoPauseReasonBandwidth, + IsDeficient: true, + BandwidthRequested: 0, + BandwidthDelta: 0 - bitrates[0][0], + BandwidthNeeded: bitrates[2][3], + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 3.75, + } + result := f.Pause(nil, bitrates) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.InvalidLayer, f.TargetLayer()) +} + +func TestForwarderPauseMute(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.SetMaxSpatialLayer(buffer.DefaultMaxLayerSpatial) + f.SetMaxTemporalLayer(buffer.DefaultMaxLayerTemporal) + f.SetMaxPublishedLayer(buffer.DefaultMaxLayerSpatial) + + bitrates := Bitrates{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + } + + f.ProvisionalAllocatePrepare(nil, bitrates) + f.ProvisionalAllocate(bitrates[2][3], buffer.VideoLayer{Spatial: 0, Temporal: 0}, true, true) + // should have set target at (0, 0) + f.ProvisionalAllocateCommit() + + f.Mute(true, true) + expectedResult := VideoAllocation{ + PauseReason: VideoPauseReasonMuted, + BandwidthRequested: 0, + BandwidthDelta: 0 - bitrates[0][0], + Bitrates: bitrates, + TargetLayer: buffer.InvalidLayer, + RequestLayerSpatial: buffer.InvalidLayerSpatial, + MaxLayer: buffer.DefaultMaxLayer, + DistanceToDesired: 0, + } + result := f.Pause(nil, bitrates) + require.Equal(t, expectedResult, result) + require.Equal(t, expectedResult, f.lastAllocation) + require.Equal(t, buffer.InvalidLayer, f.TargetLayer()) +} + +func TestForwarderGetTranslationParamsMuted(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + f.Mute(true, true) + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, err := testutils.GetTestExtPacket(params) + require.NoError(t, err) + require.NotNil(t, extPkt) + + expectedTP := TranslationParams{ + shouldDrop: true, + } + actualTP, err := f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) +} + +func TestForwarderGetTranslationParamsAudio(t *testing.T) { + f := newForwarder(testutils.TestOpusCodec, webrtc.RTPCodecTypeAudio) + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23332, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + IsOutOfOrder: true, + } + extPkt, _ := testutils.GetTestExtPacket(params) + + // should not start on an out-of-order packet + expectedTP := TranslationParams{ + shouldDrop: true, + } + actualTP, err := f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + require.False(t, f.started) + require.Zero(t, f.lastSSRC) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + // should lock onto the first in-order packet + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23333, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + require.True(t, f.started) + require.Equal(t, f.lastSSRC, params.SSRC) + + // send a duplicate, should be dropped + expectedTP = TranslationParams{ + shouldDrop: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // add a missing sequence number to the cache + err = f.rtpMunger.snRangeMap.ExcludeRange(23334, 23335) + require.NoError(t, err) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23336, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + _, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + + // out-of-order packet should get offset from cache + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23335, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: 23334, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // padding only packet in order should be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23337, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + shouldDrop: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // in order packet should be forwarded + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23338, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23336, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // padding only packet after a gap should not be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23340, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingGap, + extSequenceNumber: 23338, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // out-of-order should be forwarded using cache + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23336, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: 23335, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // switching source should lock onto the new source, but sequence number should be contiguous + params = &testutils.TestExtPacketParams{ + SequenceNumber: 123, + Timestamp: 0xfedcba, + SSRC: 0x87654321, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23339, + extTimestamp: 0xabcdf0, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + require.Equal(t, f.lastSSRC, params.SSRC) +} + +func TestForwarderGetTranslationParamsVideo(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23332, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + Marker: true, + IsOutOfOrder: true, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: false, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + + // should not start on an out-of-order packet + expectedTP := TranslationParams{ + shouldDrop: true, + } + actualTP, err := f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + require.False(t, f.started) + require.Zero(t, f.lastSSRC) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + Marker: true, + } + vp8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: false, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + // no target layers, should drop + expectedTP = TranslationParams{ + shouldDrop: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // although target layer matches, not a key frame, so should drop + f.vls.SetTarget(buffer.VideoLayer{ + Spatial: 0, + Temporal: 1, + }) + expectedTP = TranslationParams{ + shouldDrop: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // should lock onto packet (key frame) + vp8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedVP8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err := expectedVP8.Marshal() + expectedTP = TranslationParams{ + isSwitching: true, + isResuming: true, + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23333, + extTimestamp: 0xabcdef, + }, + incomingHeaderSize: 6, + codecBytes: marshalledVP8, + marker: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + require.True(t, f.started) + require.Equal(t, f.lastSSRC, params.SSRC) + + // send a duplicate, should be dropped + expectedTP = TranslationParams{ + shouldDrop: true, + marker: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // out-of-order packet not in cache should be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23332, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedTP = TranslationParams{ + shouldDrop: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // padding only packet in order should be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23334, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedTP = TranslationParams{ + shouldDrop: true, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // in order packet should be forwarded + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23335, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedVP8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23334, + extTimestamp: 0xabcdef, + }, + incomingHeaderSize: 6, + codecBytes: marshalledVP8, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // temporal layer matching target, should be forwarded + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23336, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 = &buffer.VP8{ + FirstByte: 25, + S: true, + I: true, + M: true, + PictureID: 13468, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedVP8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13468, + L: true, + TL0PICIDX: 233, + T: true, + TID: 1, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23335, + extTimestamp: 0xabcdef, + }, + incomingHeaderSize: 6, + codecBytes: marshalledVP8, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // temporal layer higher than target, should be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23337, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13468, + L: true, + TL0PICIDX: 233, + T: true, + TID: 2, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedTP = TranslationParams{ + shouldDrop: true, + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23336, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // RTP sequence number and VP8 picture id should be contiguous after dropping higher temporal layer picture + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23338, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13469, + L: true, + TL0PICIDX: 234, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: false, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + expectedVP8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13469, + L: true, + TL0PICIDX: 234, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: false, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23336, + extTimestamp: 0xabcdef, + }, + incomingHeaderSize: 6, + codecBytes: marshalledVP8, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // padding only packet after a gap should be forwarded + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23340, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingGap, + extSequenceNumber: 23338, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // out-of-order should be forwarded using cache, even if it is padding only + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23339, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + expectedTP = TranslationParams{ + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: 23337, + extTimestamp: 0xabcdef, + }, + } + actualTP, err = f.GetTranslationParams(extPkt, 0) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + + // switching SSRC (happens for new layer or new track source) + // should lock onto the new source, but sequence number should be contiguous + f.vls.SetTarget(buffer.VideoLayer{ + Spatial: 1, + Temporal: 1, + }) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 123, + Timestamp: 0xfedcba, + SSRC: 0x87654321, + PayloadSize: 20, + } + vp8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: false, + PictureID: 45, + L: true, + TL0PICIDX: 12, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 30, + HeaderSize: 5, + IsKeyFrame: true, + } + extPkt, _ = testutils.GetTestExtPacketVP8(params, vp8) + + expectedVP8 = &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13470, + L: true, + TL0PICIDX: 235, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 24, + HeaderSize: 6, + IsKeyFrame: true, + } + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + expectedTP = TranslationParams{ + isSwitching: true, + rtp: TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + extSequenceNumber: 23339, + extTimestamp: 0xabcdf0, + }, + incomingHeaderSize: 5, + codecBytes: marshalledVP8, + } + actualTP, err = f.GetTranslationParams(extPkt, 1) + require.NoError(t, err) + require.Equal(t, expectedTP, actualTP) + require.Equal(t, f.lastSSRC, params.SSRC) +} + +func TestForwarderGetSnTsForPadding(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + + f.vls.SetTarget(buffer.VideoLayer{ + Spatial: 0, + Temporal: 1, + }) + f.vls.SetCurrent(buffer.InvalidLayer) + + // send it through so that forwarder locks onto stream + _, _ = f.GetTranslationParams(extPkt, 0) + + // pause stream and get padding, it should still work + disable(f) + + // should get back frame end needed as the last packet did not have RTP marker set + snts, err := f.GetSnTsForPadding(5, 0, false) + require.NoError(t, err) + + numPadding := 5 + clockRate := uint32(0) + frameRate := uint32(5) + var sntsExpected = make([]SnTs, numPadding) + for i := range numPadding { + sntsExpected[i] = SnTs{ + extSequenceNumber: 23333 + uint64(i) + 1, + extTimestamp: 0xabcdef + (uint64(i)*uint64(clockRate))/uint64(frameRate), + } + } + require.Equal(t, sntsExpected, snts) + + // now that there is a marker, timestamp should jump on first padding when asked again + snts, err = f.GetSnTsForPadding(numPadding, 0, false) + require.NoError(t, err) + + for i := range numPadding { + sntsExpected[i] = SnTs{ + extSequenceNumber: 23338 + uint64(i) + 1, + extTimestamp: 0xabcdef + (uint64(i+1)*uint64(clockRate))/uint64(frameRate), + } + } + require.Equal(t, sntsExpected, snts) +} + +func TestForwarderGetSnTsForBlankFrames(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + + f.vls.SetTarget(buffer.VideoLayer{ + Spatial: 0, + Temporal: 1, + }) + f.vls.SetCurrent(buffer.InvalidLayer) + + // send it through so that forwarder locks onto stream + _, _ = f.GetTranslationParams(extPkt, 0) + + // should get back frame end needed as the last packet did not have RTP marker set + numBlankFrames := 6 + snts, frameEndNeeded, err := f.GetSnTsForBlankFrames(30, numBlankFrames) + require.NoError(t, err) + require.True(t, frameEndNeeded) + + // there should be one more than RTPBlankFramesMax as one would have been allocated to end previous frame + numPadding := numBlankFrames + 1 + clockRate := testutils.TestVP8Codec.ClockRate + frameRate := uint32(30) + var sntsExpected = make([]SnTs, numPadding) + for i := 0; i < numPadding; i++ { + // first blank frame should have same timestamp as last frame as end frame is synthesized + ts := params.Timestamp + if i != 0 { + // +1 here due to expected time stamp bumpint by at least one so that time stamp is always moving ahead + ts = params.Timestamp + 1 + ((uint32(i)*clockRate)+frameRate-1)/frameRate + } + sntsExpected[i] = SnTs{ + extSequenceNumber: uint64(params.SequenceNumber) + uint64(i) + 1, + extTimestamp: uint64(ts), + } + } + require.Equal(t, sntsExpected, snts) + + // now that there is a marker, timestamp should jump on first padding when asked again + // also number of padding should be RTPBlankFramesMax + numPadding = numBlankFrames + sntsExpected = sntsExpected[:numPadding] + for i := 0; i < numPadding; i++ { + sntsExpected[i] = SnTs{ + extSequenceNumber: uint64(params.SequenceNumber) + uint64(len(snts)) + uint64(i) + 1, + // +1 here due to expected time stamp bumpint by at least one so that time stamp is always moving ahead + extTimestamp: snts[len(snts)-1].extTimestamp + 1 + ((uint64(i+1)*uint64(clockRate))+uint64(frameRate)-1)/uint64(frameRate), + } + } + snts, frameEndNeeded, err = f.GetSnTsForBlankFrames(30, numBlankFrames) + require.NoError(t, err) + require.False(t, frameEndNeeded) + require.Equal(t, sntsExpected, snts) +} + +func TestForwarderGetPaddingVP8(t *testing.T) { + f := newForwarder(testutils.TestVP8Codec, webrtc.RTPCodecTypeVideo) + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + vp8 := &buffer.VP8{ + FirstByte: 25, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 13, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + extPkt, _ := testutils.GetTestExtPacketVP8(params, vp8) + + f.vls.SetTarget(buffer.VideoLayer{ + Spatial: 0, + Temporal: 1, + }) + f.vls.SetCurrent(buffer.InvalidLayer) + + // send it through so that forwarder locks onto stream + _, _ = f.GetTranslationParams(extPkt, 0) + + // getting padding with frame end needed, should repeat the last picture id + expectedVP8 := buffer.VP8{ + FirstByte: 16, + I: true, + M: true, + PictureID: 13467, + L: true, + TL0PICIDX: 233, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 23, + HeaderSize: 6, + IsKeyFrame: true, + } + buf, err := f.GetPadding(true) + require.NoError(t, err) + marshalledVP8, err := expectedVP8.Marshal() + require.NoError(t, err) + require.Equal(t, marshalledVP8, buf) + + // getting padding with no frame end needed, should get next picture id + expectedVP8 = buffer.VP8{ + FirstByte: 16, + I: true, + M: true, + PictureID: 13468, + L: true, + TL0PICIDX: 234, + T: true, + TID: 0, + Y: true, + K: true, + KEYIDX: 24, + HeaderSize: 6, + IsKeyFrame: true, + } + buf, err = f.GetPadding(false) + require.NoError(t, err) + marshalledVP8, err = expectedVP8.Marshal() + require.NoError(t, err) + require.Equal(t, marshalledVP8, buf) +} diff --git a/livekit/pkg/sfu/forwardstats.go b/livekit/pkg/sfu/forwardstats.go new file mode 100644 index 0000000..34daa0c --- /dev/null +++ b/livekit/pkg/sfu/forwardstats.go @@ -0,0 +1,114 @@ +package sfu + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/mono" +) + +const ( + cHighForwardingLatency = 20 * time.Millisecond + cSkewFactor = 10 +) + +type ForwardStats struct { + lock sync.Mutex + latency *utils.LatencyAggregate + lowest int64 + highest int64 + lastUpdateAt int64 + closeCh chan struct{} +} + +func NewForwardStats(latencyUpdateInterval, reportInterval, latencyWindowLength time.Duration) *ForwardStats { + s := &ForwardStats{ + latency: utils.NewLatencyAggregate(latencyUpdateInterval, latencyWindowLength), + lowest: time.Second.Nanoseconds(), + closeCh: make(chan struct{}), + } + + go s.report(reportInterval) + return s +} + +func (s *ForwardStats) Update(arrival, left int64) (int64, bool) { + transit := left - arrival + isHighForwardingLatency := false + if time.Duration(transit) > cHighForwardingLatency { + isHighForwardingLatency = true + } + + s.lock.Lock() + s.latency.Update(time.Duration(arrival), float64(transit)) + s.lowest = min(transit, s.lowest) + s.highest = max(transit, s.highest) + s.lastUpdateAt = arrival + s.lock.Unlock() + + prometheus.RecordForwardLatencySample(transit) + return transit, isHighForwardingLatency +} + +func (s *ForwardStats) GetStats(shortDuration time.Duration) (time.Duration, time.Duration) { + s.lock.Lock() + // a dummy sample to flush the pipe to current time + now := mono.UnixNano() + if (now - s.lastUpdateAt) > shortDuration.Nanoseconds() { + s.latency.Update(time.Duration(now), 0) + } + + wLong := s.latency.Summarize() + + lowest := s.lowest + s.lowest = time.Second.Nanoseconds() + + highest := s.highest + s.highest = 0 + s.lock.Unlock() + + latencyLong, jitterLong := time.Duration(wLong.Mean()), time.Duration(wLong.StdDev()) + if jitterLong > latencyLong*cSkewFactor { + logger.Infow( + "high jitter in forwarding path", + "lowest", time.Duration(lowest), + "highest", time.Duration(highest), + "countLong", wLong.Count(), + "latencyLong", latencyLong, + "jitterLong", jitterLong, + ) + } + return latencyLong, jitterLong +} + +func (s *ForwardStats) GetShortStats(shortDuration time.Duration) (time.Duration, time.Duration) { + s.lock.Lock() + wShort := s.latency.SummarizeLast(shortDuration) + s.lock.Unlock() + + return time.Duration(wShort.Mean()), time.Duration(wShort.StdDev()) +} + +func (s *ForwardStats) Stop() { + close(s.closeCh) +} + +func (s *ForwardStats) report(reportInterval time.Duration) { + ticker := time.NewTicker(reportInterval) + defer ticker.Stop() + + for { + select { + case <-s.closeCh: + return + + case <-ticker.C: + latencyLong, jitterLong := s.GetStats(reportInterval) + prometheus.RecordForwardJitter(uint32(jitterLong.Nanoseconds())) + prometheus.RecordForwardLatency(uint32(latencyLong.Nanoseconds())) + } + } +} diff --git a/livekit/pkg/sfu/interceptor/rtx.go b/livekit/pkg/sfu/interceptor/rtx.go new file mode 100644 index 0000000..47bf598 --- /dev/null +++ b/livekit/pkg/sfu/interceptor/rtx.go @@ -0,0 +1,194 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interceptor + +import ( + "sync" + + "github.com/pion/interceptor" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/protocol/logger" +) + +const ( + SDESRepairRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id" + + rtxProbeCount = 10 +) + +type streamInfo struct { + mid string + rid string + rsid string +} + +type RTXInfoExtractorFactory struct { + onStreamFound func(*interceptor.StreamInfo) + onRTXPairFound func(repair, base uint32, rsid string) + lock sync.Mutex + streams map[uint32]streamInfo + logger logger.Logger +} + +func NewRTXInfoExtractorFactory( + onStreamFound func(*interceptor.StreamInfo), + onRTXPairFound func(repair, base uint32, rsid string), + logger logger.Logger, +) *RTXInfoExtractorFactory { + return &RTXInfoExtractorFactory{ + onStreamFound: onStreamFound, + onRTXPairFound: onRTXPairFound, + streams: make(map[uint32]streamInfo), + logger: logger, + } +} + +func (f *RTXInfoExtractorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) { + return &RTXInfoExtractor{ + factory: f, + logger: f.logger, + }, nil +} + +func (f *RTXInfoExtractorFactory) SetStreamInfo(ssrc uint32, mid, rid, rsid string) { + var repairSsrc, baseSsrc uint32 + var repairSid string + f.lock.Lock() + + if mid == "" || (rid == "" && rsid == "") { + f.lock.Unlock() + return + } + + if rsid != "" { + // repair stream found, find base stream + for base, info := range f.streams { + if info.mid == mid && info.rid == rsid { + repairSsrc = ssrc + baseSsrc = base + repairSid = rsid + delete(f.streams, base) + break + } + } + } else { + // base stream found, find repair stream + for repair, info := range f.streams { + if info.mid == mid && info.rsid == rid { + repairSsrc = repair + baseSsrc = ssrc + repairSid = info.rid + delete(f.streams, repair) + break + } + } + } + + // no rtx pair found, save it for later + if repairSsrc == 0 || baseSsrc == 0 { + f.streams[ssrc] = streamInfo{ + mid: mid, + rid: rid, + rsid: rsid, + } + } + + f.lock.Unlock() + + if repairSsrc != 0 && baseSsrc != 0 { + f.onRTXPairFound(repairSsrc, baseSsrc, repairSid) + } +} + +type RTXInfoExtractor struct { + interceptor.NoOp + + factory *RTXInfoExtractorFactory + logger logger.Logger +} + +func (u *RTXInfoExtractor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + u.factory.onStreamFound(info) + + midExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}) + streamIDExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESRTPStreamIDURI}) + repairStreamIDExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: SDESRepairRTPStreamIDURI}) + if midExtensionID == 0 || streamIDExtensionID == 0 || repairStreamIDExtensionID == 0 { + return reader + } + + return &rtxInfoReader{ + tryTimes: rtxProbeCount, + reader: reader, + midExtID: uint8(midExtensionID), + ridExtID: uint8(streamIDExtensionID), + rsidExtID: uint8(repairStreamIDExtensionID), + factory: u.factory, + logger: u.logger, + } +} + +type rtxInfoReader struct { + tryTimes int + reader interceptor.RTPReader + midExtID uint8 + ridExtID uint8 + rsidExtID uint8 + factory *RTXInfoExtractorFactory + logger logger.Logger +} + +func (r *rtxInfoReader) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + n, a, err := r.reader.Read(b, a) + if r.tryTimes < 0 || err != nil { + return n, a, err + } + + if a == nil { + a = make(interceptor.Attributes) + } + header, err := a.GetRTPHeader(b[:n]) + if err != nil { + return n, a, nil + } + + var mid, rid, rsid string + if payload := header.GetExtension(r.midExtID); payload != nil { + mid = string(payload) + } + + if payload := header.GetExtension(r.ridExtID); payload != nil { + rid = string(payload) + } + + if payload := header.GetExtension(r.rsidExtID); payload != nil { + rsid = string(payload) + } + + if mid != "" && (rid != "" || rsid != "") { + r.logger.Debugw("stream found", "mid", mid, "rid", rid, "rsid", rsid, "ssrc", header.SSRC) + r.tryTimes = -1 + go r.factory.SetStreamInfo(header.SSRC, mid, rid, rsid) + } else { + // ignore padding only packet for probe count + if !(header.Padding && n-header.MarshalSize()-int(b[n-1]) == 0) { + r.tryTimes-- + } + } + return n, a, nil +} diff --git a/livekit/pkg/sfu/interceptor/unhandlesimulcast.go b/livekit/pkg/sfu/interceptor/unhandlesimulcast.go new file mode 100644 index 0000000..33fbd13 --- /dev/null +++ b/livekit/pkg/sfu/interceptor/unhandlesimulcast.go @@ -0,0 +1,183 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package interceptor + +import ( + "github.com/pion/interceptor" + "github.com/pion/rtp" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "go.uber.org/zap/zapcore" + + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/protocol/logger" +) + +const ( + simulcastProbeCount = 10 +) + +type SimulcastTrackInfo struct { + Mid string + StreamID string + RepairSSRC uint32 // set only when `IsRepairStream: false`, i. e. RTX SSRC for the primary stream + IsRepairStream bool +} + +func (s *SimulcastTrackInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddString("Mid", s.Mid) + e.AddString("StreamID", s.StreamID) + e.AddUint32("RepairSSRC", s.RepairSSRC) + e.AddBool("IsRepairStream", s.IsRepairStream) + return nil +} + +// ------------------------------------------------------------------- + +type UnhandleSimulcastOption func(u *UnhandleSimulcastInterceptor) error + +func UnhandleSimulcastTracks(logger logger.Logger, tracks map[uint32]SimulcastTrackInfo) UnhandleSimulcastOption { + return func(u *UnhandleSimulcastInterceptor) error { + u.logger = logger + u.simTracks = tracks + return nil + } +} + +type UnhandleSimulcastInterceptorFactory struct { + opts []UnhandleSimulcastOption +} + +func (f *UnhandleSimulcastInterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) { + i := &UnhandleSimulcastInterceptor{simTracks: map[uint32]SimulcastTrackInfo{}} + for _, o := range f.opts { + if err := o(i); err != nil { + return nil, err + } + } + return i, nil +} + +func NewUnhandleSimulcastInterceptorFactory(opts ...UnhandleSimulcastOption) (*UnhandleSimulcastInterceptorFactory, error) { + return &UnhandleSimulcastInterceptorFactory{opts: opts}, nil +} + +type unhandleSimulcastRTPReader struct { + SimulcastTrackInfo + logger logger.Logger + tryTimes int + reader interceptor.RTPReader + midExtID uint8 + ridExtID uint8 + rsidExtID uint8 +} + +func (u *unhandleSimulcastRTPReader) Read(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + n, a, err := u.reader.Read(b, a) + if u.tryTimes < 0 || err != nil { + return n, a, err + } + + header := rtp.Header{} + hsize, err := header.Unmarshal(b[:n]) + if err != nil { + return n, a, nil + } + var mid, rid, rsid string + if payload := header.GetExtension(u.midExtID); payload != nil { + mid = string(payload) + } + + if payload := header.GetExtension(u.ridExtID); payload != nil { + rid = string(payload) + } + + if payload := header.GetExtension(u.rsidExtID); payload != nil { + rid = string(payload) + } + + if mid != "" && (rid != "" || rsid != "") { + u.logger.Debugw( + "unhandle stream found", + "mid", mid, + "rid", rid, + "rsid", rsid, + "ssrc", header.SSRC, + "simulcastTrackInfo", u.SimulcastTrackInfo, + ) + u.tryTimes = -1 + return n, a, nil + } else { + // ignore padding only packet for probe count + if !(header.Padding && n-header.MarshalSize()-int(b[n-1]) == 0) { + u.tryTimes-- + } + } + + if mid == "" { + header.SetExtension(u.midExtID, []byte(u.Mid)) + } + if rid == "" && !u.IsRepairStream { + header.SetExtension(u.ridExtID, []byte(u.StreamID)) + } + if rsid == "" && u.IsRepairStream { + header.SetExtension(u.rsidExtID, []byte(u.StreamID)) + } + + hsize2 := header.MarshalSize() + + if hsize2-hsize+n > len(b) { // no enough buf to set extension + return n, a, nil + } + copy(b[hsize2:], b[hsize:n]) + header.MarshalTo(b) + u.logger.Debugw( + "unhandle stream injecting", + "mid", mid, + "rid", rid, + "rsid", rsid, + "ssrc", header.SSRC, + "simulcastTrackInfo", u.SimulcastTrackInfo, + ) + return hsize2 - hsize + n, a, nil +} + +type UnhandleSimulcastInterceptor struct { + interceptor.NoOp + logger logger.Logger + simTracks map[uint32]SimulcastTrackInfo +} + +func (u *UnhandleSimulcastInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + if t, ok := u.simTracks[info.SSRC]; ok { + midExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}) + streamIDExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESRTPStreamIDURI}) + repairStreamIDExtensionID := utils.GetHeaderExtensionID(info.RTPHeaderExtensions, webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESRepairRTPStreamIDURI}) + if midExtensionID == 0 || streamIDExtensionID == 0 || repairStreamIDExtensionID == 0 { + return reader + } + + return &unhandleSimulcastRTPReader{ + SimulcastTrackInfo: t, + logger: u.logger, + reader: reader, + tryTimes: simulcastProbeCount, + midExtID: uint8(midExtensionID), + ridExtID: uint8(streamIDExtensionID), + rsidExtID: uint8(repairStreamIDExtensionID), + } + } + return reader +} diff --git a/livekit/pkg/sfu/mime/mimetype.go b/livekit/pkg/sfu/mime/mimetype.go new file mode 100644 index 0000000..d336bcc --- /dev/null +++ b/livekit/pkg/sfu/mime/mimetype.go @@ -0,0 +1,365 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mime + +import ( + "strings" + + "github.com/pion/webrtc/v4" + + "github.com/livekit/protocol/observability/roomobs" +) + +const ( + MimeTypePrefixAudio = "audio/" + MimeTypePrefixVideo = "video/" +) + +type MimeTypeCodec int + +const ( + MimeTypeCodecUnknown MimeTypeCodec = iota + MimeTypeCodecH264 + MimeTypeCodecH265 + MimeTypeCodecOpus + MimeTypeCodecRED + MimeTypeCodecVP8 + MimeTypeCodecVP9 + MimeTypeCodecAV1 + MimeTypeCodecG722 + MimeTypeCodecPCMU + MimeTypeCodecPCMA + MimeTypeCodecRTX + MimeTypeCodecFlexFEC + MimeTypeCodecULPFEC +) + +func (m MimeTypeCodec) String() string { + switch m { + case MimeTypeCodecUnknown: + return "MimeTypeCodecUnknown" + case MimeTypeCodecH264: + return "H264" + case MimeTypeCodecH265: + return "H265" + case MimeTypeCodecOpus: + return "opus" + case MimeTypeCodecRED: + return "red" + case MimeTypeCodecVP8: + return "VP8" + case MimeTypeCodecVP9: + return "VP9" + case MimeTypeCodecAV1: + return "AV1" + case MimeTypeCodecG722: + return "G722" + case MimeTypeCodecPCMU: + return "PCMU" + case MimeTypeCodecPCMA: + return "PCMA" + case MimeTypeCodecRTX: + return "rtx" + case MimeTypeCodecFlexFEC: + return "flexfec" + case MimeTypeCodecULPFEC: + return "ulpfec" + } + + return "MimeTypeCodecUnknown" +} + +func (m MimeTypeCodec) ToMimeType() MimeType { + switch m { + case MimeTypeCodecUnknown: + return MimeTypeUnknown + case MimeTypeCodecH264: + return MimeTypeH264 + case MimeTypeCodecH265: + return MimeTypeH265 + case MimeTypeCodecOpus: + return MimeTypeOpus + case MimeTypeCodecRED: + return MimeTypeRED + case MimeTypeCodecVP8: + return MimeTypeVP8 + case MimeTypeCodecVP9: + return MimeTypeVP9 + case MimeTypeCodecAV1: + return MimeTypeAV1 + case MimeTypeCodecG722: + return MimeTypeG722 + case MimeTypeCodecPCMU: + return MimeTypePCMU + case MimeTypeCodecPCMA: + return MimeTypePCMA + case MimeTypeCodecRTX: + return MimeTypeRTX + case MimeTypeCodecFlexFEC: + return MimeTypeFlexFEC + case MimeTypeCodecULPFEC: + return MimeTypeULPFEC + } + + return MimeTypeUnknown +} + +func NormalizeMimeTypeCodec(codec string) MimeTypeCodec { + switch { + case strings.EqualFold(codec, "h264"): + return MimeTypeCodecH264 + case strings.EqualFold(codec, "h265"): + return MimeTypeCodecH265 + case strings.EqualFold(codec, "opus"): + return MimeTypeCodecOpus + case strings.EqualFold(codec, "red"): + return MimeTypeCodecRED + case strings.EqualFold(codec, "vp8"): + return MimeTypeCodecVP8 + case strings.EqualFold(codec, "vp9"): + return MimeTypeCodecVP9 + case strings.EqualFold(codec, "av1"): + return MimeTypeCodecAV1 + case strings.EqualFold(codec, "g722"): + return MimeTypeCodecG722 + case strings.EqualFold(codec, "pcmu"): + return MimeTypeCodecPCMU + case strings.EqualFold(codec, "pcma"): + return MimeTypeCodecPCMA + case strings.EqualFold(codec, "rtx"): + return MimeTypeCodecRTX + case strings.EqualFold(codec, "flexfec"): + return MimeTypeCodecFlexFEC + case strings.EqualFold(codec, "ulpfec"): + return MimeTypeCodecULPFEC + } + + return MimeTypeCodecUnknown +} + +func GetMimeTypeCodec(mime string) MimeTypeCodec { + i := strings.IndexByte(mime, '/') + if i == -1 { + return MimeTypeCodecUnknown + } + + return NormalizeMimeTypeCodec(mime[i+1:]) +} + +func IsMimeTypeCodecStringOpus(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecOpus +} + +func IsMimeTypeCodecStringRED(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecRED +} + +func IsMimeTypeCodecStringPCMA(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecPCMA +} + +func IsMimeTypeCodecStringPCMU(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecPCMU +} + +func IsMimeTypeCodecStringH264(codec string) bool { + return NormalizeMimeTypeCodec(codec) == MimeTypeCodecH264 +} + +type MimeType int + +const ( + MimeTypeUnknown MimeType = iota + MimeTypeH264 + MimeTypeH265 + MimeTypeOpus + MimeTypeRED + MimeTypeVP8 + MimeTypeVP9 + MimeTypeAV1 + MimeTypeG722 + MimeTypePCMU + MimeTypePCMA + MimeTypeRTX + MimeTypeFlexFEC + MimeTypeULPFEC +) + +func (m MimeType) String() string { + switch m { + case MimeTypeUnknown: + return "MimeTypeUnknown" + case MimeTypeH264: + return webrtc.MimeTypeH264 + case MimeTypeH265: + return webrtc.MimeTypeH265 + case MimeTypeOpus: + return webrtc.MimeTypeOpus + case MimeTypeRED: + return "audio/red" + case MimeTypeVP8: + return webrtc.MimeTypeVP8 + case MimeTypeVP9: + return webrtc.MimeTypeVP9 + case MimeTypeAV1: + return webrtc.MimeTypeAV1 + case MimeTypeG722: + return webrtc.MimeTypeG722 + case MimeTypePCMU: + return webrtc.MimeTypePCMU + case MimeTypePCMA: + return webrtc.MimeTypePCMA + case MimeTypeRTX: + return webrtc.MimeTypeRTX + case MimeTypeFlexFEC: + return webrtc.MimeTypeFlexFEC + case MimeTypeULPFEC: + return "video/ulpfec" + } + + return "MimeTypeUnknown" +} + +func (m MimeType) ReporterType() roomobs.MimeType { + switch m { + case MimeTypeUnknown: + return roomobs.MimeTypeUndefined + case MimeTypeH264: + return roomobs.MimeTypeVideoH264 + case MimeTypeH265: + return roomobs.MimeTypeVideoH265 + case MimeTypeOpus: + return roomobs.MimeTypeAudioOpus + case MimeTypeRED: + return roomobs.MimeTypeAudioRed + case MimeTypeVP8: + return roomobs.MimeTypeVideoVp8 + case MimeTypeVP9: + return roomobs.MimeTypeVideoVp9 + case MimeTypeAV1: + return roomobs.MimeTypeVideoAv1 + case MimeTypeG722: + return roomobs.MimeTypeAudioG722 + case MimeTypePCMU: + return roomobs.MimeTypeAudioPcmu + case MimeTypePCMA: + return roomobs.MimeTypeAudioPcma + case MimeTypeRTX: + return roomobs.MimeTypeVideoRtx + case MimeTypeFlexFEC: + return roomobs.MimeTypeVideoFlexfec + case MimeTypeULPFEC: + return roomobs.MimeTypeVideoUlpfec + } + + return roomobs.MimeTypeUndefined +} + +func NormalizeMimeType(mime string) MimeType { + switch { + case strings.EqualFold(mime, webrtc.MimeTypeH264): + return MimeTypeH264 + case strings.EqualFold(mime, webrtc.MimeTypeH265): + return MimeTypeH265 + case strings.EqualFold(mime, webrtc.MimeTypeOpus): + return MimeTypeOpus + case strings.EqualFold(mime, "audio/red"): + return MimeTypeRED + case strings.EqualFold(mime, webrtc.MimeTypeVP8): + return MimeTypeVP8 + case strings.EqualFold(mime, webrtc.MimeTypeVP9): + return MimeTypeVP9 + case strings.EqualFold(mime, webrtc.MimeTypeAV1): + return MimeTypeAV1 + case strings.EqualFold(mime, webrtc.MimeTypeG722): + return MimeTypeG722 + case strings.EqualFold(mime, webrtc.MimeTypePCMU): + return MimeTypePCMU + case strings.EqualFold(mime, webrtc.MimeTypePCMA): + return MimeTypePCMA + case strings.EqualFold(mime, webrtc.MimeTypeRTX): + return MimeTypeRTX + case strings.EqualFold(mime, webrtc.MimeTypeFlexFEC): + return MimeTypeFlexFEC + case strings.EqualFold(mime, "video/ulpfec"): + return MimeTypeULPFEC + } + + return MimeTypeUnknown +} + +func IsMimeTypeStringEqual(mime1 string, mime2 string) bool { + return NormalizeMimeType(mime1) == NormalizeMimeType(mime2) +} + +func IsMimeTypeStringAudio(mime string) bool { + return strings.HasPrefix(mime, MimeTypePrefixAudio) +} + +func IsMimeTypeAudio(mimeType MimeType) bool { + return strings.HasPrefix(mimeType.String(), MimeTypePrefixAudio) +} + +func IsMimeTypeStringVideo(mime string) bool { + return strings.HasPrefix(mime, MimeTypePrefixVideo) +} + +func IsMimeTypeVideo(mimeType MimeType) bool { + return strings.HasPrefix(mimeType.String(), MimeTypePrefixVideo) +} + +func IsMimeTypeSVCCapable(mimeType MimeType) bool { + switch mimeType { + case MimeTypeAV1, MimeTypeVP9: + return true + } + return false +} + +func IsMimeTypeStringSVCCapable(mime string) bool { + return IsMimeTypeSVCCapable(NormalizeMimeType(mime)) +} + +func IsMimeTypeStringRED(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeRED +} + +func IsMimeTypeStringOpus(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeOpus +} + +func IsMimeTypeStringPCMA(mime string) bool { + return NormalizeMimeType(mime) == MimeTypePCMA +} + +func IsMimeTypeStringPCMU(mime string) bool { + return NormalizeMimeType(mime) == MimeTypePCMU +} + +func IsMimeTypeStringRTX(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeRTX +} + +func IsMimeTypeStringVP8(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeVP8 +} + +func IsMimeTypeStringVP9(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeVP9 +} + +func IsMimeTypeStringH264(mime string) bool { + return NormalizeMimeType(mime) == MimeTypeH264 +} diff --git a/livekit/pkg/sfu/pacer/base.go b/livekit/pkg/sfu/pacer/base.go new file mode 100644 index 0000000..96e3be2 --- /dev/null +++ b/livekit/pkg/sfu/pacer/base.go @@ -0,0 +1,138 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pacer + +import ( + "errors" + "io" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/mediatransportutil" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + "github.com/pion/rtp" + "go.uber.org/atomic" +) + +type Base struct { + logger logger.Logger + + bwe bwe.BWE + + lastPacketSentAt atomic.Int64 + + *ProbeObserver +} + +func NewBase(logger logger.Logger, bwe bwe.BWE) *Base { + return &Base{ + logger: logger, + bwe: bwe, + ProbeObserver: NewProbeObserver(logger), + } +} + +func (b *Base) SetInterval(_interval time.Duration) { +} + +func (b *Base) SetBitrate(_bitrate int) { +} + +func (b *Base) TimeSinceLastSentPacket() time.Duration { + return time.Duration(mono.UnixNano() - b.lastPacketSentAt.Load()) +} + +func (b *Base) SendPacket(p *Packet) (int, error) { + defer func() { + if p.HeaderPool != nil && p.Header != nil { + *p.Header = rtp.Header{} + p.HeaderPool.Put(p.Header) + } + + if p.Pool != nil && p.PoolEntity != nil { + p.Pool.Put(p.PoolEntity) + } + + *p = Packet{} + PacketFactory.Put(p) + }() + + err := b.patchRTPHeaderExtensions(p) + if err != nil { + b.logger.Errorw("patching rtp header extensions err", err) + return 0, err + } + + var written int + written, err = p.WriteStream.WriteRTP(p.Header, p.Payload) + if err != nil { + if !errors.Is(err, io.ErrClosedPipe) { + b.logger.Errorw("write rtp packet failed", err) + } + return 0, err + } + + return written, nil +} + +// patch just abs-send-time and transport-cc extensions if applicable +func (b *Base) patchRTPHeaderExtensions(p *Packet) error { + sendingAt := mono.Now() + if p.AbsSendTimeExtID != 0 { + absSendTimeExt := rtp.AbsSendTimeExtension{ + Timestamp: uint64(mediatransportutil.ToNtpTime(sendingAt) >> 14), + } + absSendTimeBytes, err := absSendTimeExt.Marshal() + if err != nil { + return err + } + + if err = p.Header.SetExtension(p.AbsSendTimeExtID, absSendTimeBytes); err != nil { + return err + } + + b.lastPacketSentAt.Store(sendingAt.UnixNano()) + } + + packetSize := p.HeaderSize + len(p.Payload) + if p.TransportWideExtID != 0 && b.bwe != nil { + twccSN := b.bwe.RecordPacketSendAndGetSequenceNumber( + sendingAt.UnixMicro(), + packetSize, + p.IsRTX, + p.ProbeClusterId, + p.IsProbe, + ) + twccExt := rtp.TransportCCExtension{ + TransportSequence: twccSN, + } + twccExtBytes, err := twccExt.Marshal() + if err != nil { + return err + } + + if err = p.Header.SetExtension(p.TransportWideExtID, twccExtBytes); err != nil { + return err + } + + b.lastPacketSentAt.Store(sendingAt.UnixNano()) + } + + b.ProbeObserver.RecordPacket(packetSize, p.IsRTX, p.ProbeClusterId, p.IsProbe) + return nil +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/pacer/leaky_bucket.go b/livekit/pkg/sfu/pacer/leaky_bucket.go new file mode 100644 index 0000000..db5c893 --- /dev/null +++ b/livekit/pkg/sfu/pacer/leaky_bucket.go @@ -0,0 +1,142 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pacer + +import ( + "sync" + "time" + + "github.com/frostbyte73/core" + "github.com/gammazero/deque" + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/protocol/logger" +) + +const ( + maxOvershootFactor = 2.0 +) + +type LeakyBucket struct { + *Base + + logger logger.Logger + + lock sync.RWMutex + packets deque.Deque[*Packet] + interval time.Duration + bitrate int + stop core.Fuse +} + +func NewLeakyBucket(logger logger.Logger, bwe bwe.BWE, interval time.Duration, bitrate int) *LeakyBucket { + l := &LeakyBucket{ + Base: NewBase(logger, bwe), + logger: logger, + interval: interval, + bitrate: bitrate, + } + l.packets.SetBaseCap(512) + + go l.sendWorker() + return l +} + +func (l *LeakyBucket) SetInterval(interval time.Duration) { + l.lock.Lock() + defer l.lock.Unlock() + + l.interval = interval +} + +func (l *LeakyBucket) SetBitrate(bitrate int) { + l.lock.Lock() + defer l.lock.Unlock() + + l.bitrate = bitrate +} + +func (l *LeakyBucket) Stop() { + l.stop.Break() +} + +func (l *LeakyBucket) Enqueue(p *Packet) { + l.lock.Lock() + l.packets.PushBack(p) + l.lock.Unlock() +} + +func (l *LeakyBucket) sendWorker() { + l.lock.RLock() + interval := l.interval + bitrate := l.bitrate + l.lock.RUnlock() + + timer := time.NewTimer(interval) + overage := 0 + + for { + <-timer.C + + l.lock.RLock() + interval = l.interval + bitrate = l.bitrate + l.lock.RUnlock() + + // calculate number of bytes that can be sent in this interval + // adjusting for overage. + intervalBytes := int(interval.Seconds() * float64(bitrate) / 8.0) + maxOvershootBytes := int(float64(intervalBytes) * maxOvershootFactor) + toSendBytes := intervalBytes - overage + if toSendBytes < 0 { + // too much overage, wait for next interval + overage = -toSendBytes + timer.Reset(interval) + continue + } + + // do not allow too much overshoot in an interval + if toSendBytes > maxOvershootBytes { + toSendBytes = maxOvershootBytes + } + + for { + if l.stop.IsBroken() { + return + } + + l.lock.Lock() + if l.packets.Len() == 0 { + l.lock.Unlock() + // allow overshoot in next interval with shortage in this interval + overage = -toSendBytes + timer.Reset(interval) + break + } + p := l.packets.PopFront() + l.lock.Unlock() + + written, _ := l.Base.SendPacket(p) + toSendBytes -= written + if toSendBytes < 0 { + // overage, wait for next interval + overage = -toSendBytes + timer.Reset(interval) + break + } + } + } +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/pacer/no_queue.go b/livekit/pkg/sfu/pacer/no_queue.go new file mode 100644 index 0000000..7d78589 --- /dev/null +++ b/livekit/pkg/sfu/pacer/no_queue.go @@ -0,0 +1,90 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pacer + +import ( + "sync" + + "github.com/frostbyte73/core" + "github.com/gammazero/deque" + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/protocol/logger" +) + +type NoQueue struct { + *Base + + logger logger.Logger + + lock sync.RWMutex + packets deque.Deque[*Packet] + wake chan struct{} + stop core.Fuse +} + +func NewNoQueue(logger logger.Logger, bwe bwe.BWE) *NoQueue { + n := &NoQueue{ + Base: NewBase(logger, bwe), + logger: logger, + wake: make(chan struct{}, 1), + } + n.packets.SetBaseCap(512) + + go n.sendWorker() + return n +} + +func (n *NoQueue) Stop() { + n.stop.Break() + + select { + case n.wake <- struct{}{}: + default: + } +} + +func (n *NoQueue) Enqueue(p *Packet) { + n.lock.Lock() + n.packets.PushBack(p) + n.lock.Unlock() + + select { + case n.wake <- struct{}{}: + default: + } +} + +func (n *NoQueue) sendWorker() { + for { + <-n.wake + for { + if n.stop.IsBroken() { + return + } + + n.lock.Lock() + if n.packets.Len() == 0 { + n.lock.Unlock() + break + } + p := n.packets.PopFront() + n.lock.Unlock() + + n.Base.SendPacket(p) + } + } +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/pacer/pacer.go b/livekit/pkg/sfu/pacer/pacer.go new file mode 100644 index 0000000..f232125 --- /dev/null +++ b/livekit/pkg/sfu/pacer/pacer.go @@ -0,0 +1,77 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pacer + +import ( + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" +) + +var ( + PacketFactory = &sync.Pool{ + New: func() any { + return &Packet{} + }, + } +) + +// -------------------------------------- + +type PacerBehavior string + +const ( + PacerBehaviorPassThrough PacerBehavior = "pass-through" + PacerBehaviorNoQueue PacerBehavior = "no-queue" + PacerBehaviorLeakybucket PacerBehavior = "leaky-bucket" +) + +type Packet struct { + Header *rtp.Header + HeaderPool *sync.Pool + HeaderSize int + Payload []byte + IsRTX bool + ProbeClusterId ccutils.ProbeClusterId + IsProbe bool + AbsSendTimeExtID uint8 + TransportWideExtID uint8 + WriteStream webrtc.TrackLocalWriter + Pool *sync.Pool + PoolEntity *[]byte +} + +type Pacer interface { + Enqueue(p *Packet) + Stop() + + SetInterval(interval time.Duration) + SetBitrate(bitrate int) + + TimeSinceLastSentPacket() time.Duration + + SetPacerProbeObserverListener(listener PacerProbeObserverListener) + StartProbeCluster(pci ccutils.ProbeClusterInfo) + EndProbeCluster(probeClusterId ccutils.ProbeClusterId) ccutils.ProbeClusterInfo +} + +type PacerProbeObserverListener interface { + OnPacerProbeObserverClusterComplete(probeClusterId ccutils.ProbeClusterId) +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/pacer/pass_through.go b/livekit/pkg/sfu/pacer/pass_through.go new file mode 100644 index 0000000..21fd07e --- /dev/null +++ b/livekit/pkg/sfu/pacer/pass_through.go @@ -0,0 +1,39 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pacer + +import ( + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/protocol/logger" +) + +type PassThrough struct { + *Base +} + +func NewPassThrough(logger logger.Logger, bwe bwe.BWE) *PassThrough { + return &PassThrough{ + Base: NewBase(logger, bwe), + } +} + +func (p *PassThrough) Stop() { +} + +func (p *PassThrough) Enqueue(pkt *Packet) { + p.Base.SendPacket(pkt) +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/pacer/probe_observer.go b/livekit/pkg/sfu/pacer/probe_observer.go new file mode 100644 index 0000000..ff51b81 --- /dev/null +++ b/livekit/pkg/sfu/pacer/probe_observer.go @@ -0,0 +1,143 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pacer + +import ( + "sync" + "time" + + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +type ProbeObserver struct { + logger logger.Logger + + listener PacerProbeObserverListener + + isInProbe atomic.Bool + + lock sync.Mutex + pci ccutils.ProbeClusterInfo +} + +func NewProbeObserver(logger logger.Logger) *ProbeObserver { + return &ProbeObserver{ + logger: logger, + } +} + +func (po *ProbeObserver) SetPacerProbeObserverListener(listener PacerProbeObserverListener) { + po.listener = listener +} + +func (po *ProbeObserver) StartProbeCluster(pci ccutils.ProbeClusterInfo) { + if po.isInProbe.Load() { + po.logger.Warnw( + "ignoring start of a new probe cluster when already active", nil, + "probeClusterInfo", pci, + ) + return + } + + po.lock.Lock() + defer po.lock.Unlock() + + po.pci = pci + po.pci.Result = ccutils.ProbeClusterResult{ + StartTime: mono.UnixNano(), + } + + po.isInProbe.Store(true) +} + +func (po *ProbeObserver) EndProbeCluster(probeClusterId ccutils.ProbeClusterId) ccutils.ProbeClusterInfo { + if !po.isInProbe.Load() { + // probe not active + if probeClusterId != ccutils.ProbeClusterIdInvalid { + po.logger.Debugw( + "ignoring end of a probe cluster when not active", + "probeClusterId", probeClusterId, + ) + } + return ccutils.ProbeClusterInfoInvalid + } + + po.lock.Lock() + defer po.lock.Unlock() + + if po.pci.Id != probeClusterId { + // probe cluster id not active + po.logger.Warnw( + "ignoring end of a probe cluster of a non-active one", nil, + "probeClusterId", probeClusterId, + "active", po.pci.Id, + ) + return ccutils.ProbeClusterInfoInvalid + } + + if po.pci.Result.EndTime == 0 { + po.pci.Result.EndTime = mono.UnixNano() + } + + po.isInProbe.Store(false) + + return po.pci +} + +func (po *ProbeObserver) RecordPacket(size int, isRTX bool, probeClusterId ccutils.ProbeClusterId, isProbe bool) { + if !po.isInProbe.Load() { + return + } + + po.lock.Lock() + if probeClusterId != po.pci.Id || po.pci.Result.EndTime != 0 { + po.lock.Unlock() + return + } + + if isProbe { + po.pci.Result.PacketsProbe++ + po.pci.Result.BytesProbe += size + } else { + if isRTX { + po.pci.Result.PacketsNonProbeRTX++ + po.pci.Result.BytesNonProbeRTX += size + } else { + po.pci.Result.PacketsNonProbePrimary++ + po.pci.Result.BytesNonProbePrimary += size + } + } + + notify := false + var clusterId ccutils.ProbeClusterId + if po.pci.Result.EndTime == 0 && ((po.pci.Result.Bytes() >= po.pci.Goal.DesiredBytes) && time.Duration(mono.UnixNano()-po.pci.Result.StartTime) >= po.pci.Goal.Duration) { + po.pci.Result.EndTime = mono.UnixNano() + po.pci.Result.IsCompleted = true + + notify = true + clusterId = po.pci.Id + } + po.lock.Unlock() + + if notify && po.listener != nil { + po.listener.OnPacerProbeObserverClusterComplete(clusterId) + } +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/playoutdelay.go b/livekit/pkg/sfu/playoutdelay.go new file mode 100644 index 0000000..42217f7 --- /dev/null +++ b/livekit/pkg/sfu/playoutdelay.go @@ -0,0 +1,202 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "sync" + "time" + + pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/protocol/logger" + "go.uber.org/atomic" + "go.uber.org/zap/zapcore" +) + +const ( + jitterMultiToDelay = 10 + targetDelayLogThreshold = 500 + + // limit max delay change to make it smoother for a/v sync + maxDelayChangePerSec = 80 +) + +// ---------------------------------------------------- + +type PlayoutDelayState int32 + +const ( + PlayoutDelayStateChanged PlayoutDelayState = iota + PlayoutDelaySending + PlayoutDelayAcked +) + +func (s PlayoutDelayState) String() string { + switch s { + case PlayoutDelayStateChanged: + return "StateChanged" + case PlayoutDelaySending: + return "Sending" + case PlayoutDelayAcked: + return "Acked" + } + return "Unknown" +} + +// ---------------------------------------------------- + +type PlayoutDelayControllerState struct { + SenderSnapshotID uint32 +} + +func (p PlayoutDelayControllerState) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddUint32("SenderSnapshotID", p.SenderSnapshotID) + return nil +} + +// ---------------------------------------------------- + +type PlayoutDelayController struct { + lock sync.Mutex + state atomic.Int32 + minDelay, maxDelay uint32 + currentDelay uint32 + extBytes atomic.Value //[]byte + sendingAtSeq uint16 + sendingAtTime time.Time + logger logger.Logger + rtpStats *rtpstats.RTPStatsSender + senderSnapshotID uint32 + + highDelayCount atomic.Uint32 +} + +func NewPlayoutDelayController(minDelay, maxDelay uint32, logger logger.Logger, rtpStats *rtpstats.RTPStatsSender) (*PlayoutDelayController, error) { + if maxDelay == 0 && minDelay > 0 { + maxDelay = pd.MaxPlayoutDelayDefault + } + if maxDelay > pd.PlayoutDelayMaxValue { + maxDelay = pd.PlayoutDelayMaxValue + } + c := &PlayoutDelayController{ + currentDelay: minDelay, + minDelay: minDelay, + maxDelay: maxDelay, + logger: logger, + rtpStats: rtpStats, + senderSnapshotID: rtpStats.NewSenderSnapshotId(), + } + return c, c.createExtData() +} + +func (c *PlayoutDelayController) GetState() PlayoutDelayControllerState { + c.lock.Lock() + defer c.lock.Unlock() + + return PlayoutDelayControllerState{ + SenderSnapshotID: c.senderSnapshotID, + } +} + +func (c *PlayoutDelayController) SeedState(pdcs PlayoutDelayControllerState) { + c.lock.Lock() + defer c.lock.Unlock() + + c.senderSnapshotID = pdcs.SenderSnapshotID +} + +func (c *PlayoutDelayController) SetJitter(jitter uint32) { + c.lock.Lock() + deltaInfoSender, _ := c.rtpStats.DeltaInfoSender(c.senderSnapshotID) + var nackPercent uint32 + if deltaInfoSender != nil && deltaInfoSender.Packets > 0 { + nackPercent = deltaInfoSender.Nacks * 100 / deltaInfoSender.Packets + } + + targetDelay := jitter * jitterMultiToDelay + if nackPercent > 60 { + targetDelay += (nackPercent - 60) * 2 + } + + elapsed := time.Since(c.sendingAtTime) + delayChangeLimit := uint32(maxDelayChangePerSec * elapsed.Seconds()) + if delayChangeLimit > maxDelayChangePerSec { + delayChangeLimit = maxDelayChangePerSec + } + + if targetDelay > c.currentDelay+delayChangeLimit { + targetDelay = c.currentDelay + delayChangeLimit + } else if c.currentDelay > targetDelay+delayChangeLimit { + targetDelay = c.currentDelay - delayChangeLimit + } + if targetDelay < c.minDelay { + targetDelay = c.minDelay + } + if targetDelay > c.maxDelay { + targetDelay = c.maxDelay + } + if c.currentDelay == targetDelay { + c.lock.Unlock() + return + } + if targetDelay > targetDelayLogThreshold { + if c.highDelayCount.Add(1)%100 == 1 { + c.logger.Infow("high playout delay", "target", targetDelay, "jitter", jitter, "nackPercent", nackPercent, "current", c.currentDelay) + } + } + c.currentDelay = targetDelay + c.lock.Unlock() + c.createExtData() +} + +func (c *PlayoutDelayController) OnSeqAcked(seq uint16) { + c.lock.Lock() + defer c.lock.Unlock() + if PlayoutDelayState(c.state.Load()) == PlayoutDelaySending && (seq-c.sendingAtSeq) < 0x8000 { + c.state.Store(int32(PlayoutDelayAcked)) + } +} + +func (c *PlayoutDelayController) GetDelayExtension(seq uint16) []byte { + switch PlayoutDelayState(c.state.Load()) { + case PlayoutDelayStateChanged: + c.lock.Lock() + c.state.Store(int32(PlayoutDelaySending)) + c.sendingAtSeq = seq + c.sendingAtTime = time.Now() + c.lock.Unlock() + return c.extBytes.Load().([]byte) + case PlayoutDelaySending: + return c.extBytes.Load().([]byte) + case PlayoutDelayAcked: + return nil + } + return nil +} + +func (c *PlayoutDelayController) createExtData() error { + delay := pd.PlayoutDelayFromValue( + uint16(c.currentDelay), + uint16(c.maxDelay), + ) + b, err := delay.Marshal() + if err == nil { + c.extBytes.Store(b) + c.state.Store(int32(PlayoutDelayStateChanged)) + } else { + c.logger.Errorw("failed to marshal playout delay", err, "playoutDelay", delay) + } + return err +} diff --git a/livekit/pkg/sfu/playoutdelay_test.go b/livekit/pkg/sfu/playoutdelay_test.go new file mode 100644 index 0000000..419f473 --- /dev/null +++ b/livekit/pkg/sfu/playoutdelay_test.go @@ -0,0 +1,79 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + pd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/playoutdelay" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/protocol/logger" +) + +func TestPlayoutDelay(t *testing.T) { + stats := rtpstats.NewRTPStatsSender(rtpstats.RTPStatsParams{ClockRate: 900000, Logger: logger.GetLogger()}, 128) + c, err := NewPlayoutDelayController(100, 120, logger.GetLogger(), stats) + require.NoError(t, err) + + ext := c.GetDelayExtension(100) + playoutDelayEqual(t, ext, 100, 120) + + ext = c.GetDelayExtension(105) + playoutDelayEqual(t, ext, 100, 120) + + // seq acked before delay changed + c.OnSeqAcked(65534) + ext = c.GetDelayExtension(105) + playoutDelayEqual(t, ext, 100, 120) + + c.OnSeqAcked(90) + ext = c.GetDelayExtension(105) + playoutDelayEqual(t, ext, 100, 120) + + // seq acked, no extension sent for new packet + c.OnSeqAcked(103) + ext = c.GetDelayExtension(106) + require.Nil(t, ext) + + // delay on change(can't go below min), no extension sent + c.SetJitter(0) + ext = c.GetDelayExtension(107) + require.Nil(t, ext) + + // delay changed, generate new extension to send + time.Sleep(200 * time.Millisecond) + c.SetJitter(50) + t.Log(c.currentDelay, c.state.Load()) + ext = c.GetDelayExtension(108) + var delay pd.PlayOutDelay + require.NoError(t, delay.Unmarshal(ext)) + require.Greater(t, delay.Min, uint16(100)) + + // can't go above max + time.Sleep(200 * time.Millisecond) + c.SetJitter(10000) + ext = c.GetDelayExtension(109) + playoutDelayEqual(t, ext, 120, 120) +} + +func playoutDelayEqual(t *testing.T, data []byte, min, max uint16) { + var delay pd.PlayOutDelay + require.NoError(t, delay.Unmarshal(data)) + require.Equal(t, min, delay.Min) + require.Equal(t, max, delay.Max) +} diff --git a/livekit/pkg/sfu/receiver.go b/livekit/pkg/sfu/receiver.go new file mode 100644 index 0000000..5a1482e --- /dev/null +++ b/livekit/pkg/sfu/receiver.go @@ -0,0 +1,352 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "strings" + "sync" + "time" + + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/connectionquality" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" +) + +var _ TrackReceiver = (*WebRTCReceiver)(nil) + +// WebRTCReceiver receives a media track +type WebRTCReceiver struct { + *ReceiverBase + + receiver *webrtc.RTPReceiver + onCloseHandler func() + + onRTCP func([]rtcp.Packet) + + upTracksMu sync.Mutex + upTracks [buffer.DefaultMaxLayerSpatial + 1]TrackRemote + + connectionStats *connectionquality.ConnectionStats + onStatsUpdate func(w *WebRTCReceiver, stat *livekit.AnalyticsStat) +} + +type ReceiverOpts func(w *WebRTCReceiver) *WebRTCReceiver + +// WithPliThrottleConfig indicates minimum time(ms) between sending PLIs +func WithPliThrottleConfig(pliThrottleConfig PLIThrottleConfig) ReceiverOpts { + return func(w *WebRTCReceiver) *WebRTCReceiver { + w.ReceiverBase.SetPLIThrottleConfig(pliThrottleConfig) + return w + } +} + +// WithAudioConfig sets up parameters for active speaker detection +func WithAudioConfig(audioConfig AudioConfig) ReceiverOpts { + return func(w *WebRTCReceiver) *WebRTCReceiver { + w.ReceiverBase.SetAudioConfig(audioConfig) + return w + } +} + +func WithEnableRTPStreamRestartDetection(enable bool) ReceiverOpts { + return func(w *WebRTCReceiver) *WebRTCReceiver { + w.ReceiverBase.SetEnableRTPStreamRestartDetection(enable) + return w + } +} + +// WithLoadBalanceThreshold enables parallelization of packet writes when downTracks exceeds threshold +// Value should be between 3 and 150. +// For a server handling a few large rooms, use a smaller value (required to handle very large (250+ participant) rooms). +// For a server handling many small rooms, use a larger value or disable. +// Set to 0 (disabled) by default. +func WithLoadBalanceThreshold(downTracks int) ReceiverOpts { + return func(w *WebRTCReceiver) *WebRTCReceiver { + w.ReceiverBase.SetLBThreshold(downTracks) + return w + } +} + +func WithForwardStats(forwardStats *ForwardStats) ReceiverOpts { + return func(w *WebRTCReceiver) *WebRTCReceiver { + w.ReceiverBase.SetForwardStats(forwardStats) + return w + } +} + +// NewWebRTCReceiver creates a new webrtc track receiver +func NewWebRTCReceiver( + receiver *webrtc.RTPReceiver, + track TrackRemote, + trackInfo *livekit.TrackInfo, + logger logger.Logger, + onRTCP func([]rtcp.Packet), + streamTrackerManagerConfig StreamTrackerManagerConfig, + opts ...ReceiverOpts, +) *WebRTCReceiver { + w := &WebRTCReceiver{ + receiver: receiver, + onRTCP: onRTCP, + } + + w.ReceiverBase = NewReceiverBase( + ReceiverBaseParams{ + TrackID: livekit.TrackID(track.ID()), + StreamID: track.StreamID(), + Kind: track.Kind(), + Codec: track.Codec(), + HeaderExtensions: receiver.GetParameters().HeaderExtensions, + Logger: logger, + StreamTrackerManagerConfig: streamTrackerManagerConfig, + StreamTrackerManagerListener: w, + IsSelfClosing: true, + OnClosed: w.onClosed, + }, + trackInfo, + ReceiverCodecStateNormal, + ) + + for _, opt := range opts { + w = opt(w) + } + + w.connectionStats = connectionquality.NewConnectionStats(connectionquality.ConnectionStatsParams{ + ReceiverProvider: w, + Logger: logger.WithValues("direction", "up"), + }) + w.connectionStats.OnStatsUpdate(func(_cs *connectionquality.ConnectionStats, stat *livekit.AnalyticsStat) { + if w.onStatsUpdate != nil { + w.onStatsUpdate(w, stat) + } + }) + codec := track.Codec() + w.connectionStats.Start( + mime.NormalizeMimeType(codec.MimeType), + // TODO: technically not correct to declare FEC on when RED. Need the primary codec's fmtp line to check. + mime.IsMimeTypeStringRED(codec.MimeType) || strings.Contains(strings.ToLower(codec.SDPFmtpLine), "useinbandfec=1"), + ) + + return w +} + +func (w *WebRTCReceiver) OnStatsUpdate(fn func(w *WebRTCReceiver, stat *livekit.AnalyticsStat)) { + w.onStatsUpdate = fn +} + +func (w *WebRTCReceiver) GetConnectionScoreAndQuality() (float32, livekit.ConnectionQuality) { + return w.connectionStats.GetScoreAndQuality() +} + +func (w *WebRTCReceiver) ssrc(layer int) uint32 { + if track := w.upTracks[layer]; track != nil { + return uint32(track.SSRC()) + } + return 0 +} + +func (w *WebRTCReceiver) AddUpTrack(track TrackRemote, buff *buffer.Buffer) error { + if w.isClosed.Load() { + return ErrReceiverClosed + } + + layer := int32(0) + if w.Kind() == webrtc.RTPCodecTypeVideo && w.videoLayerMode != livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + layer = buffer.GetSpatialLayerForRid(w.Mime(), track.RID(), w.ReceiverBase.TrackInfo()) + } + if layer < 0 { + w.ReceiverBase.Logger().Warnw( + "invalid layer", nil, + "rid", track.RID(), + "trackInfo", logger.Proto(w.ReceiverBase.TrackInfo()), + ) + return ErrInvalidLayer + } + + w.upTracksMu.Lock() + if w.upTracks[layer] != nil { + w.upTracksMu.Unlock() + return ErrDuplicateLayer + } + w.upTracks[layer] = track + w.upTracksMu.Unlock() + + w.ReceiverBase.AddBuffer(buff, layer) + buff.OnRtcpFeedback(w.sendRTCP) + w.ReceiverBase.StartBuffer(buff, layer) + return nil +} + +func (w *WebRTCReceiver) SetUpTrackPaused(paused bool) { + w.ReceiverBase.SetUpTrackPaused(paused) + + w.connectionStats.UpdateMute(paused) +} + +func (w *WebRTCReceiver) notifyMaxExpectedLayer(layer int32) { + ti := w.TrackInfo() + if ti == nil { + return + } + + if w.Kind() == webrtc.RTPCodecTypeAudio || ti.Source == livekit.TrackSource_SCREEN_SHARE { + // screen share tracks have highly variable bitrate, do not use bit rate based quality for those + return + } + + expectedBitrate := int64(0) + for _, vl := range buffer.GetVideoLayersForMimeType(w.Mime(), ti) { + if vl.SpatialLayer <= layer { + expectedBitrate += int64(vl.Bitrate) + } + } + + w.connectionStats.AddBitrateTransition(expectedBitrate) +} + +func (w *WebRTCReceiver) SetMaxExpectedSpatialLayer(layer int32) { + w.ReceiverBase.SetMaxExpectedSpatialLayer(layer) + + w.notifyMaxExpectedLayer(layer) + + if layer == buffer.InvalidLayerSpatial { + w.connectionStats.UpdateLayerMute(true) + } else { + w.connectionStats.UpdateLayerMute(false) + w.connectionStats.AddLayerTransition(w.ReceiverBase.StreamTrackerManager().DistanceToDesired()) + } +} + +// StreamTrackerManagerListener.OnAvailableLayersChanged +func (w *WebRTCReceiver) OnAvailableLayersChanged() { + w.connectionStats.AddLayerTransition(w.ReceiverBase.StreamTrackerManager().DistanceToDesired()) +} + +// StreamTrackerManagerListener.OnBitrateAvailabilityChanged +func (w *WebRTCReceiver) OnBitrateAvailabilityChanged() { +} + +// StreamTrackerManagerListener.OnMaxPublishedLayerChanged +func (w *WebRTCReceiver) OnMaxPublishedLayerChanged(maxPublishedLayer int32) { + w.notifyMaxExpectedLayer(maxPublishedLayer) + w.connectionStats.AddLayerTransition(w.ReceiverBase.StreamTrackerManager().DistanceToDesired()) +} + +// StreamTrackerManagerListener.OnMaxTemporalLayerSeenChanged +func (w *WebRTCReceiver) OnMaxTemporalLayerSeenChanged(maxTemporalLayerSeen int32) { + w.connectionStats.AddLayerTransition(w.ReceiverBase.StreamTrackerManager().DistanceToDesired()) +} + +// StreamTrackerManagerListener.OnMaxAvailableLayerChanged +func (w *WebRTCReceiver) OnMaxAvailableLayerChanged(maxAvailableLayer int32) { +} + +// StreamTrackerManagerListener.OnBitrateReport +func (w *WebRTCReceiver) OnBitrateReport(availableLayers []int32, bitrates Bitrates) { + w.connectionStats.AddLayerTransition(w.ReceiverBase.StreamTrackerManager().DistanceToDesired()) +} + +// OnCloseHandler method to be called on remote track removed +func (w *WebRTCReceiver) OnCloseHandler(fn func()) { + w.onCloseHandler = fn +} + +func (w *WebRTCReceiver) sendRTCP(packets []rtcp.Packet) { + if packets == nil || w.isClosed.Load() { + return + } + + if w.onRTCP != nil { + w.onRTCP(packets) + } +} + +func (w *WebRTCReceiver) GetDeltaStats() map[uint32]*buffer.StreamStatsWithLayers { + buffers := w.ReceiverBase.GetAllBuffers() + deltaStats := make(map[uint32]*buffer.StreamStatsWithLayers, len(buffers)) + for layer, buff := range buffers { + if buff == nil { + continue + } + + sswl := buff.GetDeltaStats() + if sswl == nil { + continue + } + + // patch buffer stats with correct layer + patched := make(map[int32]*rtpstats.RTPDeltaInfo, 1) + patched[int32(layer)] = sswl.Layers[0] + sswl.Layers = patched + + deltaStats[w.ssrc(layer)] = sswl + } + + return deltaStats +} + +func (w *WebRTCReceiver) GetLastSenderReportTime() time.Time { + buffers := w.ReceiverBase.GetAllBuffers() + latestSRTime := time.Time{} + for _, buff := range buffers { + if buff == nil { + continue + } + + srAt := buff.GetLastSenderReportTime() + if srAt.After(latestSRTime) { + latestSRTime = srAt + } + } + + return latestSRTime +} + +func (w *WebRTCReceiver) onClosed() { + w.connectionStats.Close() + + if w.onCloseHandler != nil { + w.onCloseHandler() + } +} + +func (w *WebRTCReceiver) DebugInfo() map[string]any { + info := w.ReceiverBase.DebugInfo() + + w.upTracksMu.Lock() + upTrackInfo := make([]map[string]any, 0, len(w.upTracks)) + for layer, ut := range w.upTracks { + if ut != nil { + upTrackInfo = append(upTrackInfo, map[string]any{ + "Layer": layer, + "SSRC": ut.SSRC(), + "Msid": ut.Msid(), + "RID": ut.RID(), + }) + } + } + w.upTracksMu.Unlock() + info["UpTracks"] = upTrackInfo + + return info +} + +// ----------------------------------------------------------- diff --git a/livekit/pkg/sfu/receiver_base.go b/livekit/pkg/sfu/receiver_base.go new file mode 100644 index 0000000..e844780 --- /dev/null +++ b/livekit/pkg/sfu/receiver_base.go @@ -0,0 +1,1222 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "errors" + "fmt" + "io" + "slices" + "strings" + "sync" + "time" + + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + + "github.com/livekit/mediatransportutil/pkg/bucket" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/mono" + + "github.com/livekit/livekit-server/pkg/sfu/audio" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/rtpstats" + "github.com/livekit/livekit-server/pkg/sfu/streamtracker" + sfuutils "github.com/livekit/livekit-server/pkg/sfu/utils" +) + +var ( + ErrReceiverClosed = errors.New("receiver closed") + ErrDownTrackAlreadyExist = errors.New("DownTrack already exist") + ErrDuplicateLayer = errors.New("duplicate layer") + ErrInvalidLayer = errors.New("invalid layer") +) + +// -------------------------------------- + +type PLIThrottleConfig struct { + LowQuality time.Duration `yaml:"low_quality,omitempty"` + MidQuality time.Duration `yaml:"mid_quality,omitempty"` + HighQuality time.Duration `yaml:"high_quality,omitempty"` +} + +var ( + DefaultPLIThrottleConfig = PLIThrottleConfig{ + LowQuality: 500 * time.Millisecond, + MidQuality: time.Second, + HighQuality: time.Second, + } +) + +// -------------------------------------- + +type AudioConfig struct { + audio.AudioLevelConfig `yaml:",inline"` + + // enable red encoding downtrack for opus only audio up track + ActiveREDEncoding bool `yaml:"active_red_encoding,omitempty"` + // enable proxying weakest subscriber loss to publisher in RTCP Receiver Report + EnableLossProxying bool `yaml:"enable_loss_proxying,omitempty"` +} + +var ( + DefaultAudioConfig = AudioConfig{ + AudioLevelConfig: audio.DefaultAudioLevelConfig, + } +) + +// -------------------------------------- + +type Bitrates [buffer.DefaultMaxLayerSpatial + 1][buffer.DefaultMaxLayerTemporal + 1]int64 + +// -------------------------------------- + +type ReceiverCodecState int + +const ( + ReceiverCodecStateNormal ReceiverCodecState = iota + ReceiverCodecStateSuspended + ReceiverCodecStateInvalid +) + +// -------------------------------------- + +// TrackReceiver defines an interface receive media from remote peer +type TrackReceiver interface { + TrackID() livekit.TrackID + StreamID() string + + // returns the initial codec of the receiver, it is determined by the track's codec + // and will not change if the codec changes during the session (publisher changes codec) + Codec() webrtc.RTPCodecParameters + Mime() mime.MimeType + VideoLayerMode() livekit.VideoLayer_Mode + HeaderExtensions() []webrtc.RTPHeaderExtensionParameter + IsClosed() bool + + ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) + GetLayeredBitrate() ([]int32, Bitrates) + + GetAudioLevel() (float64, bool) + + SendPLI(layer int32, force bool) + + SetUpTrackPaused(paused bool) + SetMaxExpectedSpatialLayer(layer int32) + + AddDownTrack(track TrackSender) error + DeleteDownTrack(participantID livekit.ParticipantID) + GetDownTracks() []TrackSender + + DebugInfo() map[string]any + + TrackInfo() *livekit.TrackInfo + UpdateTrackInfo(ti *livekit.TrackInfo) + + // Get primary receiver if this receiver represents a RED codec; otherwise it will return itself + GetPrimaryReceiverForRed() TrackReceiver + + // Get red receiver for primary codec, used by forward red encodings for opus only codec + GetRedReceiver() TrackReceiver + + GetTemporalLayerFpsForSpatial(layer int32) []float32 + + GetTrackStats() *livekit.RTPStats + + // AddOnReady adds a function to be called when the receiver is ready, the callback + // could be called immediately if the receiver is ready when the callback is added + AddOnReady(func()) + + AddOnCodecStateChange(func(webrtc.RTPCodecParameters, ReceiverCodecState)) + CodecState() ReceiverCodecState + + // VideoSizes returns the video size parsed from rtp packet for each spatial layer. + VideoSizes() []buffer.VideoSize + + // closes all associated buffers and issues a resync to all attached downtracks so that + // they can resync and have proper sequncing without gaps in sequence numbers / timestamps + Restart(reason string) +} + +// -------------------------------------- + +type REDTransformer interface { + TrackReceiver + ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int32) int32 + ForwardRTCPSenderReport( + payloadType webrtc.PayloadType, + layer int32, + publisherSRData *livekit.RTCPSenderReportState, + ) + GetDownTracks() []TrackSender + ResyncDownTracks() + OnStreamRestart() + CanClose() bool + Close() +} + +// -------------------------------------- + +type ReceiverBaseParams struct { + TrackID livekit.TrackID + StreamID string + Kind webrtc.RTPCodecType + Codec webrtc.RTPCodecParameters + HeaderExtensions []webrtc.RTPHeaderExtensionParameter + Logger logger.Logger + StreamTrackerManagerConfig StreamTrackerManagerConfig + StreamTrackerManagerListener StreamTrackerManagerListener + IsSelfClosing bool + OnClosed func() +} + +type ReceiverBase struct { + params ReceiverBaseParams + + pliThrottleConfig PLIThrottleConfig + audioConfig AudioConfig + enableRTPStreamRestartDetection bool + lbThreshold int + forwardStats *ForwardStats + + codecStateLock sync.Mutex + codecState ReceiverCodecState + onCodecStateChange []func(webrtc.RTPCodecParameters, ReceiverCodecState) + + isRED bool + videoLayerMode livekit.VideoLayer_Mode + + bufferMu sync.RWMutex + buffers [buffer.DefaultMaxLayerSpatial + 1]buffer.BufferProvider + trackInfo *livekit.TrackInfo + + videoSizeMu sync.RWMutex + videoSizes [buffer.DefaultMaxLayerSpatial + 1]buffer.VideoSize + onVideoSizeChanged func() + + rtt uint32 + + streamTrackerManager *StreamTrackerManager + + downTrackSpreader *sfuutils.DownTrackSpreader[TrackSender] + + onMaxLayerChange func(mimeType mime.MimeType, maxLayer int32) + + redTransformer atomic.Pointer[REDTransformer] + + forwardersGeneration atomic.Uint32 + forwardersWaitGroup *sync.WaitGroup + restartInProgress bool + + isClosed atomic.Bool +} + +func NewReceiverBase(params ReceiverBaseParams, trackInfo *livekit.TrackInfo, codecState ReceiverCodecState) *ReceiverBase { + r := &ReceiverBase{ + params: params, + codecState: codecState, + isRED: mime.IsMimeTypeStringRED(params.Codec.MimeType), + trackInfo: utils.CloneProto(trackInfo), + videoLayerMode: buffer.GetVideoLayerModeForMimeType(mime.NormalizeMimeType(params.Codec.MimeType), trackInfo), + } + + r.downTrackSpreader = sfuutils.NewDownTrackSpreader[TrackSender](sfuutils.DownTrackSpreaderParams{ + Threshold: r.lbThreshold, + Logger: params.Logger, + }) + + r.streamTrackerManager = NewStreamTrackerManager( + params.Logger, + trackInfo, + r.Mime(), + r.params.Codec.ClockRate, + params.StreamTrackerManagerConfig, + ) + r.streamTrackerManager.SetListener(r) + + r.startForwarderGeneration() + + return r +} + +func (r *ReceiverBase) Close(reason string, clearBuffers bool) { + if r.isClosed.Swap(true) { + return + } + + if clearBuffers { + r.ClearAllBuffers(reason) + } + r.streamTrackerManager.Close() + + closeTrackSenders(r.downTrackSpreader.ResetAndGetDownTracks()) + + if rt := r.loadREDTransformer(); rt != nil { + rt.Close() + } + + if r.params.OnClosed != nil { + r.params.OnClosed() + } +} + +func (r *ReceiverBase) CanClose() bool { + if r.IsClosed() { + return true + } + + if r.downTrackSpreader.DownTrackCount() != 0 { + return false + } + + if rt := r.loadREDTransformer(); rt != nil { + return rt.CanClose() + } + + return true +} + +func (r *ReceiverBase) SetPLIThrottleConfig(pliThrottleConfig PLIThrottleConfig) { + r.pliThrottleConfig = pliThrottleConfig +} + +func (r *ReceiverBase) SetAudioConfig(audioConfig AudioConfig) { + r.audioConfig = audioConfig +} + +func (r *ReceiverBase) SetEnableRTPStreamRestartDetection(enableRTPStremRestartDetection bool) { + r.enableRTPStreamRestartDetection = enableRTPStremRestartDetection +} + +func (r *ReceiverBase) SetLBThreshold(lbThreshold int) { + r.lbThreshold = lbThreshold +} + +func (r *ReceiverBase) SetForwardStats(forwardStats *ForwardStats) { + r.forwardStats = forwardStats +} + +func (r *ReceiverBase) Logger() logger.Logger { + return r.params.Logger +} + +func (r *ReceiverBase) TrackInfo() *livekit.TrackInfo { + r.bufferMu.RLock() + defer r.bufferMu.RUnlock() + + return utils.CloneProto(r.trackInfo) +} + +func (r *ReceiverBase) UpdateTrackInfo(ti *livekit.TrackInfo) { + r.bufferMu.Lock() + existingVersion := utils.TimedVersionFromProto(r.trackInfo.Version) + updateVersion := utils.TimedVersionFromProto(ti.Version) + if updateVersion.Compare(existingVersion) < 0 { + r.bufferMu.Unlock() + r.params.Logger.Debugw( + "not updating to older version", + "existing", logger.Proto(r.trackInfo), + "updated", logger.Proto(ti), + ) + return + } + + shouldResync := utils.TimedVersionFromProto(r.trackInfo.Version) != utils.TimedVersionFromProto(ti.Version) + if shouldResync { + r.params.Logger.Debugw( + "updating track info", + "existing", logger.Proto(r.trackInfo), + "updated", logger.Proto(ti), + "shouldResync", shouldResync, + ) + } + r.trackInfo = utils.CloneProto(ti) + // MUTABLE-TRACKINFO-TODO: notify buffers, buffers may need to resize retransmission buffer if there is layer change + r.bufferMu.Unlock() + + r.streamTrackerManager.UpdateTrackInfo(ti) + + if shouldResync { + r.Restart("update-track-info") + } +} + +func (r *ReceiverBase) Restart(reason string) { + r.params.Logger.Infow("restarting receiver", "reason", reason) + r.restartInternal(reason, false) +} + +func (r *ReceiverBase) restartInternal(reason string, isDetected bool) { + if r.IsClosed() { + return + } + + // 1. guard against concurrent restarts + r.bufferMu.Lock() + if r.restartInProgress { + r.bufferMu.Unlock() + return + } + r.restartInProgress = true + r.bufferMu.Unlock() + + // 2. restart all the buffers + // if a stream was detected, skip external restart + // + // NOTE: The case of external restart and detected restart (which usually comes from one buffer) + // racing will miss restart on all buffers if detected restart from one buffer adds the guard + // against concurrent restart. But, that condition should be very rare if at all. + // External restart happens when the underlying track changes or when seeking + if !isDetected { + for _, buff := range r.GetAllBuffers() { + if buff == nil { + continue + } + + buff.RestartStream(reason) + } + } + + // 3. wait for the forwarders to finish + r.stopForwarderGeneration() + + // 4. reset stream tracker + r.streamTrackerManager.RemoveAllTrackers() + + // 5. signal attached downtracks to resync so that they can have proper sequencing on a receiver restart + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.ReceiverRestart() + }) + if rt := r.loadREDTransformer(); rt != nil { + rt.OnStreamRestart() + } + + // 6. move forwarder generation ahead + r.startForwarderGeneration() + + r.bufferMu.Lock() + // 7. release restart hold + r.restartInProgress = false + + // 8. restart forwarders + for layer, buff := range r.buffers { + if buff == nil { + continue + } + + r.startForwarderForBufferLocked(int32(layer), buff) + } + r.bufferMu.Unlock() +} + +func (r *ReceiverBase) OnMaxLayerChange(fn func(mimeType mime.MimeType, maxLayer int32)) { + r.bufferMu.Lock() + r.onMaxLayerChange = fn + r.bufferMu.Unlock() +} + +func (r *ReceiverBase) getOnMaxLayerChange() func(mimeType mime.MimeType, maxLayer int32) { + r.bufferMu.RLock() + defer r.bufferMu.RUnlock() + + return r.onMaxLayerChange +} + +func (r *ReceiverBase) IsClosed() bool { + return r.isClosed.Load() +} + +func (r *ReceiverBase) SetRTT(rtt uint32) { + r.bufferMu.Lock() + if r.rtt == rtt || rtt == 0 { + r.bufferMu.Unlock() + return + } + + r.rtt = rtt + buffers := r.buffers + r.bufferMu.Unlock() + + for _, buff := range buffers { + if buff == nil { + continue + } + + buff.SetRTT(rtt) + } +} + +func (r *ReceiverBase) TrackID() livekit.TrackID { + return r.params.TrackID +} + +func (r *ReceiverBase) StreamID() string { + return r.params.StreamID +} + +func (r *ReceiverBase) Codec() webrtc.RTPCodecParameters { + return r.params.Codec +} + +func (r *ReceiverBase) Mime() mime.MimeType { + return mime.NormalizeMimeType(r.params.Codec.MimeType) +} + +func (r *ReceiverBase) VideoLayerMode() livekit.VideoLayer_Mode { + return r.videoLayerMode +} + +func (r *ReceiverBase) HeaderExtensions() []webrtc.RTPHeaderExtensionParameter { + return r.params.HeaderExtensions +} + +func (r *ReceiverBase) Kind() webrtc.RTPCodecType { + return r.params.Kind +} + +func (r *ReceiverBase) StreamTrackerManager() *StreamTrackerManager { + return r.streamTrackerManager +} + +// SetUpTrackPaused indicates upstream will not be sending any data. +// this will reflect the "muted" status and will pause streamtracker to ensure we don't turn off +// the layer +func (r *ReceiverBase) SetUpTrackPaused(paused bool) { + r.streamTrackerManager.SetPaused(paused) + + r.bufferMu.RLock() + for _, buff := range r.buffers { + if buff == nil { + continue + } + + buff.SetPaused(paused) + } + r.bufferMu.RUnlock() +} + +func (r *ReceiverBase) AddDownTrack(track TrackSender) error { + if r.IsClosed() { + return ErrReceiverClosed + } + + if r.downTrackSpreader.HasDownTrack(track.SubscriberID()) { + r.params.Logger.Infow("subscriberID already exists, replacing downtrack", "subscriberID", track.SubscriberID()) + } + + track.UpTrackMaxPublishedLayerChange(r.streamTrackerManager.GetMaxPublishedLayer()) + track.UpTrackMaxTemporalLayerSeenChange(r.streamTrackerManager.GetMaxTemporalLayerSeen()) + + r.downTrackSpreader.Store(track) + r.params.Logger.Debugw("downtrack added", "subscriberID", track.SubscriberID()) + return nil +} + +func (r *ReceiverBase) DeleteDownTrack(subscriberID livekit.ParticipantID) { + r.downTrackSpreader.Free(subscriberID) + r.params.Logger.Debugw("downtrack deleted", "subscriberID", subscriberID) +} + +func (r *ReceiverBase) GetDownTracks() []TrackSender { + downTracks := r.downTrackSpreader.GetDownTracks() + if rt := r.loadREDTransformer(); rt != nil { + downTracks = append(downTracks, rt.GetDownTracks()...) + } + return downTracks +} + +func (r *ReceiverBase) SetMaxExpectedSpatialLayer(layer int32) { + prevMax := r.streamTrackerManager.SetMaxExpectedSpatialLayer(layer) + r.params.Logger.Debugw("max expected layer change", "layer", layer, "prevMax", prevMax) + + r.bufferMu.RLock() + // stop key frame seeders of stopped layers + for idx := layer + 1; idx <= prevMax; idx++ { + if r.buffers[idx] != nil { + r.buffers[idx].StopKeyFrameSeeder() + } + } + + // start key frame seeders of newly expected layers + for idx := prevMax + 1; idx <= layer; idx++ { + if r.buffers[idx] != nil { + r.buffers[idx].StartKeyFrameSeeder() + } + } + r.bufferMu.RUnlock() +} + +// StreamTrackerManagerListener.OnAvailableLayersChanged +func (r *ReceiverBase) OnAvailableLayersChanged() { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.UpTrackLayersChange() + }) + + if r.params.StreamTrackerManagerListener != nil { + r.params.StreamTrackerManagerListener.OnAvailableLayersChanged() + } +} + +// StreamTrackerManagerListener.OnBitrateAvailabilityChanged +func (r *ReceiverBase) OnBitrateAvailabilityChanged() { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.UpTrackBitrateAvailabilityChange() + }) + + if r.params.StreamTrackerManagerListener != nil { + r.params.StreamTrackerManagerListener.OnBitrateAvailabilityChanged() + } +} + +// StreamTrackerManagerListener.OnMaxPublishedLayerChanged +func (r *ReceiverBase) OnMaxPublishedLayerChanged(maxPublishedLayer int32) { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.UpTrackMaxPublishedLayerChange(maxPublishedLayer) + }) + + if r.params.StreamTrackerManagerListener != nil { + r.params.StreamTrackerManagerListener.OnMaxPublishedLayerChanged(maxPublishedLayer) + } +} + +// StreamTrackerManagerListener.OnMaxTemporalLayerSeenChanged +func (r *ReceiverBase) OnMaxTemporalLayerSeenChanged(maxTemporalLayerSeen int32) { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.UpTrackMaxTemporalLayerSeenChange(maxTemporalLayerSeen) + }) + + if r.params.StreamTrackerManagerListener != nil { + r.params.StreamTrackerManagerListener.OnMaxTemporalLayerSeenChanged(maxTemporalLayerSeen) + } +} + +// StreamTrackerManagerListener.OnMaxAvailableLayerChanged +func (r *ReceiverBase) OnMaxAvailableLayerChanged(maxAvailableLayer int32) { + if onMaxLayerChange := r.getOnMaxLayerChange(); onMaxLayerChange != nil { + onMaxLayerChange(r.Mime(), maxAvailableLayer) + } + + if r.params.StreamTrackerManagerListener != nil { + r.params.StreamTrackerManagerListener.OnMaxAvailableLayerChanged(maxAvailableLayer) + } +} + +// StreamTrackerManagerListener.OnBitrateReport +func (r *ReceiverBase) OnBitrateReport(availableLayers []int32, bitrates Bitrates) { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.UpTrackBitrateReport(availableLayers, bitrates) + }) + + if r.params.StreamTrackerManagerListener != nil { + r.params.StreamTrackerManagerListener.OnBitrateReport(availableLayers, bitrates) + } +} + +func (r *ReceiverBase) GetLayeredBitrate() ([]int32, Bitrates) { + return r.streamTrackerManager.GetLayeredBitrate() +} + +func (r *ReceiverBase) SendPLI(layer int32, force bool) { + // SVC-TODO : should send LRR (Layer Refresh Request) instead of PLI + buff, _ := r.getBuffer(layer) + if buff == nil { + return + } + + buff.SendPLI(force) +} + +func (r *ReceiverBase) getBuffer(layer int32) (buffer.BufferProvider, int32) { + r.bufferMu.RLock() + defer r.bufferMu.RUnlock() + + return r.getBufferLocked(layer) +} + +func (r *ReceiverBase) getBufferLocked(layer int32) (buffer.BufferProvider, int32) { + // for svc codecs, use layer = 0 always. + // spatial layers are in-built and handled by single buffer + if r.videoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + layer = 0 + } + + if layer < 0 || int(layer) >= len(r.buffers) { + return nil, layer + } + + return r.buffers[layer], layer +} + +func (r *ReceiverBase) GetOrCreateBuffer( + layer int32, + creatorFn func(*livekit.TrackInfo) (buffer.BufferProvider, error), +) (buffer.BufferProvider, bool) { + r.bufferMu.Lock() + + if r.IsClosed() { + r.bufferMu.Unlock() + return nil, false + } + + var buff buffer.BufferProvider + if buff, layer = r.getBufferLocked(layer); buff != nil { + r.bufferMu.Unlock() + return buff, false + } + + buff, err := creatorFn(r.trackInfo) + if err != nil { + r.bufferMu.Unlock() + r.params.Logger.Errorw("could not create buffer", err) + return nil, false + } + + r.buffers[layer] = buff + rtt := r.rtt + r.bufferMu.Unlock() + + r.setupBuffer(buff, layer, rtt) + return buff, true +} + +func (r *ReceiverBase) setupBuffer(buff buffer.BufferProvider, layer int32, rtt uint32) { + buff.SetLogger(r.params.Logger.WithValues("layer", layer)) + buff.SetAudioLevelConfig(r.audioConfig.AudioLevelConfig) + buff.SetStreamRestartDetection(r.enableRTPStreamRestartDetection) + buff.OnRtcpSenderReport(func() { + srData := buff.GetSenderReportData() + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + _ = dt.HandleRTCPSenderReportData(r.params.Codec.PayloadType, layer, srData) + }) + + if rt := r.loadREDTransformer(); rt != nil { + rt.ForwardRTCPSenderReport(r.params.Codec.PayloadType, layer, srData) + } + }) + buff.OnVideoSizeChanged(func(videoSize []buffer.VideoSize) { + r.videoSizeMu.Lock() + if r.videoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + copy(r.videoSizes[:], videoSize) + } else { + r.videoSizes[layer] = videoSize[0] + } + r.params.Logger.Debugw("video size changed", "size", r.videoSizes) + cb := r.onVideoSizeChanged + r.videoSizeMu.Unlock() + + if cb != nil { + cb() + } + }) + if r.Kind() == webrtc.RTPCodecTypeVideo && layer == 0 { + buff.OnCodecChange(r.handleCodecChange) + } + buff.OnStreamRestart(func(reason string) { + r.restartInternal(reason, true) + }) + + var duration time.Duration + switch layer { + case 2: + duration = r.pliThrottleConfig.HighQuality + case 1: + duration = r.pliThrottleConfig.MidQuality + case 0: + duration = r.pliThrottleConfig.LowQuality + default: + duration = r.pliThrottleConfig.MidQuality + } + if duration != 0 { + buff.SetPLIThrottle(duration.Nanoseconds()) + } + + buff.SetRTT(rtt) + buff.SetPaused(r.streamTrackerManager.IsPaused()) +} + +func (r *ReceiverBase) AddBuffer(buff buffer.BufferProvider, layer int32) { + r.bufferMu.Lock() + r.buffers[layer] = buff + rtt := r.rtt + r.bufferMu.Unlock() + + r.setupBuffer(buff, layer, rtt) +} + +func (r *ReceiverBase) StartBuffer(buff buffer.BufferProvider, layer int32) { + r.bufferMu.Lock() + r.startForwarderForBufferLocked(layer, buff) + r.bufferMu.Unlock() +} + +func (r *ReceiverBase) GetAllBuffers() [buffer.DefaultMaxLayerSpatial + 1]buffer.BufferProvider { + buffers := [buffer.DefaultMaxLayerSpatial + 1]buffer.BufferProvider{} + + r.bufferMu.RLock() + defer r.bufferMu.RUnlock() + + for i := range buffers { + buffers[i] = r.buffers[i] + } + return buffers +} + +func (r *ReceiverBase) ClearAllBuffers(reason string) { + r.bufferMu.Lock() + buffers := r.buffers + for idx := range r.buffers { + r.buffers[idx] = nil + } + r.bufferMu.Unlock() + + for _, buff := range buffers { + if buff == nil { + continue + } + buff.CloseWithReason(reason) + } + + r.streamTrackerManager.RemoveAllTrackers() +} + +func (r *ReceiverBase) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { + b, _ := r.getBuffer(int32(layer)) + if b == nil { + return 0, bucket.ErrPacketMismatch + } + + return b.GetPacket(buf, esn) +} + +func (r *ReceiverBase) GetTrackStats() *livekit.RTPStats { + r.bufferMu.RLock() + defer r.bufferMu.RUnlock() + + allStats := make([]*livekit.RTPStats, 0, len(r.buffers)) + for _, buff := range r.buffers { + if buff == nil { + continue + } + + stats := buff.GetStats() + if stats == nil { + continue + } + + allStats = append(allStats, stats) + } + + return rtpstats.AggregateRTPStats(allStats) +} + +func (r *ReceiverBase) GetAudioLevel() (float64, bool) { + if r.Kind() == webrtc.RTPCodecTypeVideo { + return 0, false + } + + r.bufferMu.RLock() + defer r.bufferMu.RUnlock() + + for _, buff := range r.buffers { + if buff == nil { + continue + } + + return buff.GetAudioLevel() + } + + return 0, false +} + +func (r *ReceiverBase) startForwarderGeneration() { + r.bufferMu.Lock() + defer r.bufferMu.Unlock() + + r.forwardersGeneration.Inc() + r.forwardersWaitGroup = &sync.WaitGroup{} +} + +func (r *ReceiverBase) stopForwarderGeneration() { + r.bufferMu.Lock() + r.forwardersGeneration.Inc() + forwarderWaitGroup := r.forwardersWaitGroup + r.bufferMu.Unlock() + + if forwarderWaitGroup != nil { + forwarderWaitGroup.Wait() + } +} + +func (r *ReceiverBase) startForwarderForBufferLocked(layer int32, buff buffer.BufferProvider) { + if r.restartInProgress { + return + } + + r.forwardersWaitGroup.Add(1) + + r.params.Logger.Debugw("starting forwarder", "layer", layer) + go r.forwardRTP(layer, buff, r.forwardersGeneration.Load(), r.forwardersWaitGroup) +} + +func (r *ReceiverBase) forwardRTP( + layer int32, + buff buffer.BufferProvider, + forwarderGeneration uint32, + wg *sync.WaitGroup, +) { + var ( + extPkt *buffer.ExtPacket + err error + ) + + numPacketsForwarded := 0 + numPacketsDropped := 0 + defer func() { + if err == io.EOF { + if r.params.IsSelfClosing { + r.Close("forwarder-done", false) + + r.streamTrackerManager.RemoveTracker(layer) + if r.videoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + r.streamTrackerManager.RemoveAllTrackers() + } + } + } + + r.params.Logger.Debugw( + "closing forwarder", + "layer", layer, + "numPacketsForwarded", numPacketsForwarded, + "numPacketsDropped", numPacketsDropped, + "forwarderGeneration", forwarderGeneration, + "forwardersGeneration", r.forwardersGeneration.Load(), + "error", err, + ) + wg.Done() + }() + + var spatialTrackers [buffer.DefaultMaxLayerSpatial + 1]streamtracker.StreamTrackerWorker + if layer < 0 || int(layer) >= len(spatialTrackers) { + r.params.Logger.Errorw("invalid layer", nil, "layer", layer) + return + } + + pktBuf := make([]byte, bucket.RTPMaxPktSize) + r.params.Logger.Debugw("starting forwarding", "layer", layer, "forwarderGeneration", forwarderGeneration) + + for r.forwardersGeneration.Load() == forwarderGeneration { + extPkt, err = buff.ReadExtended(pktBuf) + if err == io.EOF { + return + } + if extPkt == nil { + continue + } + dequeuedAt := mono.UnixNano() + + if extPkt.Packet.PayloadType != uint8(r.params.Codec.PayloadType) { + // drop packets as we don't support codec fallback directly + r.params.Logger.Debugw( + "dropping packet - payload mismatch", + "packetPayloadType", extPkt.Packet.PayloadType, + "payloadType", r.params.Codec.PayloadType, + ) + numPacketsDropped++ + continue + } + + spatialLayer := layer + if extPkt.Spatial >= 0 { + // svc packet, take spatial layer info from packet + spatialLayer = extPkt.Spatial + } + if int(spatialLayer) >= len(spatialTrackers) { + r.params.Logger.Errorw( + "unexpected spatial layer", nil, + "spatialLayer", spatialLayer, + "pktSpatialLayer", extPkt.Spatial, + ) + numPacketsDropped++ + continue + } + + var writeCount atomic.Int32 + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + writeCount.Add(dt.WriteRTP(extPkt, spatialLayer)) + }) + if rt := r.loadREDTransformer(); rt != nil { + writeCount.Add(rt.ForwardRTP(extPkt, spatialLayer)) + } + + // track delay/jitter + if writeCount.Load() > 0 && r.forwardStats != nil && !extPkt.IsBuffered { + if latency, isHigh := r.forwardStats.Update(extPkt.Arrival, mono.UnixNano()); isHigh { + r.params.Logger.Debugw( + "high forwarding latency", + "latency", time.Duration(latency), + "queuingLatency", time.Duration(dequeuedAt-extPkt.Arrival), + "writeCount", writeCount.Load(), + "isOutOfOrder", extPkt.IsOutOfOrder, + "layer", layer, + ) + } + } + + // track video layers + if r.Kind() == webrtc.RTPCodecTypeVideo { + if spatialTrackers[spatialLayer] == nil { + spatialTrackers[spatialLayer] = r.streamTrackerManager.GetTracker(spatialLayer) + if spatialTrackers[spatialLayer] == nil { + if r.videoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM && extPkt.DependencyDescriptor != nil { + r.streamTrackerManager.AddDependencyDescriptorTrackers() + } + spatialTrackers[spatialLayer] = r.streamTrackerManager.AddTracker(spatialLayer) + } + } + if spatialTrackers[spatialLayer] != nil { + spatialTrackers[spatialLayer].Observe( + extPkt.Temporal, + len(extPkt.RawPacket), + len(extPkt.Packet.Payload), + extPkt.Packet.Marker, + extPkt.Packet.Timestamp, + extPkt.DependencyDescriptor, + ) + } + } + + numPacketsForwarded++ + + buffer.ReleaseExtPacket(extPkt) + } +} + +func (r *ReceiverBase) DebugInfo() map[string]any { + videoLayerMode := buffer.GetVideoLayerModeForMimeType(r.Mime(), r.TrackInfo()) + info := map[string]any{ + "Mime": r.Mime().String(), + "VideoLayerMode": videoLayerMode.String(), + } + + return info +} + +func (r *ReceiverBase) GetPrimaryReceiverForRed() TrackReceiver { + r.bufferMu.Lock() + defer r.bufferMu.Unlock() + + if !r.isRED || r.IsClosed() { + return r + } + + rt := r.loadREDTransformer() + if rt == nil { + pr := NewRedPrimaryReceiver(r, sfuutils.DownTrackSpreaderParams{ + Threshold: r.lbThreshold, + Logger: r.params.Logger, + }) + r.redTransformer.Store(&pr) + return pr + } else { + if pr, ok := rt.(*RedPrimaryReceiver); ok { + return pr + } + } + return nil +} + +func (r *ReceiverBase) GetRedReceiver() TrackReceiver { + r.bufferMu.Lock() + defer r.bufferMu.Unlock() + + if r.isRED || r.IsClosed() { + return r + } + + rt := r.loadREDTransformer() + if rt == nil { + pr := NewRedReceiver(r, sfuutils.DownTrackSpreaderParams{ + Threshold: r.lbThreshold, + Logger: r.params.Logger, + }) + r.redTransformer.Store(&pr) + return pr + } else { + if pr, ok := rt.(*RedReceiver); ok { + return pr + } + } + return nil +} + +func (r *ReceiverBase) GetTemporalLayerFpsForSpatial(layer int32) []float32 { + b, _ := r.getBuffer(layer) + if b == nil { + return nil + } + + if r.videoLayerMode != livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM { + return b.GetTemporalLayerFpsForSpatial(0) + } + + return b.GetTemporalLayerFpsForSpatial(layer) +} + +func (r *ReceiverBase) AddOnReady(fn func()) { + // receiver is always ready after created + fn() +} + +func (w *ReceiverBase) handleCodecChange(newCodec webrtc.RTPCodecParameters) { + // codec fallback is not supported mid-session, i.e. change of codec via payload type change, + // set the codec state to invalid once it happens + w.SetCodecState(ReceiverCodecStateInvalid) +} + +func (r *ReceiverBase) AddOnCodecStateChange(f func(webrtc.RTPCodecParameters, ReceiverCodecState)) { + r.codecStateLock.Lock() + r.onCodecStateChange = append(r.onCodecStateChange, f) + r.codecStateLock.Unlock() +} + +func (r *ReceiverBase) CodecState() ReceiverCodecState { + r.codecStateLock.Lock() + defer r.codecStateLock.Unlock() + + return r.codecState +} + +func (r *ReceiverBase) SetCodecState(state ReceiverCodecState) { + r.codecStateLock.Lock() + if r.codecState == state || r.codecState == ReceiverCodecStateInvalid { + r.codecStateLock.Unlock() + return + } + + r.codecState = state + fns := r.onCodecStateChange + r.codecStateLock.Unlock() + + for _, f := range fns { + f(r.params.Codec, state) + } +} + +func (r *ReceiverBase) SetCodecWithState(codec webrtc.RTPCodecParameters, headerExtensions []webrtc.RTPHeaderExtensionParameter, codecState ReceiverCodecState) { + r.checkCodecChanged(codec, headerExtensions) + + r.codecStateLock.Lock() + if codecState == r.codecState { + r.codecStateLock.Unlock() + return + } + + var fireChange bool + var reason string + onCodecStateChange := r.onCodecStateChange + r.params.Logger.Infow("codec state changed", "from", r.codecState, "to", codecState) + switch codecState { + case ReceiverCodecStateNormal: + // TODO: support codec recovery + r.codecStateLock.Unlock() + return + + case ReceiverCodecStateSuspended: + reason = "codec suspended" + fallthrough + + case ReceiverCodecStateInvalid: + r.codecState = codecState + fireChange = true + reason = "codec invalid" + } + r.codecStateLock.Unlock() + + if fireChange { + r.ClearAllBuffers(reason) + + for _, fn := range onCodecStateChange { + fn(r.params.Codec, codecState) + } + } +} + +func (r *ReceiverBase) checkCodecChanged(codec webrtc.RTPCodecParameters, headerExtensions []webrtc.RTPHeaderExtensionParameter) { + existingFmtp := strings.Split(r.params.Codec.SDPFmtpLine, ";") + slices.Sort(existingFmtp) + checkFmtp := strings.Split(codec.SDPFmtpLine, ";") + slices.Sort(checkFmtp) + if !mime.IsMimeTypeStringEqual(r.params.Codec.MimeType, codec.MimeType) || !slices.Equal(existingFmtp, checkFmtp) || + r.params.Codec.ClockRate != codec.ClockRate { + err := fmt.Errorf("mime: %s -> %s, fmtp: %s -> %s, clockRate: %d -> %d", + r.params.Codec.MimeType, codec.MimeType, + r.params.Codec.SDPFmtpLine, codec.SDPFmtpLine, + r.params.Codec.ClockRate, codec.ClockRate, + ) + r.params.Logger.Errorw("unexpected change in codec", err) + } + + if len(r.params.HeaderExtensions) != len(headerExtensions) { + err := fmt.Errorf("extensions: %d -> %d", len(r.params.HeaderExtensions), len(headerExtensions)) + r.params.Logger.Errorw("unexpected change in extensions length", err) + } +} + +func (r *ReceiverBase) VideoSizes() []buffer.VideoSize { + var sizes []buffer.VideoSize + r.videoSizeMu.RLock() + defer r.videoSizeMu.RUnlock() + for _, v := range r.videoSizes { + if v.Width == 0 || v.Height == 0 { + break + } + sizes = append(sizes, v) + } + + return sizes +} + +func (r *ReceiverBase) OnVideoSizeChanged(f func()) { + r.videoSizeMu.Lock() + r.onVideoSizeChanged = f + r.videoSizeMu.Unlock() +} + +func (r *ReceiverBase) loadREDTransformer() REDTransformer { + if rt := r.redTransformer.Load(); rt != nil { + return *rt + } + + return nil +} + +// ----------------------------------------------------------- + +// closes all track senders in parallel, returns when all are closed +func closeTrackSenders(senders []TrackSender) { + wg := sync.WaitGroup{} + for _, dt := range senders { + dt := dt + wg.Add(1) + go func() { + defer wg.Done() + dt.Close() + }() + } + wg.Wait() +} diff --git a/livekit/pkg/sfu/receiver_test.go b/livekit/pkg/sfu/receiver_test.go new file mode 100644 index 0000000..3cfde07 --- /dev/null +++ b/livekit/pkg/sfu/receiver_test.go @@ -0,0 +1,213 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "fmt" + "hash/fnv" + "math/rand" + "runtime" + "sync" + "testing" + + "github.com/gammazero/workerpool" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" +) + +func TestWebRTCReceiver_OnCloseHandler(t *testing.T) { + type args struct { + fn func() + } + tests := []struct { + name string + args args + }{ + { + name: "Must set on close handler function", + args: args{ + fn: func() {}, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + w := &WebRTCReceiver{} + w.OnCloseHandler(tt.args.fn) + assert.NotNil(t, w.onCloseHandler) + }) + } +} + +func BenchmarkWriteRTP(b *testing.B) { + cases := []int{1, 2, 5, 10, 100, 250, 500} + workers := runtime.NumCPU() + wp := workerpool.New(workers) + for _, c := range cases { + // fills each bucket with a max of 50, i.e. []int{50, 50} for c=100 + fill := make([]int, 0) + for i := 50; ; i += 50 { + if i > c { + fill = append(fill, c%50) + break + } + + fill = append(fill, 50) + if i == c { + break + } + } + + // splits c into numCPU buckets, i.e. []int{9, 9, 9, 9, 8, 8, 8, 8, 8, 8, 8, 8} for 12 cpus and c=100 + split := make([]int, workers) + for i := range split { + split[i] = c / workers + } + for i := 0; i < c%workers; i++ { + split[i]++ + } + + b.Run(fmt.Sprintf("%d-Downtracks/Control", c), func(b *testing.B) { + benchmarkNoPool(b, c) + }) + b.Run(fmt.Sprintf("%d-Downtracks/Pool(Fill)", c), func(b *testing.B) { + benchmarkPool(b, wp, fill) + }) + b.Run(fmt.Sprintf("%d-Downtracks/Pool(Hash)", c), func(b *testing.B) { + benchmarkPool(b, wp, split) + }) + b.Run(fmt.Sprintf("%d-Downtracks/Goroutines", c), func(b *testing.B) { + benchmarkGoroutine(b, split) + }) + b.Run(fmt.Sprintf("%d-Downtracks/LoadBalanced", c), func(b *testing.B) { + benchmarkLoadBalanced(b, workers, 2, c) + }) + b.Run(fmt.Sprintf("%d-Downtracks/LBPool", c), func(b *testing.B) { + benchmarkLoadBalancedPool(b, wp, workers, 2, c) + }) + } +} + +func benchmarkNoPool(b *testing.B, downTracks int) { + for b.Loop() { + for range downTracks { + writeRTP() + } + } +} + +func benchmarkPool(b *testing.B, wp *workerpool.WorkerPool, buckets []int) { + for b.Loop() { + var wg sync.WaitGroup + for j := range buckets { + downTracks := buckets[j] + if downTracks == 0 { + continue + } + wg.Add(1) + wp.Submit(func() { + defer wg.Done() + for dt := 0; dt < downTracks; dt++ { + writeRTP() + } + }) + } + wg.Wait() + } +} + +func benchmarkGoroutine(b *testing.B, buckets []int) { + for b.Loop() { + var wg sync.WaitGroup + for j := range buckets { + downTracks := buckets[j] + if downTracks == 0 { + continue + } + wg.Add(1) + go func() { + defer wg.Done() + for dt := 0; dt < downTracks; dt++ { + writeRTP() + } + }() + } + wg.Wait() + } +} + +func benchmarkLoadBalanced(b *testing.B, numProcs, step, downTracks int) { + for b.Loop() { + start := atomic.NewUint64(0) + step := uint64(step) + end := uint64(downTracks) + + var wg sync.WaitGroup + wg.Add(numProcs) + for p := 0; p < numProcs; p++ { + go func() { + defer wg.Done() + for { + n := start.Add(step) + if n >= end+step { + return + } + + for i := n - step; i < n && i < end; i++ { + writeRTP() + } + } + }() + } + wg.Wait() + } +} + +func benchmarkLoadBalancedPool(b *testing.B, wp *workerpool.WorkerPool, numProcs, step, downTracks int) { + for b.Loop() { + start := atomic.NewUint64(0) + step := uint64(step) + end := uint64(downTracks) + + var wg sync.WaitGroup + wg.Add(numProcs) + for p := 0; p < numProcs; p++ { + wp.Submit(func() { + defer wg.Done() + for { + n := start.Add(step) + if n >= end+step { + return + } + + for i := n - step; i < n && i < end; i++ { + writeRTP() + } + } + }) + } + wg.Wait() + } +} + +func writeRTP() { + s := []byte("simulate some work") + stop := 1900 + rand.Intn(200) + for j := 0; j < stop; j++ { + h := fnv.New128() + s = h.Sum(s) + } +} diff --git a/livekit/pkg/sfu/redprimaryreceiver.go b/livekit/pkg/sfu/redprimaryreceiver.go new file mode 100644 index 0000000..e054481 --- /dev/null +++ b/livekit/pkg/sfu/redprimaryreceiver.go @@ -0,0 +1,348 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "encoding/binary" + "errors" + + "go.uber.org/atomic" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +var _ REDTransformer = (*RedPrimaryReceiver)(nil) + +var ( + ErrIncompleteRedHeader = errors.New("incomplete red block header") + ErrIncompleteRedBlock = errors.New("incomplete red block payload") +) + +type RedPrimaryReceiver struct { + TrackReceiver + downTrackSpreader *utils.DownTrackSpreader[TrackSender] + logger logger.Logger + closed atomic.Bool + redPT uint8 + + firstPktReceived bool + lastSeq uint16 + + // bitset for upstream packet receive history [lastSeq-8, lastSeq-1], bit 1 represents packet received + pktHistory byte +} + +func NewRedPrimaryReceiver(receiver TrackReceiver, dsp utils.DownTrackSpreaderParams) REDTransformer { + return &RedPrimaryReceiver{ + TrackReceiver: receiver, + downTrackSpreader: utils.NewDownTrackSpreader[TrackSender](dsp), + logger: dsp.Logger, + redPT: uint8(receiver.Codec().PayloadType), + } +} + +func (r *RedPrimaryReceiver) ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int32) int32 { + // extract primary payload from RED and forward to downtracks + if r.downTrackSpreader.DownTrackCount() == 0 { + return 0 + } + + if pkt.Packet.PayloadType != r.redPT { + // forward non-red packet directly + var writeCount atomic.Int32 + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + writeCount.Add(dt.WriteRTP(pkt, spatialLayer)) + }) + return writeCount.Load() + } + + pkts, err := r.getSendPktsFromRed(pkt.Packet) + if err != nil { + r.logger.Errorw("get encoding for red failed", err, "payloadtype", pkt.Packet.PayloadType) + return 0 + } + + var writeCount atomic.Int32 + for i, sendPkt := range pkts { + pPkt := *pkt + if i != len(pkts)-1 { + // patch extended sequence number and time stamp for all but the last packet, + // last packet is the primary payload + pPkt.ExtSequenceNumber -= uint64(pkts[len(pkts)-1].SequenceNumber - pkts[i].SequenceNumber) + pPkt.ExtTimestamp -= uint64(pkts[len(pkts)-1].Timestamp - pkts[i].Timestamp) + } + pPkt.Packet = sendPkt + + // not modify the ExtPacket.RawPacket here for performance since it is not used by the DownTrack, + // otherwise it should be set to the correct value (marshal the primary rtp packet) + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + writeCount.Add(dt.WriteRTP(&pPkt, spatialLayer)) + }) + } + return writeCount.Load() +} + +func (r *RedPrimaryReceiver) ForwardRTCPSenderReport( + payloadType webrtc.PayloadType, + layer int32, + publisherSRData *livekit.RTCPSenderReportState, +) { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + _ = dt.HandleRTCPSenderReportData(payloadType, layer, publisherSRData) + }) +} + +func (r *RedPrimaryReceiver) AddDownTrack(track TrackSender) error { + if r.closed.Load() { + return ErrReceiverClosed + } + + if r.downTrackSpreader.HasDownTrack(track.SubscriberID()) { + r.logger.Infow("subscriberID already exists, replacing downtrack", "subscriberID", track.SubscriberID()) + } + + r.downTrackSpreader.Store(track) + r.logger.Debugw("red primary receiver downtrack added", "subscriberID", track.SubscriberID()) + return nil +} + +func (r *RedPrimaryReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { + if r.closed.Load() { + return + } + + r.downTrackSpreader.Free(subscriberID) + r.logger.Debugw("red primary receiver downtrack deleted", "subscriberID", subscriberID) +} + +func (r *RedPrimaryReceiver) GetDownTracks() []TrackSender { + return r.downTrackSpreader.GetDownTracks() +} + +func (r *RedPrimaryReceiver) ResyncDownTracks() { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.Resync() + }) +} + +func (r *RedPrimaryReceiver) OnStreamRestart() { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.ReceiverRestart() + }) +} + +func (r *RedPrimaryReceiver) IsClosed() bool { + return r.closed.Load() +} + +func (r *RedPrimaryReceiver) CanClose() bool { + return r.closed.Load() || r.downTrackSpreader.DownTrackCount() == 0 +} + +func (r *RedPrimaryReceiver) Close() { + r.closed.Store(true) + closeTrackSenders(r.downTrackSpreader.ResetAndGetDownTracks()) +} + +func (r *RedPrimaryReceiver) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { + n, err := r.TrackReceiver.ReadRTP(buf, layer, esn) + if err != nil { + return n, err + } + + var pkt rtp.Packet + pkt.Unmarshal(buf[:n]) + payload, err := extractPrimaryEncodingForRED(pkt.Payload) + if err != nil { + return 0, err + } + pkt.Payload = payload + + return pkt.MarshalTo(buf) +} + +func (r *RedPrimaryReceiver) getSendPktsFromRed(rtp *rtp.Packet) ([]*rtp.Packet, error) { + var needRecover bool + if !r.firstPktReceived { + r.lastSeq = rtp.SequenceNumber + r.pktHistory = 0 + r.firstPktReceived = true + } else { + diff := rtp.SequenceNumber - r.lastSeq + switch { + case diff == 0: // duplicate + break + + case diff > 0x8000: // unorder + // in history + if 65535-diff < 8 { + r.pktHistory |= 1 << (65535 - diff) + needRecover = true + } + + case diff > 8: // long jump + r.lastSeq = rtp.SequenceNumber + r.pktHistory = 0 + needRecover = true + + default: + r.lastSeq = rtp.SequenceNumber + r.pktHistory = (r.pktHistory << byte(diff)) | 1<<(diff-1) + needRecover = true + } + } + + var recoverBits byte + if needRecover { + bitIndex := r.lastSeq - rtp.SequenceNumber + for i := range maxRedCount { + if bitIndex > 7 { + break + } + if r.pktHistory&byte(1<>= 10 + tsOffset := blockHead & 0x3FFF + blockHead >>= 14 + pt := uint8(blockHead & 0x7F) + + blocks = append(blocks, block{pt: pt, length: length, tsOffset: tsOffset}) + + blockLength += length + payload = payload[4:] + } + } + + if len(payload) < blockLength { + return nil, ErrIncompleteRedBlock + } + + pkts := make([]*rtp.Packet, 0, len(blocks)) + for i, b := range blocks { + if b.primary { + header := redPkt.Header + header.PayloadType = b.pt + pkts = append(pkts, &rtp.Packet{Header: header, Payload: payload}) + break + } + + recoverIndex := len(blocks) - i - 1 + if recoverIndex < 1 || recoverBits&(1<<(recoverIndex-1)) == 0 { + // skip past packet/block that does not need recovery + payload = payload[b.length:] + continue + } + + // recover missing packet + header := redPkt.Header + header.SequenceNumber -= uint16(recoverIndex) + header.Timestamp -= b.tsOffset + header.PayloadType = b.pt + pkts = append(pkts, &rtp.Packet{Header: header, Payload: payload[:b.length]}) + + payload = payload[b.length:] + } + + return pkts, nil +} + +func extractPrimaryEncodingForRED(payload []byte) ([]byte, error) { + + /* RED payload https://datatracker.ietf.org/doc/html/rfc2198#section-3 + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |F| block PT | timestamp offset | block length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + F: 1 bit First bit in header indicates whether another header block + follows. If 1 further header blocks follow, if 0 this is the + last header block. + */ + + var blockLength int + for { + if len(payload) < 1 { + // illegal data, need at least one byte for primary encoding + return nil, ErrIncompleteRedHeader + } + + if payload[0]&0x80 == 0 { + // last block is primary encoding data + payload = payload[1:] + break + } else { + if len(payload) < 4 { + // illegal data + return nil, ErrIncompleteRedHeader + } + + blockLength += int(binary.BigEndian.Uint16(payload[2:]) & 0x03FF) + payload = payload[4:] + } + } + + if len(payload) < blockLength { + return nil, ErrIncompleteRedBlock + } + + return payload[blockLength:], nil +} diff --git a/livekit/pkg/sfu/redreceiver.go b/livekit/pkg/sfu/redreceiver.go new file mode 100644 index 0000000..9c93a3f --- /dev/null +++ b/livekit/pkg/sfu/redreceiver.go @@ -0,0 +1,250 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "encoding/binary" + "fmt" + + "go.uber.org/atomic" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/mediatransportutil/pkg/bucket" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" +) + +var _ REDTransformer = (*RedReceiver)(nil) + +const ( + maxRedCount = 2 + mtuSize = 1500 + maxRedPayload = 1 << 10 // fit into 10 bits length field + + // the RedReceiver is only for chrome / native webrtc now, we always negotiate opus payload to 111 with those clients, + // so it is safe to use a fixed payload 111 here for performance(avoid encoding red blocks for each downtrack that + // have a different opus payload type). + opusPT = 111 + opusRedPT = 63 +) + +type RedReceiver struct { + TrackReceiver + downTrackSpreader *utils.DownTrackSpreader[TrackSender] + logger logger.Logger + closed atomic.Bool + pktBuff [maxRedCount]*rtp.Packet + redPayloadBuf [mtuSize]byte +} + +func NewRedReceiver(receiver TrackReceiver, dsp utils.DownTrackSpreaderParams) REDTransformer { + return &RedReceiver{ + TrackReceiver: receiver, + downTrackSpreader: utils.NewDownTrackSpreader[TrackSender](dsp), + logger: dsp.Logger, + } +} + +func (r *RedReceiver) ForwardRTP(pkt *buffer.ExtPacket, spatialLayer int32) int32 { + // encode RED payload from primary payload and forward to downtracks + if r.downTrackSpreader.DownTrackCount() == 0 { + return 0 + } + + // fallback to primary codec if payload size exceeds redundant block length + if len(pkt.Packet.Payload) >= maxRedPayload { + var writeCount atomic.Int32 + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + writeCount.Add(dt.WriteRTP(pkt, spatialLayer)) + }) + return writeCount.Load() + } + + redLen, err := r.encodeRedForPrimary(pkt.Packet, r.redPayloadBuf[:]) + if err != nil { + r.logger.Errorw("red encoding failed", err) + return 0 + } + + pPkt := *pkt + redRtpPacket := *pkt.Packet + redRtpPacket.PayloadType = 63 + redRtpPacket.Payload = r.redPayloadBuf[:redLen] + pPkt.Packet = &redRtpPacket + + // not modify the ExtPacket.RawPacket here for performance since it is not used by the DownTrack, + // otherwise it should be set to the correct value (marshal the primary rtp packet) + var writeCount atomic.Int32 + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + writeCount.Add(dt.WriteRTP(&pPkt, spatialLayer)) + }) + return writeCount.Load() +} + +func (r *RedReceiver) ForwardRTCPSenderReport( + payloadType webrtc.PayloadType, + layer int32, + publisherSRData *livekit.RTCPSenderReportState, +) { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + _ = dt.HandleRTCPSenderReportData(payloadType, layer, publisherSRData) + }) +} + +func (r *RedReceiver) AddDownTrack(track TrackSender) error { + if r.closed.Load() { + return ErrReceiverClosed + } + + if r.downTrackSpreader.HasDownTrack(track.SubscriberID()) { + r.logger.Infow("subscriberID already exists, replacing downtrack", "subscriberID", track.SubscriberID()) + } + + r.downTrackSpreader.Store(track) + r.logger.Debugw("red receiver downtrack added", "subscriberID", track.SubscriberID()) + return nil +} + +func (r *RedReceiver) DeleteDownTrack(subscriberID livekit.ParticipantID) { + if r.closed.Load() { + return + } + + r.downTrackSpreader.Free(subscriberID) + r.logger.Debugw("red receiver downtrack deleted", "subscriberID", subscriberID) +} + +func (r *RedReceiver) GetDownTracks() []TrackSender { + return r.downTrackSpreader.GetDownTracks() +} + +func (r *RedReceiver) ResyncDownTracks() { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.Resync() + }) +} + +func (r *RedReceiver) OnStreamRestart() { + r.downTrackSpreader.Broadcast(func(dt TrackSender) { + dt.ReceiverRestart() + }) +} + +func (r *RedReceiver) CanClose() bool { + return r.closed.Load() || r.downTrackSpreader.DownTrackCount() == 0 +} + +func (r *RedReceiver) IsClosed() bool { + return r.closed.Load() +} + +func (r *RedReceiver) Close() { + r.closed.Store(true) + closeTrackSenders(r.downTrackSpreader.ResetAndGetDownTracks()) +} + +func (r *RedReceiver) ReadRTP(buf []byte, layer uint8, esn uint64) (int, error) { + // red encoding doesn't support nack + return 0, bucket.ErrPacketMismatch +} + +func (r *RedReceiver) encodeRedForPrimary(pkt *rtp.Packet, redPayload []byte) (int, error) { + redLength := len(r.pktBuff) + redPkts := make([]*rtp.Packet, 0, redLength+1) + lastNilPkt := -1 + for i := redLength - 1; i >= 0; i-- { + if r.pktBuff[i] == nil { + lastNilPkt = i + break + } + + } + + for _, prev := range r.pktBuff[lastNilPkt+1:] { + if pkt.SequenceNumber == prev.SequenceNumber || + (pkt.SequenceNumber-prev.SequenceNumber) > uint16(redLength) || + (pkt.Timestamp-prev.Timestamp) >= (1<<14) { + continue + } + redPkts = append(redPkts, prev) + } + + // insert primary packet in history buffer + // NOTE: packet is copied from retransmission buffer and used in forwarding path. So, not making another + // copy here and just maintaining pointer to the packet as the forwarding path should not alter the packet. + for i := redLength - 1; i >= 0; i-- { + if r.pktBuff[i] == nil || // history is empty + pkt.SequenceNumber-r.pktBuff[i].SequenceNumber < (1<<15) { // received packet has more recent sequence number + // age out older ones + for j := 0; j < i; j++ { + r.pktBuff[j] = r.pktBuff[j+1] + } + r.pktBuff[i] = pkt + break + } + } + + return encodeRedForPrimary(redPkts, pkt, redPayload) +} + +func encodeRedForPrimary(redPkts []*rtp.Packet, primary *rtp.Packet, redPayload []byte) (int, error) { + payloadSize := len(primary.Payload) + 1 + for _, p := range redPkts { + payloadSize += len(p.Payload) + 4 + } + + // if required payload size is larger than the redPayload buffer, encode the primary packet only + if payloadSize > len(redPayload) { + redPkts = redPkts[:0] + } + + var index int + for _, p := range redPkts { + /* RED payload https://datatracker.ietf.org/doc/html/rfc2198#section-3 + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |F| block PT | timestamp offset | block length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + F: 1 bit First bit in header indicates whether another header block + follows. If 1 further header blocks follow, if 0 this is the + last header block. + */ + header := uint32(0x80 | uint8(opusPT)) + header <<= 14 + header |= (primary.Timestamp - p.Timestamp) & 0x3FFF + header <<= 10 + header |= uint32(len(p.Payload)) & 0x3FF + binary.BigEndian.PutUint32(redPayload[index:], header) + index += 4 + } + // last block header + redPayload[index] = uint8(opusPT) + index++ + + // append data blocks + redPkts = append(redPkts, primary) + for _, p := range redPkts { + if copy(redPayload[index:], p.Payload) < len(p.Payload) { + return 0, fmt.Errorf("red payload don't have enough space, needsize %d", len(p.Payload)) + } + index += len(p.Payload) + } + return index, nil +} diff --git a/livekit/pkg/sfu/redreceiver_test.go b/livekit/pkg/sfu/redreceiver_test.go new file mode 100644 index 0000000..bda7307 --- /dev/null +++ b/livekit/pkg/sfu/redreceiver_test.go @@ -0,0 +1,457 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "testing" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" +) + +const ( + tsStep = uint32(48000 / 1000 * 10) + opusREDPT = 63 +) + +type dummyDowntrack struct { + TrackSender + lastReceivedPkt *rtp.Packet + receivedPkts []*rtp.Packet +} + +func (dt *dummyDowntrack) WriteRTP(p *buffer.ExtPacket, _ int32) int32 { + dt.lastReceivedPkt = p.Packet + dt.receivedPkts = append(dt.receivedPkts, p.Packet) + return 1 +} + +func TestRedReceiver(t *testing.T) { + dt := &dummyDowntrack{TrackSender: &DownTrack{}} + + t.Run("normal", func(t *testing.T) { + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + }, + isRED: true, + }, + } + require.Equal(t, w.GetRedReceiver(), w.ReceiverBase) + w.isRED = false + red := w.GetRedReceiver().(*RedReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + expectPkt := make([]*rtp.Packet, 0, maxRedCount+1) + for _, pkt := range generatePkts(header, 10, tsStep) { + expectPkt = append(expectPkt, pkt) + if len(expectPkt) > maxRedCount+1 { + expectPkt = expectPkt[1:] + } + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt, + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + } + }) + + t.Run("packet lost and jump", func(t *testing.T) { + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + }, + }, + } + red := w.GetRedReceiver().(*RedReceiver) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + expectPkt := make([]*rtp.Packet, 0, maxRedCount+1) + for i := range 10 { + if i%2 == 0 { + header.SequenceNumber++ + header.Timestamp += tsStep + expectPkt = append(expectPkt, nil) + continue + } + hbuf, _ := header.Marshal() + pkt1 := &rtp.Packet{ + Header: header, + Payload: hbuf, + } + expectPkt = append(expectPkt, pkt1) + if len(expectPkt) > maxRedCount+1 { + expectPkt = expectPkt[len(expectPkt)-maxRedCount-1:] + } + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt1, + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + header.SequenceNumber++ + header.Timestamp += tsStep + } + + // jump + header.SequenceNumber += 10 + header.Timestamp += 10 * tsStep + + expectPkt = expectPkt[:0] + for _, pkt := range generatePkts(header, 3, tsStep) { + expectPkt = append(expectPkt, pkt) + if len(expectPkt) > maxRedCount+1 { + expectPkt = expectPkt[len(expectPkt)-maxRedCount-1:] + } + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt, + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + } + }) + + t.Run("unorder and repeat", func(t *testing.T) { + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + }, + }, + } + red := w.GetRedReceiver().(*RedReceiver) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + var prevPkts []*rtp.Packet + for _, pkt := range generatePkts(header, 10, tsStep) { + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt, + }, 0) + prevPkts = append(prevPkts, pkt) + } + + // old unorder data don't have red records + expectPkt := prevPkts[len(prevPkts)-3 : len(prevPkts)-2] + red.ForwardRTP(&buffer.ExtPacket{ + Packet: expectPkt[0], + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + + // repeat packet only have 1 red records + expectPkt = prevPkts[len(prevPkts)-2:] + red.ForwardRTP(&buffer.ExtPacket{ + Packet: expectPkt[1], + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + }) + + t.Run("encoding exceed space", func(t *testing.T) { + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + }, + isRED: true, + }, + } + require.Equal(t, w.GetRedReceiver(), w.ReceiverBase) + w.isRED = false + red := w.GetRedReceiver().(*RedReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + expectPkt := make([]*rtp.Packet, 0, maxRedCount+1) + for _, pkt := range generatePkts(header, 10, tsStep) { + // make sure red encodings don't have enough space to encoding redundant packet + pkt.Payload = make([]byte, 1000) + expectPkt = append(expectPkt[:0], pkt) + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt, + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + } + }) + + t.Run("large timestamp gap", func(t *testing.T) { + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + }, + isRED: true, + }, + } + require.Equal(t, w.GetRedReceiver(), w.ReceiverBase) + w.isRED = false + red := w.GetRedReceiver().(*RedReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65534, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + // first few packets normal + expectPkt := make([]*rtp.Packet, 0, maxRedCount+1) + for _, pkt := range generatePkts(header, 4, tsStep) { + expectPkt = append(expectPkt, pkt) + if len(expectPkt) > maxRedCount+1 { + expectPkt = expectPkt[1:] + } + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt, + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, expectPkt) + } + + // and then a few packets with a large timestamp jump, should contain only primary + for _, pkt := range generatePkts(header, 4, 40*tsStep) { + red.ForwardRTP(&buffer.ExtPacket{ + Packet: pkt, + }, 0) + verifyRedEncodings(t, dt.lastReceivedPkt, []*rtp.Packet{pkt}) + } + }) +} + +func verifyRedEncodings(t *testing.T, red *rtp.Packet, redPkts []*rtp.Packet) { + solidPkts := make([]*rtp.Packet, 0, len(redPkts)) + for _, pkt := range redPkts { + if pkt != nil { + solidPkts = append(solidPkts, pkt) + } + } + pktsFromRed, err := extractPktsFromRed(red, 0xFF) + require.NoError(t, err) + require.Len(t, pktsFromRed, len(solidPkts)) + for i, pkt := range pktsFromRed { + verifyEncodingEqual(t, pkt, solidPkts[i]) + } +} + +func verifyPktsEqual(t *testing.T, p1s, p2s []*rtp.Packet) { + require.Len(t, p1s, len(p2s)) + for i, pkt := range p1s { + verifyEncodingEqual(t, pkt, p2s[i]) + } +} + +func verifyEncodingEqual(t *testing.T, p1, p2 *rtp.Packet) { + require.Equal(t, p1.Header.Timestamp, p2.Header.Timestamp) + require.Equal(t, p1.PayloadType, p2.PayloadType) + require.EqualValues(t, p1.Payload, p2.Payload, "seq1 %s", p1.SequenceNumber) +} + +func generatePkts(header rtp.Header, count int, tsStep uint32) []*rtp.Packet { + pkts := make([]*rtp.Packet, 0, count) + for range count { + hbuf, _ := header.Marshal() + pkts = append(pkts, &rtp.Packet{ + Header: header, + Payload: hbuf, + }) + header.SequenceNumber++ + header.Timestamp += tsStep + } + return pkts +} + +func generateRedPkts(t *testing.T, pkts []*rtp.Packet, redCount int) []*rtp.Packet { + redPkts := make([]*rtp.Packet, 0, len(pkts)) + for i, pkt := range pkts { + encodingPkts := make([]*rtp.Packet, 0, redCount) + for j := i - redCount; j < i; j++ { + if j < 0 { + continue + } + encodingPkts = append(encodingPkts, pkts[j]) + } + buf := make([]byte, mtuSize) + redPkt := *pkt + redPkt.PayloadType = opusREDPT + encoded, err := encodeRedForPrimary(encodingPkts, pkt, buf) + require.NoError(t, err) + redPkt.Payload = buf[:encoded] + redPkts = append(redPkts, &redPkt) + } + return redPkts +} + +func testRedRedPrimaryReceiver(t *testing.T, maxPktCount, redCount int, sendPktIdx, expectPktIdx []int) { + dt := &dummyDowntrack{TrackSender: &DownTrack{}} + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + Codec: webrtc.RTPCodecParameters{PayloadType: opusREDPT, RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: "audio/red"}}, + }, + }, + } + require.Equal(t, w.GetPrimaryReceiverForRed(), w.ReceiverBase) + w.isRED = true + red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + header := rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + primaryPkts := generatePkts(header, maxPktCount, tsStep) + redPkts := generateRedPkts(t, primaryPkts, redCount) + + for _, i := range sendPktIdx { + red.ForwardRTP(&buffer.ExtPacket{ + Packet: redPkts[i], + }, 0) + } + + expectPkts := make([]*rtp.Packet, 0, len(expectPktIdx)) + for _, i := range expectPktIdx { + expectPkts = append(expectPkts, primaryPkts[i]) + } + + verifyPktsEqual(t, expectPkts, dt.receivedPkts) +} + +func TestRedPrimaryReceiver(t *testing.T) { + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + }, + }, + } + require.Equal(t, w.GetPrimaryReceiverForRed(), w.ReceiverBase) + w.isRED = true + red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) + require.NotNil(t, red) + + t.Run("packet should send only once", func(t *testing.T) { + maxPktCount := 19 + var sendPktIndex []int + for i := range maxPktCount { + sendPktIndex = append(sendPktIndex, i) + } + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, sendPktIndex) + }) + + t.Run("packet duplicate and unorder", func(t *testing.T) { + maxPktCount := 19 + var sendPktIndex []int + for i := range maxPktCount { + sendPktIndex = append(sendPktIndex, i) + if i > 0 { + sendPktIndex = append(sendPktIndex, i-1) + } + sendPktIndex = append(sendPktIndex, i) + } + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, sendPktIndex) + }) + + t.Run("full recover", func(t *testing.T) { + maxPktCount := 19 + var sendPktIndex, recvPktIndex []int + for i := range maxPktCount { + recvPktIndex = append(recvPktIndex, i) + + // drop packets covered by red encoding + if i%(maxRedCount+1) != 0 { + continue + } + sendPktIndex = append(sendPktIndex, i) + } + + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) + }) + + t.Run("lost 2 but red recover 1", func(t *testing.T) { + maxPktCount := 19 + sendPktIndex := []int{0, 3, 6, 9, 12} + recvPktIndex := []int{0, 2, 3, 5, 6, 8, 9, 11, 12} + testRedRedPrimaryReceiver(t, maxPktCount, 1, sendPktIndex, recvPktIndex) + }) + + t.Run("part recover and long jump", func(t *testing.T) { + maxPktCount := 50 + sendPktIndex := []int{0, 5, 12, 21 /*long jump*/, 24, 27} + recvPktIndex := []int{0, 3, 4, 5, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 27} + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) + }) + + t.Run("unorder", func(t *testing.T) { + maxPktCount := 50 + sendPktIndex := []int{20, 10 /*unorder can't recover*/, 25, 23, 34} + recvPktIndex := []int{20, 10, 23, 24, 25, 21, 22, 23, 32, 33, 34} + testRedRedPrimaryReceiver(t, maxPktCount, maxRedCount, sendPktIndex, recvPktIndex) + }) + + t.Run("mixed primary codec", func(t *testing.T) { + dt := &dummyDowntrack{TrackSender: &DownTrack{}} + w := &WebRTCReceiver{ + ReceiverBase: &ReceiverBase{ + params: ReceiverBaseParams{ + Kind: webrtc.RTPCodecTypeAudio, + Logger: logger.GetLogger(), + Codec: webrtc.RTPCodecParameters{PayloadType: opusREDPT, RTPCodecCapability: webrtc.RTPCodecCapability{MimeType: "audio/red"}}, + }, + }, + } + require.Equal(t, w.GetPrimaryReceiverForRed(), w.ReceiverBase) + w.isRED = true + red := w.GetPrimaryReceiverForRed().(*RedPrimaryReceiver) + require.NotNil(t, red) + require.NoError(t, red.AddDownTrack(dt)) + + primaryPkt := &rtp.Packet{ + Header: rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111}, + Payload: []byte{1, 3, 5, 7, 9}, + } + red.ForwardRTP(&buffer.ExtPacket{ + Packet: primaryPkt, + }, 0) + + verifyPktsEqual(t, []*rtp.Packet{primaryPkt}, dt.receivedPkts) + }) +} + +func TestExtractPrimaryEncodingForRED(t *testing.T) { + header := rtp.Header{SequenceNumber: 65530, Timestamp: (uint32(1) << 31) - 2*tsStep, PayloadType: 111} + pkts := generatePkts(header, 10, tsStep) + redPkts := generateRedPkts(t, pkts, 2) + + primaryPkts := make([]*rtp.Packet, 0, len(redPkts)) + + for _, redPkt := range redPkts { + payload, err := extractPrimaryEncodingForRED(redPkt.Payload) + require.NoError(t, err) + primaryHeader := redPkt.Header + primaryHeader.PayloadType = 111 + primaryPkts = append(primaryPkts, &rtp.Packet{ + Header: primaryHeader, + Payload: payload, + }) + } + + verifyPktsEqual(t, pkts, primaryPkts) +} diff --git a/livekit/pkg/sfu/rtpextension/abscapturetime/abscapturetime.go b/livekit/pkg/sfu/rtpextension/abscapturetime/abscapturetime.go new file mode 100644 index 0000000..1eb3f8d --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/abscapturetime/abscapturetime.go @@ -0,0 +1,114 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package abscapturetime + +import ( + "encoding/binary" + "errors" + "time" + + "github.com/livekit/mediatransportutil" +) + +const ( + AbsCaptureTimeURI = "http://www.webrtc.org/experiments/rtp-hdrext/abs-capture-time" +) + +var ( + errInvalidData = errors.New("invalid data") + errTooSmall = errors.New("buffer too small") +) + +// Reference: https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/abs-capture-time/ +// +// Data layout of the shortened version of abs-capture-time with a 1-byte header + 8 bytes of data: +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=7 | absolute capture timestamp (bit 0-23) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | absolute capture timestamp (bit 24-55) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ... (56-63) | +// +-+-+-+-+-+-+-+-+ +// +//Data layout of the extended version of abs-capture-time with a 1-byte header + 16 bytes of data: +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=15| absolute capture timestamp (bit 0-23) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | absolute capture timestamp (bit 24-55) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ... (56-63) | estimated capture clock offset (bit 0-23) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | estimated capture clock offset (bit 24-55) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ... (56-63) | +// +-+-+-+-+-+-+-+-+ + +type AbsCaptureTime struct { + absoluteCaptureTimestamp mediatransportutil.NtpTime + estimatedCaptureClockOffset int64 +} + +func AbsCaptureTimeFromValue(absoluteCaptureTimestamp uint64, estimatedCaptureClockOffset int64) *AbsCaptureTime { + return &AbsCaptureTime{ + absoluteCaptureTimestamp: mediatransportutil.NtpTime(absoluteCaptureTimestamp), + estimatedCaptureClockOffset: estimatedCaptureClockOffset, + } +} + +func (a *AbsCaptureTime) Rewrite(offset time.Duration) error { + if a.absoluteCaptureTimestamp == 0 { + return errInvalidData + } + + capturedAt := a.absoluteCaptureTimestamp.Time().Add(offset) + a.absoluteCaptureTimestamp = mediatransportutil.ToNtpTime(capturedAt) + a.estimatedCaptureClockOffset = 0 + return nil +} + +func (a *AbsCaptureTime) Marshal() ([]byte, error) { + if a.absoluteCaptureTimestamp == 0 { + return nil, errInvalidData + } + + size := 8 + if a.estimatedCaptureClockOffset != 0 { + size += 8 + } + marshalled := make([]byte, size) + binary.BigEndian.PutUint64(marshalled, uint64(a.absoluteCaptureTimestamp)) + if a.estimatedCaptureClockOffset != 0 { + binary.BigEndian.PutUint64(marshalled[8:], uint64(a.estimatedCaptureClockOffset)) + } + return marshalled, nil +} + +func (a *AbsCaptureTime) Unmarshal(marshalled []byte) error { + if len(marshalled) < 8 { + return errTooSmall + } + + a.absoluteCaptureTimestamp = mediatransportutil.NtpTime(binary.BigEndian.Uint64(marshalled)) + if len(marshalled) >= 16 { + a.estimatedCaptureClockOffset = int64(binary.BigEndian.Uint64(marshalled[8:])) + } + return nil +} diff --git a/livekit/pkg/sfu/rtpextension/dependencydescriptor/bitstreamreader.go b/livekit/pkg/sfu/rtpextension/dependencydescriptor/bitstreamreader.go new file mode 100644 index 0000000..62f4d10 --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/dependencydescriptor/bitstreamreader.go @@ -0,0 +1,138 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dependencydescriptor + +import ( + "errors" + "io" +) + +type BitStreamReader struct { + buf []byte + pos int + remainingBits int +} + +func NewBitStreamReader(buf []byte) *BitStreamReader { + return &BitStreamReader{buf: buf, remainingBits: len(buf) * 8} +} + +func (b *BitStreamReader) RemainingBits() int { + return b.remainingBits +} + +// Reads `bits` from the bitstream. `bits` must be in range [0, 64]. +// Returns an unsigned integer in range [0, 2^bits - 1]. +// On failure sets `BitstreamReader` into the failure state and returns 0. +func (b *BitStreamReader) ReadBits(bits int) (uint64, error) { + if bits < 0 || bits > 64 { + return 0, errors.New("invalid number of bits, expected 0-64") + } + + if b.remainingBits < bits { + b.remainingBits -= bits + return 0, io.EOF + } + + remainingBitsInFirstByte := b.remainingBits % 8 + b.remainingBits -= bits + if bits < remainingBitsInFirstByte { + // Reading fewer bits than what's left in the current byte, just + // return the portion of this byte that is needed. + offset := remainingBitsInFirstByte - bits + return uint64((b.buf[b.pos] >> offset) & ((1 << bits) - 1)), nil + } + var result uint64 + if remainingBitsInFirstByte > 0 { + // Read all bits that were left in the current byte and consume that byte. + bits -= remainingBitsInFirstByte + mask := byte((1 << remainingBitsInFirstByte) - 1) + result = uint64(b.buf[b.pos]&mask) << bits + b.pos++ + } + // Read as many full bytes as we can. + for bits >= 8 { + bits -= 8 + result |= uint64(b.buf[b.pos]) << bits + b.pos++ + } + + // Whatever is left to read is smaller than a byte, so grab just the needed + // bits and shift them into the lowest bits. + if bits > 0 { + result |= uint64(b.buf[b.pos] >> (8 - bits)) + } + return result, nil +} + +func (b *BitStreamReader) ReadBool() (bool, error) { + val, err := b.ReadBits(1) + return val != 0, err +} + +func (b *BitStreamReader) Ok() bool { + return b.remainingBits >= 0 +} + +func (b *BitStreamReader) Invalidate() { + b.remainingBits = -1 +} + +// Reads value in range [0, `num_values` - 1]. +// This encoding is similar to ReadBits(val, Ceil(Log2(num_values)), +// but reduces wastage incurred when encoding non-power of two value ranges +// Non symmetric values are encoded as: +// 1) n = bit_width(num_values) +// 2) k = (1 << n) - num_values +// Value v in range [0, k - 1] is encoded in (n-1) bits. +// Value v in range [k, num_values - 1] is encoded as (v+k) in n bits. +// https://aomediacodec.github.io/av1-spec/#nsn +func (b *BitStreamReader) ReadNonSymmetric(numValues uint32) (uint32, error) { + if numValues >= (uint32(1) << 31) { + return 0, errors.New("invalid number of values, expected 0-2^31") + } + + width := bitwidth(numValues) + numMinBitsValues := (uint32(1) << width) - numValues + + val, err := b.ReadBits(width - 1) + if err != nil { + return 0, err + } + if val < uint64(numMinBitsValues) { + return uint32(val), nil + } + bit, err := b.ReadBits(1) + if err != nil { + return 0, err + } + return uint32((val << 1) + bit - uint64(numMinBitsValues)), nil +} + +func (b *BitStreamReader) BytesRead() int { + if b.remainingBits%8 > 0 { + return b.pos + 1 + } + return b.pos +} + +func bitwidth(n uint32) int { + var w int + for n != 0 { + n >>= 1 + w++ + } + return w +} diff --git a/livekit/pkg/sfu/rtpextension/dependencydescriptor/bitstreamwriter.go b/livekit/pkg/sfu/rtpextension/dependencydescriptor/bitstreamwriter.go new file mode 100644 index 0000000..1e5ffac --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/dependencydescriptor/bitstreamwriter.go @@ -0,0 +1,132 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dependencydescriptor + +import ( + "errors" + "fmt" +) + +type BitStreamWriter struct { + buf []byte + pos int + bitOffset int // bit offset in the current byte +} + +func NewBitStreamWriter(buf []byte) *BitStreamWriter { + return &BitStreamWriter{buf: buf} +} + +func (w *BitStreamWriter) RemainingBits() int { + return (len(w.buf)-w.pos)*8 - w.bitOffset +} + +func (w *BitStreamWriter) WriteBits(val uint64, bitCount int) error { + if bitCount > w.RemainingBits() { + return errors.New("insufficient space") + } + + totalBits := bitCount + + // push bits to the highest bits of uint64 + val <<= 64 - bitCount + + buf := w.buf[w.pos:] + + // The first byte is relatively special; the bit offset to write to may put us + // in the middle of the byte, and the total bit count to write may require we + // save the bits at the end of the byte. + remainingBitsInCurrentByte := 8 - w.bitOffset + bitsInFirstByte := bitCount + if bitsInFirstByte > remainingBitsInCurrentByte { + bitsInFirstByte = remainingBitsInCurrentByte + } + + buf[0] = w.writePartialByte(uint8(val>>56), bitsInFirstByte, buf[0], w.bitOffset) + + if bitCount <= remainingBitsInCurrentByte { + // no bit left to write + return w.consumeBits(totalBits) + } + + // write the rest of the bits + val <<= bitsInFirstByte + buf = buf[1:] + bitCount -= bitsInFirstByte + for bitCount >= 8 { + buf[0] = uint8(val >> 56) + buf = buf[1:] + val <<= 8 + bitCount -= 8 + } + + // write the last bits + if bitCount > 0 { + buf[0] = w.writePartialByte(uint8(val>>56), bitCount, buf[0], 0) + } + return w.consumeBits(totalBits) +} + +func (w *BitStreamWriter) consumeBits(bitCount int) error { + if bitCount > w.RemainingBits() { + return errors.New("insufficient space") + } + + w.pos += (w.bitOffset + bitCount) / 8 + w.bitOffset = (w.bitOffset + bitCount) % 8 + + return nil +} + +func (w *BitStreamWriter) writePartialByte(source uint8, sourceBitCount int, target uint8, targetBitOffset int) uint8 { + // if !(targetBitOffset < 8 && sourceBitCount <= (8-targetBitOffset)) { + // return fmt.Errorf("invalid argument, source %d, sourceBitCount %d, target %d, targetBitOffset %d", source, sourceBitCount, target, targetBitOffset) + // } + + // generate mask for bits to overwrite, shift source bits to highest bits, then position to target bit offset + mask := uint8(0xff<<(8-sourceBitCount)) >> uint8(targetBitOffset) + + // clear target bits and write source bits + return (target & ^mask) | (source >> targetBitOffset) +} + +func (w *BitStreamWriter) WriteNonSymmetric(val, numValues uint32) error { + if !(val < numValues && numValues <= 1<<31) { + return fmt.Errorf("invalid argument, val %d, numValues %d", val, numValues) + } + if numValues == 1 { + // When there is only one possible value, it requires zero bits to store it. + // But WriteBits doesn't support writing zero bits. + return nil + } + + countBits := bitwidth(numValues) + numMinBitsValues := (uint32(1) << countBits) - numValues + if val < numMinBitsValues { + return w.WriteBits(uint64(val), countBits-1) + } else { + return w.WriteBits(uint64(val+numMinBitsValues), countBits) + } +} + +func SizeNonSymmetricBits(val, numValues uint32) int { + countBits := bitwidth(numValues) + numMinBitsValues := (uint32(1) << countBits) - numValues + if val < numMinBitsValues { + return countBits - 1 + } else { + return countBits + } +} diff --git a/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorextension.go b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorextension.go new file mode 100644 index 0000000..2934823 --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorextension.go @@ -0,0 +1,199 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dependencydescriptor + +import ( + "fmt" + "math" + "strconv" +) + +// DependencyDescriptorExtension is a extension payload format in +// https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension + +func formatBitmask(b *uint32) string { + if b == nil { + return "-" + } + return strconv.FormatInt(int64(*b), 2) +} + +// ------------------------------------------------------------------------------ + +type DependencyDescriptorExtension struct { + Descriptor *DependencyDescriptor + Structure *FrameDependencyStructure +} + +func (d *DependencyDescriptorExtension) Marshal() ([]byte, error) { + return d.MarshalWithActiveChains(^uint32(0)) +} + +func (d *DependencyDescriptorExtension) MarshalWithActiveChains(activeChains uint32) ([]byte, error) { + writer, err := NewDependencyDescriptorWriter(nil, d.Structure, activeChains, d.Descriptor) + if err != nil { + return nil, err + } + buf := make([]byte, int(math.Ceil(float64(writer.ValueSizeBits())/8))) + writer.ResetBuf(buf) + if err = writer.Write(); err != nil { + return nil, err + } + return buf, nil +} + +func (d *DependencyDescriptorExtension) Unmarshal(buf []byte) (int, error) { + reader := NewDependencyDescriptorReader(buf, d.Structure, d.Descriptor) + return reader.Parse() +} + +// ------------------------------------------------------------------------------ + +const ( + MaxSpatialIds = 4 + MaxTemporalIds = 8 + MaxDecodeTargets = 32 + MaxTemplates = 64 + + AllChainsAreActive = uint32(0) + + ExtensionURI = "https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension" +) + +// ------------------------------------------------------------------------------ + +type DependencyDescriptor struct { + FirstPacketInFrame bool + LastPacketInFrame bool + FrameNumber uint16 + FrameDependencies *FrameDependencyTemplate + Resolution *RenderResolution + ActiveDecodeTargetsBitmask *uint32 + AttachedStructure *FrameDependencyStructure +} + +func (d *DependencyDescriptor) MarshalSize() (int, error) { + return d.MarshalSizeWithActiveChains(^uint32(0)) +} + +func (d *DependencyDescriptor) MarshalSizeWithActiveChains(activeChains uint32) (int, error) { + writer, err := NewDependencyDescriptorWriter(nil, d.AttachedStructure, activeChains, d) + if err != nil { + return 0, err + } + return int(math.Ceil(float64(writer.ValueSizeBits()) / 8)), nil +} + +func (d *DependencyDescriptor) String() string { + resolution, dependencies := "-", "-" + if d.Resolution != nil { + resolution = fmt.Sprintf("%+v", *d.Resolution) + } + if d.FrameDependencies != nil { + dependencies = fmt.Sprintf("%+v", *d.FrameDependencies) + } + return fmt.Sprintf("DependencyDescriptor{FirstPacketInFrame: %v, LastPacketInFrame: %v, FrameNumber: %v, FrameDependencies: %s, Resolution: %s, ActiveDecodeTargetsBitmask: %v, AttachedStructure: %v}", + d.FirstPacketInFrame, d.LastPacketInFrame, d.FrameNumber, dependencies, resolution, formatBitmask(d.ActiveDecodeTargetsBitmask), d.AttachedStructure) +} + +// ------------------------------------------------------------------------------ + +// Relationship of a frame to a Decode target. +type DecodeTargetIndication int + +const ( + DecodeTargetNotPresent DecodeTargetIndication = iota // DecodeTargetInfo symbol '-' + DecodeTargetDiscardable // DecodeTargetInfo symbol 'D' + DecodeTargetSwitch // DecodeTargetInfo symbol 'S' + DecodeTargetRequired // DecodeTargetInfo symbol 'R' +) + +func (i DecodeTargetIndication) String() string { + switch i { + case DecodeTargetNotPresent: + return "-" + case DecodeTargetDiscardable: + return "D" + case DecodeTargetSwitch: + return "S" + case DecodeTargetRequired: + return "R" + default: + return "Unknown" + } +} + +// ------------------------------------------------------------------------------ + +type FrameDependencyTemplate struct { + SpatialId int + TemporalId int + DecodeTargetIndications []DecodeTargetIndication + FrameDiffs []int + ChainDiffs []int +} + +func (t *FrameDependencyTemplate) Clone() *FrameDependencyTemplate { + t2 := &FrameDependencyTemplate{ + SpatialId: t.SpatialId, + TemporalId: t.TemporalId, + } + + t2.DecodeTargetIndications = make([]DecodeTargetIndication, len(t.DecodeTargetIndications)) + copy(t2.DecodeTargetIndications, t.DecodeTargetIndications) + + t2.FrameDiffs = make([]int, len(t.FrameDiffs)) + copy(t2.FrameDiffs, t.FrameDiffs) + + t2.ChainDiffs = make([]int, len(t.ChainDiffs)) + copy(t2.ChainDiffs, t.ChainDiffs) + + return t2 +} + +// ------------------------------------------------------------------------------ + +type FrameDependencyStructure struct { + StructureId int + NumDecodeTargets int + NumChains int + // If chains are used (num_chains > 0), maps decode target index into index of + // the chain protecting that target. + DecodeTargetProtectedByChain []int + Resolutions []RenderResolution + Templates []*FrameDependencyTemplate +} + +func (f *FrameDependencyStructure) String() string { + str := fmt.Sprintf("FrameDependencyStructure{StructureId: %v, NumDecodeTargets: %v, NumChains: %v, DecodeTargetProtectedByChain: %v, Resolutions: %+v, Templates: [", + f.StructureId, f.NumDecodeTargets, f.NumChains, f.DecodeTargetProtectedByChain, f.Resolutions) + + // templates + for _, t := range f.Templates { + str += fmt.Sprintf("%+v, ", t) + } + str += "]}" + + return str +} + +// ------------------------------------------------------------------------------ + +type RenderResolution struct { + Width int + Height int +} + +// ------------------------------------------------------------------------------ diff --git a/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorextension_test.go b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorextension_test.go new file mode 100644 index 0000000..95c3aba --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorextension_test.go @@ -0,0 +1,65 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dependencydescriptor + +import ( + "encoding/hex" + "testing" +) + +func TestDependencyDescriptorUnmarshal(t *testing.T) { + + // hex bytes from traffic capture + hexes := []string{ + "c1017280081485214eafffaaaa863cf0430c10c302afc0aaa0063c00430010c002a000a80006000040001d954926e082b04a0941b820ac1282503157f974000ca864330e222222eca8655304224230eca877530077004200ef008601df010d", + "86017340fc", + "46017340fc", + "c3017540fc", + "88017640fc", + "48017640fc", + "c2017840fc", + // + "c1017280081485214eafffaaaa863cf0430c10c302afc0aaa0063c00430010c002a000a80006000040001d954926e082b04a0941b820ac1282503157f974000ca864330e222222eca8655304224230eca877530077004200ef008601df010d", + "860173", + "460173", + "8b0174", + "0b0174", + "0b0174", + "c30175", + } + + var structure *FrameDependencyStructure + + for _, h := range hexes { + buf, err := hex.DecodeString(h) + if err != nil { + t.Fatal(err) + } + + var ddVal DependencyDescriptor + var d = DependencyDescriptorExtension{ + Structure: structure, + Descriptor: &ddVal, + } + if _, err := d.Unmarshal(buf); err != nil { + t.Fatal(err) + } + if ddVal.AttachedStructure != nil { + structure = ddVal.AttachedStructure + } + + t.Log(ddVal.String()) + } +} diff --git a/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorreader.go b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorreader.go new file mode 100644 index 0000000..2ca21ff --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorreader.go @@ -0,0 +1,446 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dependencydescriptor + +import ( + "errors" +) + +var ( + ErrDDReaderNoStructure = errors.New("DependencyDescriptorReader: Structure is nil") + ErrDDReaderTemplateWithoutStructure = errors.New("DependencyDescriptorReader: has templateDependencyStructurePresentFlag but AttachedStructure is nil") + ErrDDReaderTooManyTemplates = errors.New("DependencyDescriptorReader: too many templates") + ErrDDReaderTooManyTemporalLayers = errors.New("DependencyDescriptorReader: too many temporal layers") + ErrDDReaderTooManySpatialLayers = errors.New("DependencyDescriptorReader: too many spatial layers") + ErrDDReaderInvalidTemplateIndex = errors.New("DependencyDescriptorReader: invalid template index") + ErrDDReaderInvalidSpatialLayer = errors.New("DependencyDescriptorReader: invalid spatial layer, should be less than the number of resolutions") + ErrDDReaderNumDTIMismatch = errors.New("DependencyDescriptorReader: decode target indications length mismatch with structure num decode targets") + ErrDDReaderNumChainDiffsMismatch = errors.New("DependencyDescriptorReader: chain diffs length mismatch with structure num chains") +) + +type DependencyDescriptorReader struct { + // Output. + descriptor *DependencyDescriptor + + // Values that are needed while reading the descriptor, but can be discarded + // when reading is complete. + buffer *BitStreamReader + frameDependencyTemplateId int + activeDecodeTargetsPresentFlag bool + customDtisFlag bool + customFdiffsFlag bool + customChainsFlag bool + structure *FrameDependencyStructure +} + +func NewDependencyDescriptorReader(buf []byte, structure *FrameDependencyStructure, descriptor *DependencyDescriptor) *DependencyDescriptorReader { + buffer := NewBitStreamReader(buf) + return &DependencyDescriptorReader{ + buffer: buffer, + descriptor: descriptor, + structure: structure, + } +} + +func (r *DependencyDescriptorReader) Parse() (int, error) { + if err := r.readMandatoryFields(); err != nil { + return 0, err + } + if len(r.buffer.buf) > 3 { + err := r.readExtendedFields() + if err != nil { + return 0, err + } + } + + if r.descriptor.AttachedStructure != nil { + r.structure = r.descriptor.AttachedStructure + } + + if r.structure == nil { + r.buffer.Invalidate() + return 0, ErrDDReaderNoStructure + } + + if r.activeDecodeTargetsPresentFlag { + bitmask, err := r.buffer.ReadBits(r.structure.NumDecodeTargets) + if err != nil { + return 0, err + } + mask := uint32(bitmask) + r.descriptor.ActiveDecodeTargetsBitmask = &mask + } + + err := r.readFrameDependencyDefinition() + if err != nil { + return 0, err + } + return r.buffer.BytesRead(), nil +} + +func (r *DependencyDescriptorReader) readMandatoryFields() error { + var err error + r.descriptor.FirstPacketInFrame, err = r.buffer.ReadBool() + if err != nil { + return err + } + + r.descriptor.LastPacketInFrame, err = r.buffer.ReadBool() + if err != nil { + return err + } + + templateID, err := r.buffer.ReadBits(6) + if err != nil { + return err + } + r.frameDependencyTemplateId = int(templateID) + + frameNumber, err := r.buffer.ReadBits(16) + if err != nil { + return err + } + r.descriptor.FrameNumber = uint16(frameNumber) + return nil +} + +func (r *DependencyDescriptorReader) readExtendedFields() error { + templateDependencyStructurePresentFlag, err := r.buffer.ReadBool() + if err != nil { + return err + } + + flag, err := r.buffer.ReadBool() + if err != nil { + return err + } + r.activeDecodeTargetsPresentFlag = flag + + flag, err = r.buffer.ReadBool() + if err != nil { + return err + } + r.customDtisFlag = flag + + flag, err = r.buffer.ReadBool() + if err != nil { + return err + } + r.customFdiffsFlag = flag + + flag, err = r.buffer.ReadBool() + if err != nil { + return err + } + r.customChainsFlag = flag + + if templateDependencyStructurePresentFlag { + err = r.readTemplateDependencyStructure() + if err != nil { + return err + } + if r.descriptor.AttachedStructure == nil { + return ErrDDReaderTemplateWithoutStructure + } + bitmask := uint32((uint64(1) << r.descriptor.AttachedStructure.NumDecodeTargets) - 1) + r.descriptor.ActiveDecodeTargetsBitmask = &bitmask + } + return nil +} + +func (r *DependencyDescriptorReader) readTemplateDependencyStructure() error { + r.descriptor.AttachedStructure = &FrameDependencyStructure{} + structureId, err := r.buffer.ReadBits(6) + if err != nil { + return err + } + r.descriptor.AttachedStructure.StructureId = int(structureId) + + numDecodeTargets, err := r.buffer.ReadBits(5) + if err != nil { + return err + } + r.descriptor.AttachedStructure.NumDecodeTargets = int(numDecodeTargets) + 1 + + if err = r.readTemplateLayers(); err != nil { + return err + } + if err = r.readTemplateDtis(); err != nil { + return err + } + if err = r.readTemplateFdiffs(); err != nil { + return err + } + if err = r.readTemplateChains(); err != nil { + return err + } + + flag, err := r.buffer.ReadBool() + if err != nil { + return err + } + if flag { + return r.readResolutions() + } + return nil +} + +type nextLayerIdcType int + +const ( + sameLayer nextLayerIdcType = iota + nextTemporalLayer + nextSpatialLayer + noMoreLayer + invalidLayer +) + +func (r *DependencyDescriptorReader) readTemplateLayers() error { + var ( + templates []*FrameDependencyTemplate + temporalId, spatialId int + nextLayerIdc nextLayerIdcType + ) + for { + if len(templates) == MaxTemplates { + return ErrDDReaderTooManyTemplates + } + + var lastTemplate FrameDependencyTemplate + templates = append(templates, &lastTemplate) + lastTemplate.TemporalId = temporalId + lastTemplate.SpatialId = spatialId + + idc, err := r.buffer.ReadBits(2) + if err != nil { + return err + } + nextLayerIdc = nextLayerIdcType(idc) + + if nextLayerIdc == nextTemporalLayer { + temporalId++ + if temporalId >= MaxTemporalIds { + return ErrDDReaderTooManyTemporalLayers + } + } else if nextLayerIdc == nextSpatialLayer { + spatialId++ + temporalId = 0 + if spatialId >= MaxSpatialIds { + return ErrDDReaderTooManySpatialLayers + } + } + + if !(nextLayerIdc != noMoreLayer && r.buffer.Ok()) { + break + } + } + + r.descriptor.AttachedStructure.Templates = templates + return nil +} + +func (r *DependencyDescriptorReader) readTemplateDtis() error { + structure := r.descriptor.AttachedStructure + for _, template := range structure.Templates { + if len(template.DecodeTargetIndications) < structure.NumDecodeTargets { + template.DecodeTargetIndications = append(template.DecodeTargetIndications, make([]DecodeTargetIndication, structure.NumDecodeTargets-len(template.DecodeTargetIndications))...) + } else { + template.DecodeTargetIndications = template.DecodeTargetIndications[:structure.NumDecodeTargets] + } + + for i := range template.DecodeTargetIndications { + indication, err := r.buffer.ReadBits(2) + if err != nil { + return err + } + template.DecodeTargetIndications[i] = DecodeTargetIndication(indication) + } + } + return nil +} + +func (r *DependencyDescriptorReader) readTemplateFdiffs() error { + for _, template := range r.descriptor.AttachedStructure.Templates { + for { + fdiffFollow, err := r.buffer.ReadBool() + if err != nil { + return err + } + if !fdiffFollow { + break + } + fDiffMinusOne, err := r.buffer.ReadBits(4) + if err != nil { + return err + } + template.FrameDiffs = append(template.FrameDiffs, int(fDiffMinusOne+1)) + } + } + + return nil +} + +func (r *DependencyDescriptorReader) readTemplateChains() error { + structure := r.descriptor.AttachedStructure + + numChains, err := r.buffer.ReadNonSymmetric(uint32(structure.NumDecodeTargets) + 1) + if err != nil { + return err + } + structure.NumChains = int(numChains) + if structure.NumChains == 0 { + return nil + } + + for i := 0; i < structure.NumDecodeTargets; i++ { + protectedByChain, err := r.buffer.ReadNonSymmetric(uint32(structure.NumChains)) + if err != nil { + return err + } + structure.DecodeTargetProtectedByChain = append(structure.DecodeTargetProtectedByChain, int(protectedByChain)) + } + + for _, frameTemplate := range structure.Templates { + for chainId := 0; chainId < structure.NumChains; chainId++ { + chainDiff, err := r.buffer.ReadBits(4) + if err != nil { + return err + } + frameTemplate.ChainDiffs = append(frameTemplate.ChainDiffs, int(chainDiff)) + } + } + + return nil +} + +func (r *DependencyDescriptorReader) readResolutions() error { + structure := r.descriptor.AttachedStructure + // The way templates are bitpacked, they are always ordered by spatial_id. + spatialLayers := structure.Templates[len(structure.Templates)-1].SpatialId + 1 + for sid := 0; sid < spatialLayers; sid++ { + widthMinus1, err := r.buffer.ReadBits(16) + if err != nil { + return err + } + heightMinus1, err := r.buffer.ReadBits(16) + if err != nil { + return err + } + structure.Resolutions = append(structure.Resolutions, RenderResolution{ + Width: int(widthMinus1 + 1), + Height: int(heightMinus1 + 1), + }) + } + + return nil +} + +func (r *DependencyDescriptorReader) readFrameDependencyDefinition() error { + templateIndex := (r.frameDependencyTemplateId + MaxTemplates - r.structure.StructureId) % MaxTemplates + + if templateIndex >= len(r.structure.Templates) { + r.buffer.Invalidate() + return ErrDDReaderInvalidTemplateIndex + } + + // Copy all the fields from the matching template + r.descriptor.FrameDependencies = r.structure.Templates[templateIndex].Clone() + + if r.customDtisFlag { + err := r.readFrameDtis() + if err != nil { + return err + } + } + + if r.customFdiffsFlag { + err := r.readFrameFdiffs() + if err != nil { + return err + } + } + + if r.customChainsFlag { + err := r.readFrameChains() + if err != nil { + return err + } + } + + if len(r.structure.Resolutions) == 0 { + r.descriptor.Resolution = nil + } else { + // Format guarantees that if there were resolutions in the last structure, + // then each spatial layer got one. + if r.descriptor.FrameDependencies.SpatialId >= len(r.structure.Resolutions) { + r.buffer.Invalidate() + return ErrDDReaderInvalidSpatialLayer + } + res := r.structure.Resolutions[r.descriptor.FrameDependencies.SpatialId] + r.descriptor.Resolution = &res + } + + return nil +} + +func (r *DependencyDescriptorReader) readFrameDtis() error { + if len(r.descriptor.FrameDependencies.DecodeTargetIndications) != r.structure.NumDecodeTargets { + return ErrDDReaderNumDTIMismatch + } + + for i := range r.descriptor.FrameDependencies.DecodeTargetIndications { + indication, err := r.buffer.ReadBits(2) + if err != nil { + return err + } + r.descriptor.FrameDependencies.DecodeTargetIndications[i] = DecodeTargetIndication(indication) + } + return nil +} + +func (r *DependencyDescriptorReader) readFrameFdiffs() error { + r.descriptor.FrameDependencies.FrameDiffs = r.descriptor.FrameDependencies.FrameDiffs[:0] + for { + nexFdiffSize, err := r.buffer.ReadBits(2) + if err != nil { + return err + } + if nexFdiffSize == 0 { + break + } + fDiffMinusOne, err := r.buffer.ReadBits(int(nexFdiffSize * 4)) + if err != nil { + return err + } + r.descriptor.FrameDependencies.FrameDiffs = append(r.descriptor.FrameDependencies.FrameDiffs, int(fDiffMinusOne+1)) + } + + return nil +} + +func (r *DependencyDescriptorReader) readFrameChains() error { + if len(r.descriptor.FrameDependencies.ChainDiffs) != r.structure.NumChains { + return ErrDDReaderNumChainDiffsMismatch + } + + for i := range r.descriptor.FrameDependencies.ChainDiffs { + chainDiff, err := r.buffer.ReadBits(8) + if err != nil { + return err + } + r.descriptor.FrameDependencies.ChainDiffs[i] = int(chainDiff) + } + return nil +} diff --git a/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorwriter.go b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorwriter.go new file mode 100644 index 0000000..d392c90 --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/dependencydescriptor/dependencydescriptorwriter.go @@ -0,0 +1,501 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dependencydescriptor + +import ( + "fmt" + + "golang.org/x/exp/slices" +) + +type TemplateMatch struct { + TemplateIdx int + NeedCustomDtis bool + NeedCustomFdiffs bool + NeedCustomChains bool + // Size in bits to store frame-specific details, i.e. + // excluding mandatory fields and template dependency structure. + ExtraSizeBits int +} + +type DependencyDescriptorWriter struct { + descriptor *DependencyDescriptor + structure *FrameDependencyStructure + activeChains uint32 + writer *BitStreamWriter + bestTemplate TemplateMatch +} + +func NewDependencyDescriptorWriter(buf []byte, structure *FrameDependencyStructure, activeChains uint32, descriptor *DependencyDescriptor) (*DependencyDescriptorWriter, error) { + writer := NewBitStreamWriter(buf) + w := &DependencyDescriptorWriter{ + descriptor: descriptor, + structure: structure, + activeChains: activeChains, + writer: writer, + } + return w, w.findBestTemplate() +} + +func (w *DependencyDescriptorWriter) ResetBuf(buf []byte) { + w.writer = NewBitStreamWriter(buf) +} + +func (w *DependencyDescriptorWriter) Write() error { + if err := w.findBestTemplate(); err != nil { + return err + } + + if err := w.writeMandatoryFields(); err != nil { + return err + } + + if w.hasExtendedFields() { + if err := w.writeExtendedFields(); err != nil { + return err + } + + if err := w.writeFrameDependencyDefinition(); err != nil { + return err + } + } + + remainingBits := w.writer.RemainingBits() + // Zero remaining memory to avoid leaving it uninitialized. + if remainingBits%64 != 0 { + if err := w.writeBits(0, remainingBits%64); err != nil { + return err + } + } + + for i := 0; i < remainingBits/64; i++ { + if err := w.writeBits(0, 64); err != nil { + return err + } + } + + return nil +} + +func (w *DependencyDescriptorWriter) findBestTemplate() error { + // Find templates with same spatial and temporal layer of frame dependency. + var ( + firstSameLayer *FrameDependencyTemplate + firstSameLayerIdx, lastSameLayerIdx int + ) + for i, t := range w.structure.Templates { + if w.descriptor.FrameDependencies.SpatialId == t.SpatialId && + w.descriptor.FrameDependencies.TemporalId == t.TemporalId { + firstSameLayer = t + firstSameLayerIdx = i + break + } + } + + if firstSameLayer == nil { + return fmt.Errorf("no template found for spatial layer %d and temporal layer %d", w.descriptor.FrameDependencies.SpatialId, w.descriptor.FrameDependencies.TemporalId) + } + + for i, t := range w.structure.Templates[firstSameLayerIdx:] { + if w.descriptor.FrameDependencies.SpatialId != t.SpatialId || + w.descriptor.FrameDependencies.TemporalId != t.TemporalId { + lastSameLayerIdx = i + firstSameLayerIdx + } + } + + // Search if there any better template that have small extra size. + w.bestTemplate = w.calculateMatch(firstSameLayerIdx, firstSameLayer) + for i := firstSameLayerIdx + 1; i <= lastSameLayerIdx; i++ { + t := w.structure.Templates[i] + match := w.calculateMatch(i, t) + if match.ExtraSizeBits < w.bestTemplate.ExtraSizeBits { + w.bestTemplate = match + } + } + return nil +} + +func (w *DependencyDescriptorWriter) calculateMatch(idx int, template *FrameDependencyTemplate) TemplateMatch { + var result TemplateMatch + result.TemplateIdx = idx + result.NeedCustomFdiffs = w.descriptor.FrameDependencies.FrameDiffs != nil && !slices.Equal(w.descriptor.FrameDependencies.FrameDiffs, template.FrameDiffs) + result.NeedCustomDtis = w.descriptor.FrameDependencies.DecodeTargetIndications != nil && !slices.Equal(w.descriptor.FrameDependencies.DecodeTargetIndications, template.DecodeTargetIndications) + + for i := 0; i < w.structure.NumChains; i++ { + if w.activeChains&(1< 0 || w.descriptor.AttachedStructure != nil || w.descriptor.ActiveDecodeTargetsBitmask != nil +} + +func (w *DependencyDescriptorWriter) writeExtendedFields() error { + // template_dependency_structure_present_flag + if err := w.writeBool(w.descriptor.AttachedStructure != nil); err != nil { + return err + } + + // active_decode_targets_present_flag + activeDecodeTargetsPresentFlag := w.shouldWriteActiveDecodeTargetsBitmask() + if err := w.writeBool(activeDecodeTargetsPresentFlag); err != nil { + return err + } + + // need_custom_dtis + if err := w.writeBool(w.bestTemplate.NeedCustomDtis); err != nil { + return err + } + + // need_custom_fdiffs + if err := w.writeBool(w.bestTemplate.NeedCustomFdiffs); err != nil { + return err + } + + // need_custom_chains + if err := w.writeBool(w.bestTemplate.NeedCustomChains); err != nil { + return err + } + + // template_dependency_structure + if w.descriptor.AttachedStructure != nil { + if err := w.writeTemplateDependencyStructure(); err != nil { + return err + } + } + + // active_decode_targets_bitmask + if activeDecodeTargetsPresentFlag { + if err := w.writeBits(uint64(*w.descriptor.ActiveDecodeTargetsBitmask), w.structure.NumDecodeTargets); err != nil { + return err + } + } + + return nil +} + +func (w *DependencyDescriptorWriter) writeTemplateDependencyStructure() error { + if !(w.structure.StructureId >= 0 && w.structure.StructureId < MaxTemplates && + w.structure.NumDecodeTargets > 0 && w.structure.NumDecodeTargets <= MaxDecodeTargets) { + return fmt.Errorf("invalid arguments, structureId: %d, numDecodeTargets: %d", w.structure.StructureId, w.structure.NumDecodeTargets) + } + + if err := w.writeBits(uint64(w.structure.StructureId), 6); err != nil { + return err + } + + if err := w.writeBits(uint64(w.structure.NumDecodeTargets-1), 5); err != nil { + return err + } + + if err := w.writeTemplateLayers(); err != nil { + return err + } + + if err := w.writeTemplateDtis(); err != nil { + return err + } + + if err := w.writeTemplateFdiffs(); err != nil { + return err + } + + if err := w.writeTemplateChains(); err != nil { + return err + } + + hasResolutions := len(w.structure.Resolutions) > 0 + if err := w.writeBool(hasResolutions); err != nil { + return err + } + return w.writeResolutions() +} + +func (w *DependencyDescriptorWriter) writeTemplateLayers() error { + if !(len(w.structure.Templates) > 0 && len(w.structure.Templates) <= MaxTemplates && + w.structure.Templates[0].SpatialId == 0 && w.structure.Templates[0].TemporalId == 0) { + return fmt.Errorf("invalid templates, len %d, templates[0]: spatialId %d, temporalId %d", len(w.structure.Templates), w.structure.Templates[0].SpatialId, w.structure.Templates[0].TemporalId) + } + for i := 1; i < len(w.structure.Templates); i++ { + nextLayerIdc := getNextLayerIdc(w.structure.Templates[i-1], w.structure.Templates[i]) + if nextLayerIdc >= 3 { + return fmt.Errorf("invalid next_layer_idc %d", nextLayerIdc) + } + if err := w.writeBits(uint64(nextLayerIdc), 2); err != nil { + return err + } + } + return w.writeBits(uint64(noMoreLayer), 2) +} + +func getNextLayerIdc(prevTemplate, nextTemplate *FrameDependencyTemplate) nextLayerIdcType { + if nextTemplate.SpatialId == prevTemplate.SpatialId && nextTemplate.TemporalId == prevTemplate.TemporalId { + return sameLayer + } else if nextTemplate.SpatialId == prevTemplate.SpatialId && nextTemplate.TemporalId == prevTemplate.TemporalId+1 { + return nextTemporalLayer + } else if nextTemplate.SpatialId == prevTemplate.SpatialId+1 && nextTemplate.TemporalId == 0 { + return nextSpatialLayer + } + + return invalidLayer +} + +func (w *DependencyDescriptorWriter) writeTemplateDtis() error { + for _, t := range w.structure.Templates { + for _, dti := range t.DecodeTargetIndications { + if err := w.writeBits(uint64(dti), 2); err != nil { + return err + } + } + } + + return nil +} + +func (w *DependencyDescriptorWriter) writeTemplateFdiffs() error { + for _, t := range w.structure.Templates { + for _, fdiff := range t.FrameDiffs { + if err := w.writeBits(uint64(1<<4)|uint64(fdiff-1), 1+4); err != nil { + return err + } + } + // no more fdiffs for this template + if err := w.writeBits(uint64(0), 1); err != nil { + return err + } + } + + return nil +} + +func (w *DependencyDescriptorWriter) writeTemplateChains() error { + if err := w.writeNonSymmetric(uint32(w.structure.NumChains), uint32(w.structure.NumDecodeTargets+1)); err != nil { + return err + } + + if w.structure.NumChains == 0 { + return nil + } + + for _, protectedBy := range w.structure.DecodeTargetProtectedByChain { + if err := w.writeNonSymmetric(uint32(protectedBy), uint32(w.structure.NumChains)); err != nil { + return err + } + } + + for _, t := range w.structure.Templates { + for _, chainDiff := range t.ChainDiffs { + if err := w.writeBits(uint64(chainDiff), 4); err != nil { + return err + } + } + } + return nil +} + +func (w *DependencyDescriptorWriter) writeNonSymmetric(value, numValues uint32) error { + return w.writer.WriteNonSymmetric(value, numValues) +} + +func (w *DependencyDescriptorWriter) writeResolutions() error { + for _, res := range w.structure.Resolutions { + if err := w.writeBits(uint64(res.Width)-1, 16); err != nil { + return err + } + if err := w.writeBits(uint64(res.Height)-1, 16); err != nil { + return err + } + } + return nil +} + +func (w *DependencyDescriptorWriter) writeFrameDependencyDefinition() error { + if w.bestTemplate.NeedCustomDtis { + if err := w.writeFrameDtis(); err != nil { + return err + } + } + + if w.bestTemplate.NeedCustomFdiffs { + if err := w.writeFrameFdiffs(); err != nil { + return err + } + } + + if w.bestTemplate.NeedCustomChains { + if err := w.writeFrameChains(); err != nil { + return err + } + } + + return nil +} + +func (w *DependencyDescriptorWriter) writeFrameDtis() error { + for _, dti := range w.descriptor.FrameDependencies.DecodeTargetIndications { + if err := w.writeBits(uint64(dti), 2); err != nil { + return err + } + } + return nil +} + +func (w *DependencyDescriptorWriter) writeFrameFdiffs() error { + for _, fdiff := range w.descriptor.FrameDependencies.FrameDiffs { + if fdiff <= (1 << 4) { + if err := w.writeBits(uint64(1<<4)|uint64(fdiff-1), 2+4); err != nil { + return err + } + } else if fdiff <= (1 << 8) { + if err := w.writeBits(uint64(2<<8)|uint64(fdiff-1), 2+8); err != nil { + return err + } + } else { // fdiff <= (1<<12) + if err := w.writeBits(uint64(3<<12)|uint64(fdiff-1), 2+12); err != nil { + return err + } + } + } + // no more fdiffs + return w.writeBits(uint64(0), 2) +} + +func (w *DependencyDescriptorWriter) writeFrameChains() error { + for i := 0; i < w.structure.NumChains; i++ { + chainDiff := 0 + if w.activeChains&(1< 0 { + for _, protectedBy := range w.structure.DecodeTargetProtectedByChain { + bits += SizeNonSymmetricBits(uint32(protectedBy), uint32(w.structure.NumChains)) + } + bits += 4 * len(w.structure.Templates) * w.structure.NumChains + } + + // resolutions + bits += 1 + 32*len(w.structure.Resolutions) + + return bits +} diff --git a/livekit/pkg/sfu/rtpextension/playoutdelay/playoutdelay.go b/livekit/pkg/sfu/rtpextension/playoutdelay/playoutdelay.go new file mode 100644 index 0000000..d101784 --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/playoutdelay/playoutdelay.go @@ -0,0 +1,72 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package playoutdelay + +import ( + "encoding/binary" + "errors" +) + +const ( + PlayoutDelayURI = "http://www.webrtc.org/experiments/rtp-hdrext/playout-delay" + MaxPlayoutDelayDefault = 10000 // 10s, equal to chrome's default max playout delay + PlayoutDelayMaxValue = 10 * (1<<12 - 1) // max value for playout delay can be represented + + playoutDelayExtensionSize = 3 +) + +var ( + errPlayoutDelayOverflow = errors.New("playout delay overflow") + errTooSmall = errors.New("buffer too small") +) + +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=2 | MIN delay | MAX delay | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// The wired MIN/MAX delay is in 10ms unit + +type PlayOutDelay struct { + Min, Max uint16 // delay in ms +} + +func PlayoutDelayFromValue(min, max uint16) PlayOutDelay { + if min > PlayoutDelayMaxValue { + min = PlayoutDelayMaxValue + } + if max > PlayoutDelayMaxValue { + max = PlayoutDelayMaxValue + } + return PlayOutDelay{Min: min, Max: max} +} + +func (p PlayOutDelay) Marshal() ([]byte, error) { + min, max := p.Min/10, p.Max/10 + if min >= 1<<12 || max >= 1<<12 { + return nil, errPlayoutDelayOverflow + } + + return []byte{byte(min >> 4), byte(min<<4) | byte(max>>8), byte(max)}, nil +} + +func (p *PlayOutDelay) Unmarshal(rawData []byte) error { + if len(rawData) < playoutDelayExtensionSize { + return errTooSmall + } + + p.Min = (binary.BigEndian.Uint16(rawData) >> 4) * 10 + p.Max = (binary.BigEndian.Uint16(rawData[1:]) & 0x0FFF) * 10 + return nil +} diff --git a/livekit/pkg/sfu/rtpextension/playoutdelay/playoutdelay_test.go b/livekit/pkg/sfu/rtpextension/playoutdelay/playoutdelay_test.go new file mode 100644 index 0000000..5a2dc07 --- /dev/null +++ b/livekit/pkg/sfu/rtpextension/playoutdelay/playoutdelay_test.go @@ -0,0 +1,57 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package playoutdelay + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPlayoutDelay(t *testing.T) { + p1 := PlayOutDelay{Min: 100, Max: 200} + b, err := p1.Marshal() + require.NoError(t, err) + require.Len(t, b, playoutDelayExtensionSize) + var p2 PlayOutDelay + err = p2.Unmarshal(b) + require.NoError(t, err) + require.Equal(t, p1, p2) + + // overflow + p3 := PlayOutDelay{Min: 100, Max: (1 << 12) * 10} + _, err = p3.Marshal() + require.ErrorIs(t, err, errPlayoutDelayOverflow) + + // too small + p4 := PlayOutDelay{} + err = p4.Unmarshal([]byte{0x00, 0x00}) + require.ErrorIs(t, err, errTooSmall) + + // from value + p5 := PlayoutDelayFromValue(1<<12*10, 1<<12*10+10) + _, err = p5.Marshal() + require.NoError(t, err) + require.Equal(t, uint16((1<<12)-1)*10, p5.Min) + require.Equal(t, uint16((1<<12)-1)*10, p5.Max) + + p6 := PlayOutDelay{Min: 100, Max: PlayoutDelayMaxValue} + bytes, err := p6.Marshal() + require.NoError(t, err) + p6Unmarshal := PlayOutDelay{} + err = p6Unmarshal.Unmarshal(bytes) + require.NoError(t, err) + require.Equal(t, p6, p6Unmarshal) +} diff --git a/livekit/pkg/sfu/rtpmunger.go b/livekit/pkg/sfu/rtpmunger.go new file mode 100644 index 0000000..882ae04 --- /dev/null +++ b/livekit/pkg/sfu/rtpmunger.go @@ -0,0 +1,349 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/utils" +) + +// RTPMunger +type SequenceNumberOrdering int + +const ( + SequenceNumberOrderingContiguous SequenceNumberOrdering = iota + SequenceNumberOrderingOutOfOrder + SequenceNumberOrderingGap + SequenceNumberOrderingDuplicate +) + +const ( + RtxGateWindow = 2000 +) + +type TranslationParamsRTP struct { + snOrdering SequenceNumberOrdering + extSequenceNumber uint64 + extTimestamp uint64 +} + +type SnTs struct { + extSequenceNumber uint64 + extTimestamp uint64 +} + +// ---------------------------------------------------------------------- + +type RTPMunger struct { + logger logger.Logger + + extHighestIncomingSN uint64 + snRangeMap *utils.RangeMap[uint64, uint64] + + extLastSN uint64 + extSecondLastSN uint64 + snOffset uint64 + + extLastTS uint64 + extSecondLastTS uint64 + tsOffset uint64 + + lastMarker bool + secondLastMarker bool + + extRtxGateSn uint64 + isInRtxGateRegion bool +} + +func NewRTPMunger(logger logger.Logger) *RTPMunger { + return &RTPMunger{ + logger: logger, + snRangeMap: utils.NewRangeMap[uint64, uint64](100), + } +} + +func (r *RTPMunger) DebugInfo() map[string]any { + return map[string]any{ + "ExtHighestIncomingSN": r.extHighestIncomingSN, + "ExtLastSN": r.extLastSN, + "ExtSecondLastSN": r.extSecondLastSN, + "SNOffset": r.snOffset, + "ExtLastTS": r.extLastTS, + "ExtSecondLastTS": r.extSecondLastTS, + "TSOffset": r.tsOffset, + "LastMarker": r.lastMarker, + "SecondLastMarker": r.secondLastMarker, + } +} + +func (r *RTPMunger) GetState() *livekit.RTPMungerState { + return &livekit.RTPMungerState{ + ExtLastSequenceNumber: r.extLastSN, + ExtSecondLastSequenceNumber: r.extSecondLastSN, + ExtLastTimestamp: r.extLastTS, + ExtSecondLastTimestamp: r.extSecondLastTS, + LastMarker: r.lastMarker, + SecondLastMarker: r.secondLastMarker, + } +} + +func (r *RTPMunger) GetTSOffset() uint64 { + return r.tsOffset +} + +func (r *RTPMunger) SeedState(state *livekit.RTPMungerState) { + r.extLastSN = state.ExtLastSequenceNumber + r.extSecondLastSN = state.ExtSecondLastSequenceNumber + r.extLastTS = state.ExtLastTimestamp + r.extSecondLastTS = state.ExtSecondLastTimestamp + r.lastMarker = state.LastMarker + r.secondLastMarker = state.SecondLastMarker +} + +func (r *RTPMunger) SetLastSnTs(extPkt *buffer.ExtPacket) { + r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1 + + r.extLastSN = extPkt.ExtSequenceNumber + r.extSecondLastSN = r.extLastSN - 1 + r.snRangeMap.ClearAndResetValue(extPkt.ExtSequenceNumber, 0) + r.updateSnOffset() + + r.extLastTS = extPkt.ExtTimestamp + r.extSecondLastTS = extPkt.ExtTimestamp + r.tsOffset = 0 +} + +func (r *RTPMunger) UpdateSnTsOffsets(extPkt *buffer.ExtPacket, snAdjust uint64, tsAdjust uint64) { + r.extHighestIncomingSN = extPkt.ExtSequenceNumber - 1 + + r.snRangeMap.ClearAndResetValue(extPkt.ExtSequenceNumber, extPkt.ExtSequenceNumber-r.extLastSN-snAdjust) + r.updateSnOffset() + + r.tsOffset = extPkt.ExtTimestamp - r.extLastTS - tsAdjust +} + +func (r *RTPMunger) PacketDropped(extPkt *buffer.ExtPacket) { + if r.extHighestIncomingSN != extPkt.ExtSequenceNumber { + return + } + + snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber) + if err == nil { + outSN := extPkt.ExtSequenceNumber - snOffset + if outSN != r.extLastSN { + r.logger.Warnw("last outgoing sequence number mismatch", nil, "expected", r.extLastSN, "got", outSN) + } + } + if r.extLastSN == r.extSecondLastSN { + r.logger.Warnw("cannot roll back on drop", nil, "extLastSN", r.extLastSN, "secondLastSN", r.extSecondLastSN) + } + + if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil { + r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN) + } + + r.extLastSN = r.extSecondLastSN + r.updateSnOffset() + + r.extLastTS = r.extSecondLastTS + r.lastMarker = r.secondLastMarker +} + +func (r *RTPMunger) UpdateAndGetSnTs(extPkt *buffer.ExtPacket, marker bool) (TranslationParamsRTP, error) { + diff := int64(extPkt.ExtSequenceNumber - r.extHighestIncomingSN) + if (diff == 1 && len(extPkt.Packet.Payload) != 0) || diff > 1 { + // in-order - either contiguous packet with payload OR packet following a gap, may or may not have payload + r.extHighestIncomingSN = extPkt.ExtSequenceNumber + + ordering := SequenceNumberOrderingContiguous + if diff > 1 { + ordering = SequenceNumberOrderingGap + } + + extMungedSN := extPkt.ExtSequenceNumber - r.snOffset + extMungedTS := extPkt.ExtTimestamp - r.tsOffset + + r.extSecondLastSN = r.extLastSN + r.extLastSN = extMungedSN + r.extSecondLastTS = r.extLastTS + r.extLastTS = extMungedTS + r.secondLastMarker = r.lastMarker + r.lastMarker = marker + + if extPkt.IsKeyFrame { + r.extRtxGateSn = extMungedSN + r.isInRtxGateRegion = true + } + + if r.isInRtxGateRegion && (extMungedSN-r.extRtxGateSn) > RtxGateWindow { + r.isInRtxGateRegion = false + } + + return TranslationParamsRTP{ + snOrdering: ordering, + extSequenceNumber: extMungedSN, + extTimestamp: extMungedTS, + }, nil + } + + if diff < 0 { + // out-of-order, look up sequence number offset cache + snOffset, err := r.snRangeMap.GetValue(extPkt.ExtSequenceNumber) + if err != nil { + return TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + }, errOutOfOrderSequenceNumberCacheMiss + } + + extSequenceNumber := extPkt.ExtSequenceNumber - snOffset + if extSequenceNumber >= r.extLastSN { + // should not happen, just being paranoid + r.logger.Errorw( + "unexpected packet ordering", nil, + "extIncomingSN", extPkt.ExtSequenceNumber, + "extHighestIncomingSN", r.extHighestIncomingSN, + "extLastSN", r.extLastSN, + "snOffsetIncoming", snOffset, + "snOffsetHighest", r.snOffset, + ) + return TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + }, errOutOfOrderSequenceNumberCacheMiss + } + + return TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: extSequenceNumber, + extTimestamp: extPkt.ExtTimestamp - r.tsOffset, + }, nil + } + + // if padding only packet, can be dropped and sequence number adjusted, if contiguous + if diff == 1 { + r.extHighestIncomingSN = extPkt.ExtSequenceNumber + + if err := r.snRangeMap.ExcludeRange(r.extHighestIncomingSN, r.extHighestIncomingSN+1); err != nil { + r.logger.Errorw("could not exclude range", err, "sn", r.extHighestIncomingSN) + } + + r.updateSnOffset() + + return TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + }, errPaddingOnlyPacket + } + + // can get duplicate packet due to FEC + return TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingDuplicate, + }, errDuplicatePacket +} + +func (r *RTPMunger) FilterRTX(nacks []uint16) []uint16 { + if !r.isInRtxGateRegion { + return nacks + } + + filtered := make([]uint16, 0, len(nacks)) + for _, sn := range nacks { + if (sn - uint16(r.extRtxGateSn)) < (1 << 15) { + filtered = append(filtered, sn) + } + } + + return filtered +} + +func (r *RTPMunger) UpdateAndGetPaddingSnTs( + num int, + clockRate uint32, + frameRate uint32, + forceMarker bool, + extRtpTimestamp uint64, +) ([]SnTs, error) { + if num == 0 { + return nil, nil + } + + useLastTSForFirst := false + tsOffset := 0 + if !r.lastMarker { + if !forceMarker { + return nil, errPaddingNotOnFrameBoundary + } + + // if forcing frame end, use timestamp of latest received frame for the first one + useLastTSForFirst = true + tsOffset = 1 + } + + extLastSN := r.extLastSN + extLastTS := r.extLastTS + vals := make([]SnTs, num) + for i := range num { + extLastSN++ + vals[i].extSequenceNumber = extLastSN + + if frameRate != 0 { + if useLastTSForFirst && i == 0 { + vals[i].extTimestamp = r.extLastTS + } else { + ets := extRtpTimestamp + uint64(((uint32(i+1-tsOffset)*clockRate)+frameRate-1)/frameRate) + if int64(ets-extLastTS) <= 0 { + ets = extLastTS + 1 + } + extLastTS = ets + vals[i].extTimestamp = ets + } + } else { + vals[i].extTimestamp = r.extLastTS + } + } + + r.extSecondLastSN = extLastSN - 1 + r.extLastSN = extLastSN + r.snRangeMap.DecValue(r.extHighestIncomingSN, uint64(num)) + r.updateSnOffset() + + if len(vals) == 1 { + r.extSecondLastTS = r.extLastTS + } else { + r.extSecondLastTS = vals[len(vals)-2].extTimestamp + } + r.tsOffset -= extLastTS - r.extLastTS + r.extLastTS = extLastTS + + r.secondLastMarker = r.lastMarker + if forceMarker { + r.lastMarker = true + } + + return vals, nil +} + +func (r *RTPMunger) IsOnFrameBoundary() bool { + return r.lastMarker +} + +func (r *RTPMunger) updateSnOffset() { + snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) + if err != nil { + r.logger.Errorw("could not get sequence number offset", err) + } + r.snOffset = snOffset +} diff --git a/livekit/pkg/sfu/rtpmunger_test.go b/livekit/pkg/sfu/rtpmunger_test.go new file mode 100644 index 0000000..e6d62c3 --- /dev/null +++ b/livekit/pkg/sfu/rtpmunger_test.go @@ -0,0 +1,585 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu/testutils" +) + +func newRTPMunger() *RTPMunger { + return NewRTPMunger(logger.GetLogger()) +} + +func TestSetLastSnTs(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, err := testutils.GetTestExtPacket(params) + require.NoError(t, err) + require.NotNil(t, extPkt) + + r.SetLastSnTs(extPkt) + require.Equal(t, uint64(23332), r.extHighestIncomingSN) + require.Equal(t, uint64(23333), r.extLastSN) + require.Equal(t, uint64(0xabcdef), r.extLastTS) + snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + snOffset, err = r.snRangeMap.GetValue(r.extLastSN) + require.NoError(t, err) + require.Equal(t, uint64(0), snOffset) + require.Equal(t, uint64(0), r.snOffset) + require.Equal(t, uint64(0), r.tsOffset) +} + +func TestUpdateSnTsOffsets(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 33333, + Timestamp: 0xabcdef, + SSRC: 0x87654321, + } + extPkt, _ = testutils.GetTestExtPacket(params) + r.UpdateSnTsOffsets(extPkt, 1, 1) + require.Equal(t, uint64(33332), r.extHighestIncomingSN) + require.Equal(t, uint64(23333), r.extLastSN) + require.Equal(t, uint64(0xabcdef), r.extLastTS) + _, err := r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + _, err = r.snRangeMap.GetValue(r.extLastSN) + require.Error(t, err) + require.Equal(t, uint64(9999), r.snOffset) + require.Equal(t, uint64(0xffff_ffff_ffff_ffff), r.tsOffset) +} + +func TestPacketDropped(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 10, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + require.Equal(t, uint64(23332), r.extHighestIncomingSN) + require.Equal(t, uint64(23333), r.extLastSN) + require.Equal(t, uint64(0xabcdef), r.extLastTS) + snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + snOffset, err = r.snRangeMap.GetValue(r.extLastSN) + require.NoError(t, err) + require.Equal(t, uint64(0), snOffset) + require.Equal(t, uint64(0), r.tsOffset) + + r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) // update sequence number offset + + // drop a non-head packet, should cause no change in internals + params = &testutils.TestExtPacketParams{ + SequenceNumber: 33333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + r.PacketDropped(extPkt) + require.Equal(t, uint64(23333), r.extHighestIncomingSN) + require.Equal(t, uint64(23333), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(0), snOffset) + + // drop a head packet and check offset increases + params = &testutils.TestExtPacketParams{ + SequenceNumber: 44444, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) // update sequence number offset + require.Equal(t, uint64(44444), r.extLastSN) + + r.PacketDropped(extPkt) + require.Equal(t, uint64(23333), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN + 1) + require.NoError(t, err) + require.Equal(t, uint64(1), snOffset) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 44445, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) // update sequence number offset + require.Equal(t, r.extLastSN, uint64(44444)) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(1), snOffset) +} + +func TestOutOfOrderSequenceNumber(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 10, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + + // should not be able to add a missing sequence number to the cache that is before start + err := r.snRangeMap.ExcludeRange(23332, 23333) + require.Error(t, err) + + // out-of-order sequence number before start should miss + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23331, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 10, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tp, err := r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.Error(t, err) + + // add a missing sequence number to the cache + err = r.snRangeMap.ExcludeRange(23334, 23335) + require.NoError(t, err) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23336, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 10, + } + extPkt, _ = testutils.GetTestExtPacket(params) + r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + + // out-of-order sequence number should be munged from cache + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23335, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 10, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected := TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: 23334, + extTimestamp: 0xabcdef, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23332, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 10, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.Error(t, err, errOutOfOrderSequenceNumberCacheMiss) + require.Equal(t, tpExpected, tp) +} + +func TestDuplicateSequenceNumber(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + + // send first packet through + r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + + // send it again - duplicate packet + tpExpected := TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingDuplicate, + } + + tp, err := r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.ErrorIs(t, err, errDuplicatePacket) + require.Equal(t, tpExpected, tp) +} + +func TestPaddingOnlyPacket(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + + // contiguous padding only packet should report an error + tpExpected := TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + } + + tp, err := r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.Error(t, err) + require.ErrorIs(t, err, errPaddingOnlyPacket) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(23333), r.extHighestIncomingSN) + require.Equal(t, uint64(23333), r.extLastSN) + snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + + // padding only packet with a gap should not report an error + params = &testutils.TestExtPacketParams{ + SequenceNumber: 23335, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingGap, + extSequenceNumber: 23334, + extTimestamp: 0xabcdef, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(23335), r.extHighestIncomingSN) + require.Equal(t, uint64(23334), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(1), snOffset) +} + +func TestGapInSequenceNumber(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 65533, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 33, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + + _, err := r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + + // three lost packets + params = &testutils.TestExtPacketParams{ + SequenceNumber: 1, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 33, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected := TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingGap, + extSequenceNumber: 65536 + 1, + extTimestamp: 0xabcdef, + } + + tp, err := r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+1), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+1), r.extLastSN) + snOffset, err := r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(0), snOffset) + + // ensure missing sequence numbers have correct cached offset + for i := uint64(65534); i != 65536+1; i++ { + offset, err := r.snRangeMap.GetValue(i) + require.NoError(t, err) + require.Equal(t, uint64(0), offset) + } + + // a padding only packet should be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 2, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.ErrorIs(t, err, errPaddingOnlyPacket) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+2), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+1), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + + // a packet with a gap should be adjusting for dropped padding packet + params = &testutils.TestExtPacketParams{ + SequenceNumber: 4, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 22, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingGap, + extSequenceNumber: 65536 + 3, + extTimestamp: 0xabcdef, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+4), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+3), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(1), snOffset) + + // ensure missing sequence number has correct cached offset + offset, err := r.snRangeMap.GetValue(65536 + 3) + require.NoError(t, err) + require.Equal(t, uint64(1), offset) + + // another contiguous padding only packet should be dropped + params = &testutils.TestExtPacketParams{ + SequenceNumber: 5, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingContiguous, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.ErrorIs(t, err, errPaddingOnlyPacket) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+5), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+3), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.Error(t, err) + + // a packet with a gap should be adjusting for dropped packets + params = &testutils.TestExtPacketParams{ + SequenceNumber: 7, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 22, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingGap, + extSequenceNumber: 65536 + 5, + extTimestamp: 0xabcdef, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+7), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+5), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(2), snOffset) + + // ensure missing sequence number has correct cached offset + offset, err = r.snRangeMap.GetValue(65536 + 3) + require.NoError(t, err) + require.Equal(t, uint64(1), offset) + + offset, err = r.snRangeMap.GetValue(65536 + 6) + require.NoError(t, err) + require.Equal(t, uint64(2), offset) + + // check the missing packets + params = &testutils.TestExtPacketParams{ + SequenceNumber: 6, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: 65536 + 4, + extTimestamp: 0xabcdef, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+7), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+5), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(2), snOffset) + + params = &testutils.TestExtPacketParams{ + SequenceNumber: 3, + SNCycles: 1, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + tpExpected = TranslationParamsRTP{ + snOrdering: SequenceNumberOrderingOutOfOrder, + extSequenceNumber: 65536 + 2, + extTimestamp: 0xabcdef, + } + + tp, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.Equal(t, tpExpected, tp) + require.Equal(t, uint64(65536+7), r.extHighestIncomingSN) + require.Equal(t, uint64(65536+5), r.extLastSN) + snOffset, err = r.snRangeMap.GetValue(r.extHighestIncomingSN) + require.NoError(t, err) + require.Equal(t, uint64(2), snOffset) +} + +func TestUpdateAndGetPaddingSnTs(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + + // getting padding without forcing marker should fail + _, err := r.UpdateAndGetPaddingSnTs(10, 10, 5, false, 0) + require.Error(t, err) + require.ErrorIs(t, err, errPaddingNotOnFrameBoundary) + + // forcing a marker should not error out. + // And timestamp on first padding should be the same as the last one. + numPadding := 10 + clockRate := uint64(10) + frameRate := uint64(5) + var sntsExpected = make([]SnTs, numPadding) + for i := range numPadding { + sntsExpected[i] = SnTs{ + extSequenceNumber: uint64(params.SequenceNumber) + uint64(i) + 1, + extTimestamp: uint64(params.Timestamp) + ((uint64(i)*clockRate)+frameRate-1)/frameRate, + } + } + snts, err := r.UpdateAndGetPaddingSnTs(numPadding, uint32(clockRate), uint32(frameRate), true, extPkt.ExtTimestamp) + require.NoError(t, err) + require.Equal(t, sntsExpected, snts) + + // now that there is a marker, timestamp should jump on first padding when asked again + for i := range numPadding { + sntsExpected[i] = SnTs{ + extSequenceNumber: uint64(params.SequenceNumber) + uint64(len(snts)) + uint64(i) + 1, + extTimestamp: snts[len(snts)-1].extTimestamp + ((uint64(i+1)*clockRate)+frameRate-1)/frameRate, + } + } + snts, err = r.UpdateAndGetPaddingSnTs(numPadding, uint32(clockRate), uint32(frameRate), false, snts[len(snts)-1].extTimestamp) + require.NoError(t, err) + require.Equal(t, sntsExpected, snts) +} + +func TestIsOnFrameBoundary(t *testing.T) { + r := newRTPMunger() + + params := &testutils.TestExtPacketParams{ + SequenceNumber: 23333, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ := testutils.GetTestExtPacket(params) + r.SetLastSnTs(extPkt) + + // send it through + _, err := r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.False(t, r.IsOnFrameBoundary()) + + // packet with RTP marker + params = &testutils.TestExtPacketParams{ + Marker: true, + SequenceNumber: 23334, + Timestamp: 0xabcdef, + SSRC: 0x12345678, + PayloadSize: 20, + } + extPkt, _ = testutils.GetTestExtPacket(params) + + // send it through + _, err = r.UpdateAndGetSnTs(extPkt, extPkt.Packet.Marker) + require.NoError(t, err) + require.True(t, r.IsOnFrameBoundary()) +} diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_base.go b/livekit/pkg/sfu/rtpstats/rtpstats_base.go new file mode 100644 index 0000000..fe8adf2 --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_base.go @@ -0,0 +1,990 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "errors" + "time" + + "go.uber.org/zap/zapcore" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/mediatransportutil" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/mono" + "github.com/livekit/protocol/utils/rtputil" +) + +const ( + cFirstPacketTimeAdjustWindow = 2 * time.Minute + cFirstPacketTimeAdjustThreshold = 15 * 1e9 + + cSequenceNumberLargeJumpThreshold = 100 +) + +// ------------------------------------------------------- + +type RTPDeltaInfo struct { + StartTime time.Time + EndTime time.Time + Packets uint32 + Bytes uint64 + HeaderBytes uint64 + PacketsDuplicate uint32 + BytesDuplicate uint64 + HeaderBytesDuplicate uint64 + PacketsPadding uint32 + BytesPadding uint64 + HeaderBytesPadding uint64 + PacketsLost uint32 + PacketsMissing uint32 + PacketsOutOfOrder uint32 + Frames uint32 + RttMax uint32 + JitterMax float64 + Nacks uint32 + NackRepeated uint32 + Plis uint32 + Firs uint32 +} + +func (r *RTPDeltaInfo) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + e.AddTime("StartTime", r.StartTime) + e.AddTime("EndTime", r.EndTime) + e.AddUint32("Packets", r.Packets) + e.AddUint64("Bytes", r.Bytes) + e.AddUint64("HeaderBytes", r.HeaderBytes) + e.AddUint32("PacketsDuplicate", r.PacketsDuplicate) + e.AddUint64("BytesDuplicate", r.BytesDuplicate) + e.AddUint64("HeaderBytesDuplicate", r.HeaderBytesDuplicate) + e.AddUint32("PacketsPadding", r.PacketsPadding) + e.AddUint64("BytesPadding", r.BytesPadding) + e.AddUint64("HeaderBytesPadding", r.HeaderBytesPadding) + e.AddUint32("PacketsLost", r.PacketsLost) + e.AddUint32("PacketsMissing", r.PacketsMissing) + e.AddUint32("PacketsOutOfOrder", r.PacketsOutOfOrder) + e.AddUint32("Frames", r.Frames) + e.AddUint32("RttMax", r.RttMax) + e.AddFloat64("JitterMax", r.JitterMax) + e.AddUint32("Nacks", r.Nacks) + e.AddUint32("NackRepeated", r.NackRepeated) + e.AddUint32("Plis", r.Plis) + e.AddUint32("Firs", r.Firs) + return nil +} + +// ------------------------------------------------------- + +type snapshot struct { + snapshotLite + + headerBytes uint64 + + packetsDuplicate uint64 + bytesDuplicate uint64 + headerBytesDuplicate uint64 + + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + + frames uint32 + + plis uint32 + firs uint32 + + maxRtt uint32 + maxJitter float64 +} + +func (s *snapshot) MarshalLogObject(e zapcore.ObjectEncoder) error { + if s == nil { + return nil + } + + e.AddObject("snapshotLite", &s.snapshotLite) + e.AddUint64("headerBytes", s.headerBytes) + e.AddUint64("packetsDuplicate", s.packetsDuplicate) + e.AddUint64("bytesDuplicate", s.bytesDuplicate) + e.AddUint64("headerBytesDuplicate", s.headerBytesDuplicate) + e.AddUint64("packetsPadding", s.packetsPadding) + e.AddUint64("bytesPadding", s.bytesPadding) + e.AddUint64("headerBytesPadding", s.headerBytesPadding) + e.AddUint32("frames", s.frames) + e.AddUint32("plis", s.plis) + e.AddUint32("firs", s.firs) + e.AddUint32("maxRtt", s.maxRtt) + e.AddFloat64("maxJitter", s.maxJitter) + return nil +} + +func (s *snapshot) maybeUpdateMaxRTT(rtt uint32) { + if rtt > s.maxRtt { + s.maxRtt = rtt + } +} + +func (s *snapshot) maybeUpdateMaxJitter(jitter float64) { + if jitter > s.maxJitter { + s.maxJitter = jitter + } +} + +// ------------------------------------------------------------------ + +type wrappedRTPDriftLogger struct { + *livekit.RTPDrift +} + +func (w wrappedRTPDriftLogger) MarshalLogObject(e zapcore.ObjectEncoder) error { + rd := w.RTPDrift + if rd == nil { + return nil + } + + e.AddTime("StartTime", rd.StartTime.AsTime()) + e.AddTime("EndTime", rd.EndTime.AsTime()) + e.AddFloat64("Duration", rd.Duration) + e.AddUint64("StartTimestamp", rd.StartTimestamp) + e.AddUint64("EndTimestamp", rd.EndTimestamp) + e.AddUint64("RtpClockTicks", rd.RtpClockTicks) + e.AddInt64("DriftSamples", rd.DriftSamples) + e.AddFloat64("DriftMs", rd.DriftMs) + e.AddFloat64("ClockRate", rd.ClockRate) + return nil +} + +// ------------------------------------------------------------------ + +type WrappedRTCPSenderReportStateLogger struct { + *livekit.RTCPSenderReportState +} + +func (w WrappedRTCPSenderReportStateLogger) MarshalLogObject(e zapcore.ObjectEncoder) error { + rsrs := w.RTCPSenderReportState + if rsrs == nil { + return nil + } + + e.AddUint32("RtpTimestamp", rsrs.RtpTimestamp) + e.AddUint64("RtpTimestampExt", rsrs.RtpTimestampExt) + e.AddTime("NtpTimestamp", mediatransportutil.NtpTime(rsrs.NtpTimestamp).Time()) + e.AddTime("At", time.Unix(0, rsrs.At)) + e.AddTime("AtAdjusted", time.Unix(0, rsrs.AtAdjusted)) + e.AddUint32("Packets", rsrs.Packets) + e.AddUint64("Octets", rsrs.Octets) + return nil +} + +// ------------------------------------------------------------------ + +func RTCPSenderReportPropagationDelay(rsrs *livekit.RTCPSenderReportState, passThrough bool) time.Duration { + if passThrough { + return 0 + } + + return time.Unix(0, rsrs.AtAdjusted).Sub(mediatransportutil.NtpTime(rsrs.NtpTimestamp).Time()) +} + +// ------------------------------------------------------------------ + +type rtpStatsBase struct { + *rtpStatsBaseLite + + rtpConverter *rtputil.RTPConverter + + firstTime int64 + firstTimeAdjustment time.Duration + highestTime int64 + + lastTransit uint64 + lastJitterExtTimestamp uint64 + + headerBytes uint64 + + packetsDuplicate uint64 + bytesDuplicate uint64 + headerBytesDuplicate uint64 + + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + + frames uint32 + + jitter float64 + maxJitter float64 + + firs uint32 + lastFir time.Time + + keyFrames uint32 + lastKeyFrame time.Time + + rtt uint32 + maxRtt uint32 + + srFirst *livekit.RTCPSenderReportState + srNewest *livekit.RTCPSenderReportState + + nextSnapshotID uint32 + snapshots []snapshot +} + +func newRTPStatsBase(params RTPStatsParams) *rtpStatsBase { + return &rtpStatsBase{ + rtpStatsBaseLite: newRTPStatsBaseLite(params), + rtpConverter: rtputil.NewRTPConverter(int64(params.ClockRate)), + nextSnapshotID: cFirstSnapshotID, + snapshots: make([]snapshot, 2), + } +} + +func (r *rtpStatsBase) seed(from *rtpStatsBase) bool { + if !r.rtpStatsBaseLite.seed(from.rtpStatsBaseLite) { + return false + } + + r.firstTime = from.firstTime + r.firstTimeAdjustment = from.firstTimeAdjustment + r.highestTime = from.highestTime + + r.lastTransit = from.lastTransit + r.lastJitterExtTimestamp = from.lastJitterExtTimestamp + + r.headerBytes = from.headerBytes + + r.packetsDuplicate = from.packetsDuplicate + r.bytesDuplicate = from.bytesDuplicate + r.headerBytesDuplicate = from.headerBytesDuplicate + + r.packetsPadding = from.packetsPadding + r.bytesPadding = from.bytesPadding + r.headerBytesPadding = from.headerBytesPadding + + r.frames = from.frames + + r.jitter = from.jitter + r.maxJitter = from.maxJitter + + r.firs = from.firs + r.lastFir = from.lastFir + + r.keyFrames = from.keyFrames + r.lastKeyFrame = from.lastKeyFrame + + r.rtt = from.rtt + r.maxRtt = from.maxRtt + + r.srFirst = utils.CloneProto(from.srFirst) + r.srNewest = utils.CloneProto(from.srNewest) + + r.nextSnapshotID = from.nextSnapshotID + r.snapshots = make([]snapshot, cap(from.snapshots)) + copy(r.snapshots, from.snapshots) + return true +} + +func (r *rtpStatsBase) newSnapshotID(extStartSN uint64) uint32 { + id := r.nextSnapshotID + r.nextSnapshotID++ + + if cap(r.snapshots) < int(r.nextSnapshotID-cFirstSnapshotID) { + snapshots := make([]snapshot, r.nextSnapshotID-cFirstSnapshotID) + copy(snapshots, r.snapshots) + r.snapshots = snapshots + } + + if r.initialized { + r.snapshots[id-cFirstSnapshotID] = initSnapshot(mono.UnixNano(), extStartSN) + } + return id +} + +func (r *rtpStatsBase) UpdateFir(firCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.firs += firCount +} + +func (r *rtpStatsBase) UpdateFirTime() { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.lastFir = time.Now() +} + +func (r *rtpStatsBase) UpdateKeyFrame(kfCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.keyFrames += kfCount + r.lastKeyFrame = time.Now() +} + +func (r *rtpStatsBase) KeyFrame() (uint32, time.Time) { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.keyFrames, r.lastKeyFrame +} + +func (r *rtpStatsBase) UpdateRtt(rtt uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.rtt = rtt + if rtt > r.maxRtt { + r.maxRtt = rtt + } + + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] + if rtt > s.maxRtt { + s.maxRtt = rtt + } + } +} + +func (r *rtpStatsBase) GetRtt() uint32 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.rtt +} + +func (r *rtpStatsBase) maybeAdjustFirstPacketTime( + srData *livekit.RTCPSenderReportState, + tsOffset uint64, + extStartTS uint64, +) (adjustment int64, err error, loggingFields []any) { + nowNano := mono.UnixNano() + if time.Duration(nowNano-r.startTime) > cFirstPacketTimeAdjustWindow { + return + } + + // for some time after the start, adjust time of first packet. + // Helps improve accuracy of expected timestamp calculation. + // Adjusting only one way, i. e. if the first sample experienced + // abnormal delay (maybe due to pacing or maybe due to queuing + // in some network element along the way), push back first time + // to an earlier instance. + timeSinceReceive := time.Duration(nowNano - srData.AtAdjusted) + extNowTS := srData.RtpTimestampExt - tsOffset + r.rtpConverter.ToRTPExt(timeSinceReceive) + samplesDiff := int64(extNowTS - extStartTS) + if samplesDiff < 0 { + // out-of-order, skip + return + } + + samplesDuration := r.rtpConverter.ToDurationExt(uint64(samplesDiff)) + timeSinceFirst := time.Duration(nowNano - r.firstTime) + now := r.firstTime + timeSinceFirst.Nanoseconds() + firstTime := now - samplesDuration.Nanoseconds() + adjustment = r.firstTime - firstTime + + getFields := func() []any { + return []any{ + "startTime", time.Unix(0, r.startTime), + "nowTime", time.Unix(0, now), + "before", time.Unix(0, r.firstTime), + "after", time.Unix(0, firstTime), + "adjustment", time.Duration(adjustment), + "firstTimeAdjustment", r.firstTimeAdjustment, + "extNowTS", extNowTS, + "extStartTS", extStartTS, + "srData", WrappedRTCPSenderReportStateLogger{srData}, + "tsOffset", tsOffset, + "timeSinceReceive", timeSinceReceive, + "timeSinceFirst", timeSinceFirst, + "samplesDiff", samplesDiff, + "samplesDuration", samplesDuration, + } + } + + if firstTime < r.firstTime { + if adjustment > cFirstPacketTimeAdjustThreshold { + err = errors.New("adjusting first packet time, too big, ignoring") + loggingFields = getFields() + } else { + r.firstTimeAdjustment += time.Duration(adjustment) + r.logger.Debugw("adjusting first packet time", getFields()...) + r.firstTime = firstTime + } + } + + return +} + +func (r *rtpStatsBase) getPacketsSeenMinusPadding(extStartSN, extHighestSN uint64) uint64 { + packetsSeen := r.getPacketsSeen(extStartSN, extHighestSN) + if r.packetsPadding > packetsSeen { + return 0 + } + + return packetsSeen - r.packetsPadding +} + +func (r *rtpStatsBase) getPacketsSeenPlusDuplicates(extStartSN, extHighestSN uint64) uint64 { + return r.getPacketsSeen(extStartSN, extHighestSN) + r.packetsDuplicate +} + +func (r *rtpStatsBase) deltaInfo( + snapshotID uint32, + extStartSN uint64, + extHighestSN uint64, +) (deltaInfo *RTPDeltaInfo, err error, loggingFields []any) { + then, now := r.getAndResetSnapshot(snapshotID, extStartSN, extHighestSN) + if now == nil || then == nil { + return + } + + startTime := then.startTime + endTime := now.startTime + + packetsExpected := now.extStartSN - then.extStartSN + if then.extStartSN > extHighestSN { + packetsExpected = 0 + } + if packetsExpected > cNumSequenceNumbers { + loggingFields = []any{ + "snapshotID", snapshotID, + "snapshotNow", now, + "snapshotThen", then, + "duration", time.Duration(endTime - startTime), + "packetsExpected", packetsExpected, + } + err = errors.New("too many packets expected in delta") + return + } + if packetsExpected == 0 { + deltaInfo = &RTPDeltaInfo{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + } + return + } + + packetsLost := uint32(now.packetsLost - then.packetsLost) + if int32(packetsLost) < 0 { + packetsLost = 0 + } + + // padding packets delta could be higher than expected due to out-of-order padding packets + packetsPadding := now.packetsPadding - then.packetsPadding + if packetsExpected < packetsPadding { + loggingFields = []any{ + "snapshotID", snapshotID, + "snapshotNow", now, + "snapshotThen", then, + "duration", time.Duration(endTime - startTime), + "packetsExpected", packetsExpected, + "packetsPadding", packetsPadding, + "packetsLost", packetsLost, + } + err = errors.New("padding packets more than expected") + packetsExpected = 0 + } else { + packetsExpected -= packetsPadding + } + + deltaInfo = &RTPDeltaInfo{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + Packets: uint32(packetsExpected), + Bytes: now.bytes - then.bytes, + HeaderBytes: now.headerBytes - then.headerBytes, + PacketsDuplicate: uint32(now.packetsDuplicate - then.packetsDuplicate), + BytesDuplicate: now.bytesDuplicate - then.bytesDuplicate, + HeaderBytesDuplicate: now.headerBytesDuplicate - then.headerBytesDuplicate, + PacketsPadding: uint32(packetsPadding), + BytesPadding: now.bytesPadding - then.bytesPadding, + HeaderBytesPadding: now.headerBytesPadding - then.headerBytesPadding, + PacketsLost: packetsLost, + PacketsOutOfOrder: uint32(now.packetsOutOfOrder - then.packetsOutOfOrder), + Frames: now.frames - then.frames, + RttMax: then.maxRtt, + JitterMax: then.maxJitter / float64(r.params.ClockRate) * 1e6, + Nacks: now.nacks - then.nacks, + Plis: now.plis - then.plis, + Firs: now.firs - then.firs, + } + return +} + +func (r *rtpStatsBase) marshalLogObject( + e zapcore.ObjectEncoder, + packetsExpected, packetsSeenMinusPadding uint64, + extStartTS, extHighestTS uint64, +) (float64, error) { + if r == nil { + return 0, nil + } + + elapsedSeconds, err := r.rtpStatsBaseLite.marshalLogObject(e, packetsExpected, packetsSeenMinusPadding) + if err != nil { + return 0, err + } + + e.AddTime("firstTime", time.Unix(0, r.firstTime)) + e.AddDuration("firstTimeAdjustment", r.firstTimeAdjustment) + e.AddTime("highestTime", time.Unix(0, r.highestTime)) + + e.AddUint64("headerBytes", r.headerBytes) + + e.AddUint64("packetsDuplicate", r.packetsDuplicate) + e.AddFloat64("packetsDuplicateRate", float64(r.packetsDuplicate)/elapsedSeconds) + e.AddUint64("bytesDuplicate", r.bytesDuplicate) + e.AddFloat64("bitrateDuplicate", float64(r.bytesDuplicate)*8.0/elapsedSeconds) + e.AddUint64("headerBytesDuplicate", r.headerBytesDuplicate) + + e.AddUint64("packetsPadding", r.packetsPadding) + e.AddFloat64("packetsPaddingRate", float64(r.packetsPadding)/elapsedSeconds) + e.AddUint64("bytesPadding", r.bytesPadding) + e.AddFloat64("bitratePadding", float64(r.bytesPadding)*8.0/elapsedSeconds) + e.AddUint64("headerBytesPadding", r.headerBytesPadding) + + e.AddUint32("frames", r.frames) + e.AddFloat64("frameRate", float64(r.frames)/elapsedSeconds) + + e.AddFloat64("jitter", r.jitter) + e.AddFloat64("maxJitter", r.maxJitter) + + e.AddUint32("firs", r.firs) + e.AddTime("lastFir", r.lastFir) + + e.AddUint32("keyFrames", r.keyFrames) + e.AddTime("lastKeyFrame", r.lastKeyFrame) + + e.AddUint32("rtt", r.rtt) + e.AddUint32("maxRtt", r.maxRtt) + + e.AddObject("srFirst", WrappedRTCPSenderReportStateLogger{r.srFirst}) + e.AddObject("srNewest", WrappedRTCPSenderReportStateLogger{r.srNewest}) + + packetDrift, ntpReportDrift, receivedReportDrift, rebasedReportDrift := r.getDrift(extStartTS, extHighestTS) + e.AddObject("packetDrift", wrappedRTPDriftLogger{packetDrift}) + e.AddObject("ntpReportDrift", wrappedRTPDriftLogger{ntpReportDrift}) + e.AddObject("receivedReportDrift", wrappedRTPDriftLogger{receivedReportDrift}) + e.AddObject("rebasedReportDrift", wrappedRTPDriftLogger{rebasedReportDrift}) + return elapsedSeconds, nil +} + +func (r *rtpStatsBase) toProto( + packetsExpected, packetsSeenMinusPadding, packetsLost uint64, + extStartTS, extHighestTS uint64, + jitter, maxJitter float64, +) *livekit.RTPStats { + p := r.rtpStatsBaseLite.toProto(packetsExpected, packetsSeenMinusPadding, packetsLost) + if p == nil { + return nil + } + + p.HeaderBytes = r.headerBytes + + p.PacketsDuplicate = uint32(r.packetsDuplicate) + p.PacketDuplicateRate = float64(r.packetsDuplicate) / p.Duration + p.BytesDuplicate = r.bytesDuplicate + p.BitrateDuplicate = float64(r.bytesDuplicate) * 8.0 / p.Duration + p.HeaderBytesDuplicate = r.headerBytesDuplicate + + p.PacketsPadding = uint32(r.packetsPadding) + p.PacketPaddingRate = float64(r.packetsPadding) / p.Duration + p.BytesPadding = r.bytesPadding + p.BitratePadding = float64(r.bytesPadding) * 8.0 / p.Duration + p.HeaderBytesPadding = r.headerBytesPadding + + p.Frames = r.frames + p.FrameRate = float64(r.frames) / p.Duration + + p.KeyFrames = r.keyFrames + p.LastKeyFrame = timestamppb.New(r.lastKeyFrame) + + p.JitterCurrent = jitter / float64(r.params.ClockRate) * 1e6 + p.JitterMax = maxJitter / float64(r.params.ClockRate) * 1e6 + + p.Firs = r.firs + p.LastFir = timestamppb.New(r.lastFir) + + p.RttCurrent = r.rtt + p.RttMax = r.maxRtt + + p.PacketDrift, p.NtpReportDrift, p.ReceivedReportDrift, p.RebasedReportDrift = r.getDrift(extStartTS, extHighestTS) + return p +} + +func (r *rtpStatsBase) updateJitter(ets uint64, packetTime int64) float64 { + // Do not update jitter on multiple packets of same frame. + // All packets of a frame have the same time stamp. + // NOTE: This does not protect against using more than one packet of the same frame + // if packets arrive out-of-order. For example, + // p1f1 -> p1f2 -> p2f1 + // In this case, p2f1 (packet 2, frame 1) will still be used in jitter calculation + // although it is the second packet of a frame because of out-of-order receival. + if r.lastJitterExtTimestamp != ets { + timeSinceFirst := packetTime - r.firstTime + packetTimeRTP := r.rtpConverter.ToRTPExt(time.Duration(timeSinceFirst)) + transit := packetTimeRTP - ets + + if r.lastTransit != 0 { + d := int64(transit - r.lastTransit) + if d < 0 { + d = -d + } + r.jitter += (float64(d) - r.jitter) / 16 + if r.jitter > r.maxJitter { + r.maxJitter = r.jitter + } + + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + r.snapshots[i].maybeUpdateMaxJitter(r.jitter) + } + } + + r.lastTransit = transit + r.lastJitterExtTimestamp = ets + } + return r.jitter +} + +func (r *rtpStatsBase) getAndResetSnapshot(snapshotID uint32, extStartSN uint64, extHighestSN uint64) (*snapshot, *snapshot) { + if !r.initialized || snapshotID < cFirstSnapshotID { + return nil, nil + } + + idx := snapshotID - cFirstSnapshotID + then := r.snapshots[idx] + if !then.isValid { + then = initSnapshot(r.startTime, extStartSN) + r.snapshots[idx] = then + } + + // snapshot now + now := r.getSnapshot(mono.UnixNano(), extHighestSN+1) + r.snapshots[idx] = now + return &then, &now +} + +func (r *rtpStatsBase) getDrift(extStartTS, extHighestTS uint64) ( + packetDrift *livekit.RTPDrift, + ntpReportDrift *livekit.RTPDrift, + receivedReportDrift *livekit.RTPDrift, + rebasedReportDrift *livekit.RTPDrift, +) { + if r.firstTime != 0 { + elapsed := time.Duration(r.highestTime - r.firstTime) + rtpClockTicks := extHighestTS - extStartTS + driftSamples := int64(rtpClockTicks - r.rtpConverter.ToRTPExt(elapsed)) + if elapsed > 0 { + packetDrift = &livekit.RTPDrift{ + StartTime: timestamppb.New(time.Unix(0, r.firstTime)), + EndTime: timestamppb.New(time.Unix(0, r.highestTime)), + Duration: elapsed.Seconds(), + StartTimestamp: extStartTS, + EndTimestamp: extHighestTS, + RtpClockTicks: rtpClockTicks, + DriftSamples: driftSamples, + DriftMs: (float64(driftSamples) * 1000) / float64(r.params.ClockRate), + ClockRate: float64(rtpClockTicks) / elapsed.Seconds(), + } + } + } + + if r.srFirst != nil && r.srNewest != nil && r.srFirst.RtpTimestamp != r.srNewest.RtpTimestamp { + rtpClockTicks := r.srNewest.RtpTimestampExt - r.srFirst.RtpTimestampExt + + elapsed := mediatransportutil.NtpTime(r.srNewest.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(r.srFirst.NtpTimestamp).Time()) + if elapsed.Seconds() > 0.0 { + driftSamples := int64(rtpClockTicks - r.rtpConverter.ToRTPExt(elapsed)) + ntpReportDrift = &livekit.RTPDrift{ + StartTime: timestamppb.New(mediatransportutil.NtpTime(r.srFirst.NtpTimestamp).Time()), + EndTime: timestamppb.New(mediatransportutil.NtpTime(r.srNewest.NtpTimestamp).Time()), + Duration: elapsed.Seconds(), + StartTimestamp: r.srFirst.RtpTimestampExt, + EndTimestamp: r.srNewest.RtpTimestampExt, + RtpClockTicks: rtpClockTicks, + DriftSamples: driftSamples, + DriftMs: (float64(driftSamples) * 1000) / float64(r.params.ClockRate), + ClockRate: float64(rtpClockTicks) / elapsed.Seconds(), + } + } + + elapsed = time.Duration(r.srNewest.At - r.srFirst.At) + if elapsed.Seconds() > 0.0 { + driftSamples := int64(rtpClockTicks - r.rtpConverter.ToRTPExt(elapsed)) + receivedReportDrift = &livekit.RTPDrift{ + StartTime: timestamppb.New(time.Unix(0, r.srFirst.At)), + EndTime: timestamppb.New(time.Unix(0, r.srNewest.At)), + Duration: elapsed.Seconds(), + StartTimestamp: r.srFirst.RtpTimestampExt, + EndTimestamp: r.srNewest.RtpTimestampExt, + RtpClockTicks: rtpClockTicks, + DriftSamples: driftSamples, + DriftMs: (float64(driftSamples) * 1000) / float64(r.params.ClockRate), + ClockRate: float64(rtpClockTicks) / elapsed.Seconds(), + } + } + + elapsed = time.Duration(r.srNewest.AtAdjusted - r.srFirst.AtAdjusted) + if elapsed.Seconds() > 0.0 { + driftSamples := int64(rtpClockTicks - r.rtpConverter.ToRTPExt(elapsed)) + rebasedReportDrift = &livekit.RTPDrift{ + StartTime: timestamppb.New(time.Unix(0, r.srFirst.AtAdjusted)), + EndTime: timestamppb.New(time.Unix(0, r.srNewest.AtAdjusted)), + Duration: elapsed.Seconds(), + StartTimestamp: r.srFirst.RtpTimestampExt, + EndTimestamp: r.srNewest.RtpTimestampExt, + RtpClockTicks: rtpClockTicks, + DriftSamples: driftSamples, + DriftMs: (float64(driftSamples) * 1000) / float64(r.params.ClockRate), + ClockRate: float64(rtpClockTicks) / elapsed.Seconds(), + } + } + } + return +} + +func (r *rtpStatsBase) updateGapHistogram(gap int) { + if gap < 2 { + return + } + + missing := gap - 1 + if missing > len(r.gapHistogram) { + r.gapHistogram[len(r.gapHistogram)-1]++ + } else { + r.gapHistogram[missing-1]++ + } +} + +func (r *rtpStatsBase) getSnapshot(startTime int64, extStartSN uint64) snapshot { + return snapshot{ + snapshotLite: r.getSnapshotLite(startTime, extStartSN), + headerBytes: r.headerBytes, + packetsDuplicate: r.packetsDuplicate, + bytesDuplicate: r.bytesDuplicate, + headerBytesDuplicate: r.headerBytesDuplicate, + packetsPadding: r.packetsPadding, + bytesPadding: r.bytesPadding, + headerBytesPadding: r.headerBytesPadding, + frames: r.frames, + plis: r.plis, + firs: r.firs, + maxRtt: r.rtt, + maxJitter: r.jitter, + } +} + +// ---------------------------------- + +func initSnapshot(startTime int64, extStartSN uint64) snapshot { + return snapshot{ + snapshotLite: initSnapshotLite(startTime, extStartSN), + } +} + +func AggregateRTPStats(statsList []*livekit.RTPStats) *livekit.RTPStats { + return utils.AggregateRTPStats(statsList, cGapHistogramNumBins) +} + +func AggregateRTPDeltaInfo(deltaInfoList []*RTPDeltaInfo) *RTPDeltaInfo { + if len(deltaInfoList) == 0 { + return nil + } + + startTime := int64(0) + endTime := int64(0) + + packets := uint32(0) + bytes := uint64(0) + headerBytes := uint64(0) + + packetsDuplicate := uint32(0) + bytesDuplicate := uint64(0) + headerBytesDuplicate := uint64(0) + + packetsPadding := uint32(0) + bytesPadding := uint64(0) + headerBytesPadding := uint64(0) + + packetsLost := uint32(0) + packetsMissing := uint32(0) + packetsOutOfOrder := uint32(0) + + frames := uint32(0) + + maxRtt := uint32(0) + maxJitter := float64(0) + + nacks := uint32(0) + plis := uint32(0) + firs := uint32(0) + + for _, deltaInfo := range deltaInfoList { + if deltaInfo == nil { + continue + } + + if startTime == 0 || startTime > deltaInfo.StartTime.UnixNano() { + startTime = deltaInfo.StartTime.UnixNano() + } + + if endTime == 0 || endTime < deltaInfo.EndTime.UnixNano() { + endTime = deltaInfo.EndTime.UnixNano() + } + + packets += deltaInfo.Packets + bytes += deltaInfo.Bytes + headerBytes += deltaInfo.HeaderBytes + + packetsDuplicate += deltaInfo.PacketsDuplicate + bytesDuplicate += deltaInfo.BytesDuplicate + headerBytesDuplicate += deltaInfo.HeaderBytesDuplicate + + packetsPadding += deltaInfo.PacketsPadding + bytesPadding += deltaInfo.BytesPadding + headerBytesPadding += deltaInfo.HeaderBytesPadding + + packetsLost += deltaInfo.PacketsLost + packetsMissing += deltaInfo.PacketsMissing + packetsOutOfOrder += deltaInfo.PacketsOutOfOrder + + frames += deltaInfo.Frames + + if deltaInfo.RttMax > maxRtt { + maxRtt = deltaInfo.RttMax + } + + if deltaInfo.JitterMax > maxJitter { + maxJitter = deltaInfo.JitterMax + } + + nacks += deltaInfo.Nacks + plis += deltaInfo.Plis + firs += deltaInfo.Firs + } + if startTime == 0 || endTime == 0 { + return nil + } + + return &RTPDeltaInfo{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + Packets: packets, + Bytes: bytes, + HeaderBytes: headerBytes, + PacketsDuplicate: packetsDuplicate, + BytesDuplicate: bytesDuplicate, + HeaderBytesDuplicate: headerBytesDuplicate, + PacketsPadding: packetsPadding, + BytesPadding: bytesPadding, + HeaderBytesPadding: headerBytesPadding, + PacketsLost: packetsLost, + PacketsMissing: packetsMissing, + PacketsOutOfOrder: packetsOutOfOrder, + Frames: frames, + RttMax: maxRtt, + JitterMax: maxJitter, + Nacks: nacks, + Plis: plis, + Firs: firs, + } +} + +func ReconcileRTPStatsWithRTX(primaryStats *livekit.RTPStats, rtxStats *livekit.RTPStats) *livekit.RTPStats { + if primaryStats == nil || rtxStats == nil { + return primaryStats + } + + primaryStats.PacketsDuplicate += rtxStats.Packets + primaryStats.PacketDuplicateRate = float64(primaryStats.PacketsDuplicate) / primaryStats.Duration + + primaryStats.BytesDuplicate += rtxStats.Bytes + primaryStats.HeaderBytesDuplicate += rtxStats.HeaderBytes + primaryStats.BitrateDuplicate = float64(primaryStats.BytesDuplicate) * 8.0 / primaryStats.Duration + + primaryStats.PacketsPadding += rtxStats.PacketsPadding + primaryStats.PacketPaddingRate = float64(primaryStats.PacketsPadding) / primaryStats.Duration + + primaryStats.BytesPadding += rtxStats.BytesPadding + primaryStats.HeaderBytesPadding += rtxStats.HeaderBytesPadding + primaryStats.BitratePadding = float64(primaryStats.BytesPadding) * 8.0 / primaryStats.Duration + + // RTX non-padding packets are responses to NACKs, that should discount packets lost, + lossAdjustment := rtxStats.Packets - rtxStats.PacketsLost - primaryStats.NackRepeated + if int32(lossAdjustment) < 0 { + lossAdjustment = 0 + } + if lossAdjustment >= primaryStats.PacketsLost { + primaryStats.PacketsLost = 0 + } else { + primaryStats.PacketsLost -= lossAdjustment + } + primaryStats.PacketLossRate = float64(primaryStats.PacketsLost) / primaryStats.Duration + primaryStats.PacketLossPercentage = float32(primaryStats.PacketsLost) / float32(primaryStats.Packets+primaryStats.PacketsPadding+primaryStats.PacketsLost) * 100.0 + return primaryStats +} + +func ReconcileRTPDeltaInfoWithRTX(primaryDeltaInfo *RTPDeltaInfo, rtxDeltaInfo *RTPDeltaInfo) *RTPDeltaInfo { + if primaryDeltaInfo == nil || rtxDeltaInfo == nil { + return primaryDeltaInfo + } + + primaryDeltaInfo.PacketsDuplicate += rtxDeltaInfo.Packets + + primaryDeltaInfo.BytesDuplicate += rtxDeltaInfo.Bytes + primaryDeltaInfo.HeaderBytesDuplicate += rtxDeltaInfo.HeaderBytes + + primaryDeltaInfo.PacketsPadding += rtxDeltaInfo.PacketsPadding + + primaryDeltaInfo.BytesPadding += rtxDeltaInfo.BytesPadding + primaryDeltaInfo.HeaderBytesPadding += rtxDeltaInfo.HeaderBytesPadding + + // RTX non-padding packets are responses to NACKs, that should discount packets lost + lossAdjustment := rtxDeltaInfo.Packets - rtxDeltaInfo.PacketsLost - primaryDeltaInfo.NackRepeated + if int32(lossAdjustment) < 0 { + lossAdjustment = 0 + } + if lossAdjustment >= primaryDeltaInfo.PacketsLost { + primaryDeltaInfo.PacketsLost = 0 + } else { + primaryDeltaInfo.PacketsLost -= lossAdjustment + } + return primaryDeltaInfo +} + +// ------------------------------------------------------------------- diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_base_lite.go b/livekit/pkg/sfu/rtpstats/rtpstats_base_lite.go new file mode 100644 index 0000000..061eb10 --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_base_lite.go @@ -0,0 +1,557 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" + "go.uber.org/zap/zapcore" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + cGapHistogramNumBins = 101 + cNumSequenceNumbers = 65536 + cFirstSnapshotID = 1 +) + +// ------------------------------------------------------- + +type RTPDeltaInfoLite struct { + StartTime time.Time + EndTime time.Time + Packets uint32 + Bytes uint64 + PacketsLost uint32 + PacketsOutOfOrder uint32 + Nacks uint32 +} + +func (r *RTPDeltaInfoLite) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + e.AddTime("StartTime", r.StartTime) + e.AddTime("EndTime", r.EndTime) + e.AddUint32("Packets", r.Packets) + e.AddUint64("Bytes", r.Bytes) + e.AddUint32("PacketsLost", r.PacketsLost) + e.AddUint32("PacketsOutOfOrder", r.PacketsOutOfOrder) + e.AddUint32("Nacks", r.Nacks) + return nil +} + +// ------------------------------------------------------- + +type snapshotLite struct { + isValid bool + + startTime int64 + + extStartSN uint64 + bytes uint64 + + packetsOutOfOrder uint64 + + packetsLost uint64 + + nacks uint32 +} + +func (s *snapshotLite) MarshalLogObject(e zapcore.ObjectEncoder) error { + if s == nil { + return nil + } + + e.AddBool("isValid", s.isValid) + e.AddTime("startTime", time.Unix(0, s.startTime)) + e.AddUint64("extStartSN", s.extStartSN) + e.AddUint64("bytes", s.bytes) + e.AddUint64("packetsOutOfOrder", s.packetsOutOfOrder) + e.AddUint64("packetsLost", s.packetsLost) + e.AddUint32("nacks", s.nacks) + return nil +} + +// ------------------------------------------------------------------ + +type RTPStatsParams struct { + ClockRate uint32 + IsRTX bool + Logger logger.Logger +} + +type rtpStatsBaseLite struct { + params RTPStatsParams + logger logger.Logger + + lock sync.RWMutex + + initialized bool + + startTime int64 + endTime int64 + + bytes uint64 + + packetsOutOfOrder uint64 + + packetsLost uint64 + + gapHistogram [cGapHistogramNumBins]uint32 + + nacks uint32 + nackAcks uint32 + nackMisses uint32 + nackRepeated uint32 + + plis uint32 + lastPli int64 + + nextSnapshotLiteID uint32 + snapshotLites []snapshotLite +} + +func newRTPStatsBaseLite(params RTPStatsParams) *rtpStatsBaseLite { + return &rtpStatsBaseLite{ + params: params, + logger: params.Logger, + nextSnapshotLiteID: cFirstSnapshotID, + snapshotLites: make([]snapshotLite, 2), + } +} + +func (r *rtpStatsBaseLite) seed(from *rtpStatsBaseLite) bool { + if from == nil || !from.initialized || r.initialized { + return false + } + + r.initialized = from.initialized + + r.startTime = from.startTime + // do not clone endTime as a non-zero endTime indicates an ended object + + r.bytes = from.bytes + + r.packetsOutOfOrder = from.packetsOutOfOrder + + r.packetsLost = from.packetsLost + + r.gapHistogram = from.gapHistogram + + r.nacks = from.nacks + r.nackAcks = from.nackAcks + r.nackMisses = from.nackMisses + r.nackRepeated = from.nackRepeated + + r.plis = from.plis + r.lastPli = from.lastPli + + r.nextSnapshotLiteID = from.nextSnapshotLiteID + r.snapshotLites = make([]snapshotLite, cap(from.snapshotLites)) + copy(r.snapshotLites, from.snapshotLites) + return true +} + +func (r *rtpStatsBaseLite) SetLogger(logger logger.Logger) { + r.logger = logger +} + +func (r *rtpStatsBaseLite) Stop() { + r.lock.Lock() + defer r.lock.Unlock() + + r.endTime = mono.UnixNano() +} + +func (r *rtpStatsBaseLite) newSnapshotLiteID(extStartSN uint64) uint32 { + id := r.nextSnapshotLiteID + r.nextSnapshotLiteID++ + + if cap(r.snapshotLites) < int(r.nextSnapshotLiteID-cFirstSnapshotID) { + snapshotLites := make([]snapshotLite, r.nextSnapshotLiteID-cFirstSnapshotID) + copy(snapshotLites, r.snapshotLites) + r.snapshotLites = snapshotLites + } + + if r.initialized { + r.snapshotLites[id-cFirstSnapshotID] = initSnapshotLite(mono.UnixNano(), extStartSN) + } + return id +} + +func (r *rtpStatsBaseLite) IsActive() bool { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.initialized && r.endTime == 0 +} + +func (r *rtpStatsBaseLite) UpdateNack(nackCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.nacks += nackCount +} + +func (r *rtpStatsBaseLite) UpdateNackProcessed(nackAckCount uint32, nackMissCount uint32, nackRepeatedCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.nackAcks += nackAckCount + r.nackMisses += nackMissCount + r.nackRepeated += nackRepeatedCount +} + +func (r *rtpStatsBaseLite) CheckAndUpdatePli(throttle int64, force bool) bool { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 || (!force && mono.UnixNano()-r.lastPli < throttle) { + return false + } + r.updatePliLocked(1) + r.updatePliTimeLocked() + return true +} + +func (r *rtpStatsBaseLite) UpdatePliAndTime(pliCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.updatePliLocked(pliCount) + r.updatePliTimeLocked() +} + +func (r *rtpStatsBaseLite) UpdatePli(pliCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.updatePliLocked(pliCount) +} + +func (r *rtpStatsBaseLite) updatePliLocked(pliCount uint32) { + r.plis += pliCount +} + +func (r *rtpStatsBaseLite) UpdatePliTime() { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.updatePliTimeLocked() +} + +func (r *rtpStatsBaseLite) updatePliTimeLocked() { + r.lastPli = mono.UnixNano() +} + +func (r *rtpStatsBaseLite) LastPli() int64 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.lastPli +} + +func (r *rtpStatsBaseLite) getPacketsSeen(extStartSN, extHighestSN uint64) uint64 { + packetsExpected := getPacketsExpected(extStartSN, extHighestSN) + if r.packetsLost > packetsExpected { + // should not happen + return 0 + } + + return packetsExpected - r.packetsLost +} + +func (r *rtpStatsBaseLite) deltaInfoLite( + snapshotLiteID uint32, + extStartSN uint64, + extHighestSN uint64, +) (deltaInfoLite *RTPDeltaInfoLite, err error, loggingFields []any) { + then, now := r.getAndResetSnapshotLite(snapshotLiteID, extStartSN, extHighestSN) + if now == nil || then == nil { + return + } + + startTime := then.startTime + endTime := now.startTime + + packetsExpected := uint32(now.extStartSN - then.extStartSN) + if then.extStartSN > extHighestSN { + packetsExpected = 0 + } + if packetsExpected > cNumSequenceNumbers { + loggingFields = []any{ + "snapshotLiteID", snapshotLiteID, + "snapshotLiteNow", now, + "snapshotLiteThen", then, + "packetsExpected", packetsExpected, + "duration", time.Duration(endTime - startTime), + } + err = errors.New("too many packets expected in delta lite") + return + } + if packetsExpected == 0 { + deltaInfoLite = &RTPDeltaInfoLite{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + } + return + } + + packetsLost := uint32(now.packetsLost - then.packetsLost) + if int32(packetsLost) < 0 { + packetsLost = 0 + } + if packetsLost > packetsExpected { + loggingFields = []any{ + "snapshotLiteID", snapshotLiteID, + "snapshotLiteNow", now, + "snapshotLiteThen", then, + "packetsExpected", packetsExpected, + "packetsLost", packetsLost, + "duration", time.Duration(endTime - startTime), + } + err = errors.New("unexpected number of packets lost in delta lite") + } + + deltaInfoLite = &RTPDeltaInfoLite{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + Packets: packetsExpected, + Bytes: now.bytes - then.bytes, + PacketsLost: packetsLost, + PacketsOutOfOrder: uint32(now.packetsOutOfOrder - then.packetsOutOfOrder), + Nacks: now.nacks - then.nacks, + } + return +} + +func (r *rtpStatsBaseLite) marshalLogObject(e zapcore.ObjectEncoder, packetsExpected, packetsSeenMinusPadding uint64) (float64, error) { + if r == nil || !r.initialized { + return 0, errors.New("not initialized") + } + + endTime := r.endTime + if endTime == 0 { + endTime = mono.UnixNano() + } + elapsed := time.Duration(endTime - r.startTime) + if elapsed == 0 { + return 0, errors.New("no time elapsed") + } + elapsedSeconds := elapsed.Seconds() + + e.AddTime("startTime", time.Unix(0, r.startTime)) + e.AddTime("endTime", time.Unix(0, r.endTime)) + e.AddDuration("elapsed", elapsed) + + e.AddUint64("packetsExpected", packetsExpected) + e.AddFloat64("packetsExpectedRate", float64(packetsExpected)/elapsedSeconds) + e.AddUint64("packetsSeenPrimary", packetsSeenMinusPadding) + e.AddFloat64("packetsSeenPrimaryRate", float64(packetsSeenMinusPadding)/elapsedSeconds) + e.AddUint64("bytes", r.bytes) + e.AddFloat64("bitrate", float64(r.bytes)*8.0/elapsedSeconds) + + e.AddUint64("packetsOutOfOrder", r.packetsOutOfOrder) + + e.AddUint64("packetsLost", r.packetsLost) + e.AddFloat64("packetsLostRate", float64(r.packetsLost)/elapsedSeconds) + if packetsExpected != 0 { + e.AddFloat32("packetLostPercentage", float32(r.packetsLost)/float32(packetsExpected)*100.0) + } + + hasLoss := false + first := true + str := "[" + for burst, count := range r.gapHistogram { + if count == 0 { + continue + } + + hasLoss = true + + if !first { + str += ", " + } + first = false + str += fmt.Sprintf("%d:%d", burst+1, count) + } + str += "]" + if hasLoss { + e.AddString("gapHistogram", str) + } + + e.AddUint32("nacks", r.nacks) + e.AddUint32("nackAcks", r.nackAcks) + e.AddUint32("nackMisses", r.nackMisses) + e.AddUint32("nackRepeated", r.nackRepeated) + + e.AddUint32("plis", r.plis) + e.AddTime("lastPli", time.Unix(0, r.lastPli)) + return elapsedSeconds, nil +} + +func (r *rtpStatsBaseLite) toProto(packetsExpected, packetsSeenMinusPadding, packetsLost uint64) *livekit.RTPStats { + if r.startTime == 0 { + return nil + } + + endTime := r.endTime + if endTime == 0 { + endTime = mono.UnixNano() + } + elapsed := time.Duration(endTime - r.startTime).Seconds() + if elapsed == 0.0 { + return nil + } + + packetRate := float64(packetsSeenMinusPadding) / elapsed + bitrate := float64(r.bytes) * 8.0 / elapsed + + packetLostRate := float64(packetsLost) / elapsed + packetLostPercentage := float32(0) + if packetsExpected != 0 { + packetLostPercentage = float32(packetsLost) / float32(packetsExpected) * 100.0 + } + + p := &livekit.RTPStats{ + StartTime: timestamppb.New(time.Unix(0, r.startTime)), + EndTime: timestamppb.New(time.Unix(0, endTime)), + Duration: elapsed, + Packets: uint32(packetsSeenMinusPadding), + PacketRate: packetRate, + Bytes: r.bytes, + Bitrate: bitrate, + PacketsLost: uint32(packetsLost), + PacketLossRate: packetLostRate, + PacketLossPercentage: packetLostPercentage, + PacketsOutOfOrder: uint32(r.packetsOutOfOrder), + Nacks: r.nacks, + NackAcks: r.nackAcks, + NackMisses: r.nackMisses, + NackRepeated: r.nackRepeated, + Plis: r.plis, + LastPli: timestamppb.New(time.Unix(0, r.lastPli)), + } + + gapsPresent := false + for i := range len(r.gapHistogram) { + if r.gapHistogram[i] == 0 { + continue + } + + gapsPresent = true + break + } + + if gapsPresent { + p.GapHistogram = make(map[int32]uint32, len(r.gapHistogram)) + for i := range len(r.gapHistogram) { + if r.gapHistogram[i] == 0 { + continue + } + + p.GapHistogram[int32(i+1)] = r.gapHistogram[i] + } + } + + return p +} + +func (r *rtpStatsBaseLite) getAndResetSnapshotLite(snapshotLiteID uint32, extStartSN uint64, extHighestSN uint64) (*snapshotLite, *snapshotLite) { + if !r.initialized { + return nil, nil + } + + idx := snapshotLiteID - cFirstSnapshotID + then := r.snapshotLites[idx] + if !then.isValid { + then = initSnapshotLite(r.startTime, extStartSN) + r.snapshotLites[idx] = then + } + + // snapshot now + now := r.getSnapshotLite(mono.UnixNano(), extHighestSN+1) + r.snapshotLites[idx] = now + return &then, &now +} + +func (r *rtpStatsBaseLite) updateGapHistogram(gap int) { + if gap < 2 { + return + } + + missing := gap - 1 + if missing > len(r.gapHistogram) { + r.gapHistogram[len(r.gapHistogram)-1]++ + } else { + r.gapHistogram[missing-1]++ + } +} + +func (r *rtpStatsBaseLite) getSnapshotLite(startTime int64, extStartSN uint64) snapshotLite { + return snapshotLite{ + isValid: true, + startTime: startTime, + extStartSN: extStartSN, + bytes: r.bytes, + packetsOutOfOrder: r.packetsOutOfOrder, + packetsLost: r.packetsLost, + nacks: r.nacks, + } +} + +// ---------------------------------- + +func initSnapshotLite(startTime int64, extStartSN uint64) snapshotLite { + return snapshotLite{ + isValid: true, + startTime: startTime, + extStartSN: extStartSN, + } +} + +func getPacketsExpected(extStartSN, extHighestSN uint64) uint64 { + return extHighestSN - extStartSN + 1 +} + +// ---------------------------------- diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_receiver.go b/livekit/pkg/sfu/rtpstats/rtpstats_receiver.go new file mode 100644 index 0000000..7829262 --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_receiver.go @@ -0,0 +1,872 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "fmt" + "math" + "time" + + "github.com/pion/rtcp" + "go.uber.org/zap/zapcore" + + "github.com/livekit/mediatransportutil" + "github.com/livekit/mediatransportutil/pkg/latency" + "github.com/livekit/mediatransportutil/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + protoutils "github.com/livekit/protocol/utils" + "github.com/livekit/protocol/utils/mono" +) + +const ( + cHistorySize = 8192 + + // number of seconds the current report RTP timestamp can be off from expected RTP timestamp + cReportSlack = float64(60.0) + + cTSJumpTooHighFactor = float64(1.5) + + restartThreshold = 5 +) + +// --------------------------------------------------------------------- + +type RTPFlowUnhandledReason int + +const ( + RTPFlowUnhandledReasonNone RTPFlowUnhandledReason = iota + RTPFlowUnhandledReasonEnded + RTPFlowUnhandledReasonPaddingOnly + RTPFlowUnhandledReasonPreStartTimestamp + RTPFlowUnhandledReasonOldTimestamp + RTPFlowUnhandledReasonPreStartSequenceNumber + RTPFlowUnhandledReasonOldSequenceNumber + RTPFlowUnhandledReasonRestart +) + +func (r RTPFlowUnhandledReason) String() string { + switch r { + case RTPFlowUnhandledReasonNone: + return "NONE" + case RTPFlowUnhandledReasonEnded: + return "ENDED" + case RTPFlowUnhandledReasonPaddingOnly: + return "PADDING_ONLY" + case RTPFlowUnhandledReasonPreStartTimestamp: + return "PRE_START_TIMESTAMP" + case RTPFlowUnhandledReasonOldTimestamp: + return "OLD_TIMESTAMP" + case RTPFlowUnhandledReasonPreStartSequenceNumber: + return "PRE_START_SEQUENCE_NUMBER" + case RTPFlowUnhandledReasonOldSequenceNumber: + return "OLD_SEQUENCE_NUMBER" + case RTPFlowUnhandledReasonRestart: + return "RESTART" + default: + return fmt.Sprintf("UNKNOWN: %d", int(r)) + } +} + +type RTPFlowState struct { + UnhandledReason RTPFlowUnhandledReason + + LossStartInclusive uint64 + LossEndExclusive uint64 + + IsDuplicate bool + IsOutOfOrder bool + + ExtSequenceNumber uint64 + ExtTimestamp uint64 +} + +func (r *RTPFlowState) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + e.AddString("UnhandledReason", r.UnhandledReason.String()) + e.AddUint64("LossStartInclusive", r.LossStartInclusive) + e.AddUint64("LossEndExclusive", r.LossEndExclusive) + e.AddBool("IsDuplicate", r.IsDuplicate) + e.AddBool("IsOutOfOrder", r.IsOutOfOrder) + e.AddUint64("ExtSequenceNumber", r.ExtSequenceNumber) + e.AddUint64("ExtTimestamp", r.ExtTimestamp) + return nil +} + +// --------------------------------------------------------------------- + +type packet struct { + sequenceNumber uint16 + timestamp uint32 +} + +func (p packet) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddUint16("sequenceNumber", p.sequenceNumber) + e.AddUint32("timestamp", p.timestamp) + return nil +} + +// --------------------------------------------------------------------- + +type RTPStatsReceiver struct { + *rtpStatsBase + + sequenceNumber *utils.WrapAround[uint16, uint64] + + tsRolloverThreshold int64 + timestamp *utils.WrapAround[uint32, uint64] + + history *protoutils.Bitmap[uint64] + + propagationDelayEstimator *latency.OWDEstimator + + clockSkewCount int + clockSkewMediaPathCount int + outOfOrderSenderReportCount int + largeJumpCount int + largeJumpNegativeCount int + timeReversedCount int + + packetsDroppedPreStartTimestamp int + packetsDroppedOldTimestamp int + packetsDroppedPreStartSequenceNumber int + packetsDroppedOldSequenceNumber int + + restartPackets []packet +} + +func NewRTPStatsReceiver(params RTPStatsParams) *RTPStatsReceiver { + return &RTPStatsReceiver{ + rtpStatsBase: newRTPStatsBase(params), + sequenceNumber: utils.NewWrapAround[uint16, uint64](utils.WrapAroundParams{IsRestartAllowed: false}), + tsRolloverThreshold: (1 << 31) * 1e9 / int64(params.ClockRate), + timestamp: utils.NewWrapAround[uint32, uint64](utils.WrapAroundParams{IsRestartAllowed: false}), + history: protoutils.NewBitmap[uint64](cHistorySize), + propagationDelayEstimator: latency.NewOWDEstimator(latency.OWDEstimatorParamsDefault), + } +} + +func (r *RTPStatsReceiver) NewSnapshotId() uint32 { + r.lock.Lock() + defer r.lock.Unlock() + + return r.newSnapshotID(r.sequenceNumber.GetExtendedHighest()) +} + +func (r *RTPStatsReceiver) getTSRolloverCount(diffNano int64, ts uint32) int { + if diffNano < r.tsRolloverThreshold { + // time not more than rollover threshold + return -1 + } + + excess := (diffNano - r.tsRolloverThreshold*2) * int64(r.params.ClockRate) / 1e9 + roc := max(excess/(1<<32), 0) + if r.timestamp.GetHighest() > ts { + roc++ + } + return int(roc) +} + +func (r *RTPStatsReceiver) Update( + packetTime int64, + sequenceNumber uint16, + timestamp uint32, + marker bool, + hdrSize int, + payloadSize int, + paddingSize int, +) (flowState RTPFlowState) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + flowState.UnhandledReason = RTPFlowUnhandledReasonEnded + return + } + + var resSN utils.WrapAroundUpdateResult[uint64] + var gapSN int64 + var resTS utils.WrapAroundUpdateResult[uint64] + var gapTS int64 + var expectedTSJump int64 + var timeSinceHighest int64 + var tsRolloverCount int + var snRolloverCount int + + logger := func() logger.UnlikelyLogger { + return r.logger.WithUnlikelyValues( + "resSN", resSN, + "gapSN", gapSN, + "resTS", resTS, + "gapTS", gapTS, + "snRolloverCount", snRolloverCount, + "expectedTSJump", expectedTSJump, + "tsRolloverCount", tsRolloverCount, + "packetTime", time.Unix(0, packetTime), + "timeSinceHighest", time.Duration(timeSinceHighest), + "sequenceNumber", sequenceNumber, + "timestamp", timestamp, + "marker", marker, + "hdrSize", hdrSize, + "payloadSize", payloadSize, + "paddingSize", paddingSize, + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + } + + undoUpdates := func() { + r.sequenceNumber.UndoUpdate(resSN) + r.timestamp.UndoUpdate(resTS) + } + + if !r.initialized { + if payloadSize == 0 { + // do not start on a padding only packet + flowState.UnhandledReason = RTPFlowUnhandledReasonPaddingOnly + return + } + + r.initialized = true + + r.startTime = mono.UnixNano() + + r.firstTime = packetTime + r.highestTime = packetTime + + resSN = r.sequenceNumber.Update(sequenceNumber) + resTS = r.timestamp.Update(timestamp) + + // initialize snapshots if any + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + r.snapshots[i] = initSnapshot(r.startTime, r.sequenceNumber.GetExtendedStart()) + } + + r.logger.Debugw( + "rtp receiver stream start", + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + } else { + resSN = r.sequenceNumber.Update(sequenceNumber) + gapSN = int64(resSN.ExtendedVal - resSN.PreExtendedHighest) + + timeSinceHighest = packetTime - r.highestTime + tsRolloverCount = r.getTSRolloverCount(timeSinceHighest, timestamp) + if tsRolloverCount >= 0 { + logger().Warnw("potential time stamp roll over", nil) + } + resTS = r.timestamp.Rollover(timestamp, tsRolloverCount) + if resTS.IsUnhandled { + undoUpdates() + + r.packetsDroppedPreStartTimestamp++ + logger().Warnw("dropping packet, pre-start timestamp", nil) + + if r.maybeRestart(sequenceNumber, timestamp, payloadSize) { + logger().Infow("potential restart") + r.resetRestart() + flowState.UnhandledReason = RTPFlowUnhandledReasonRestart + } else { + flowState.UnhandledReason = RTPFlowUnhandledReasonPreStartTimestamp + } + return + } + gapTS = int64(resTS.ExtendedVal - resTS.PreExtendedHighest) + + if !resSN.IsUnhandled { + // it is possible to receive old packets in two different scenarios + // as it is not possible to detect how far to roll back, ignore old packets + // + // Case 1: + // Very old time stamp, happens under the following conditions + // - resume after long mute, this casues big time stamp jump ahead for the packets + // after unmute + // - an out of order packet from before the mute arrives (unsure what causes this + // very old packet to be transmitted from remote), causing time stamp to jump back + // to before mute, but it appears like it has rolled over. + // Use a threshold against expected to ignore these. + if gapSN < 0 && gapTS > 0 { + expectedTSJump = int64(r.rtpConverter.ToRTPExt(time.Duration(timeSinceHighest))) + if gapTS > int64(float64(expectedTSJump)*cTSJumpTooHighFactor) { + undoUpdates() + + r.packetsDroppedOldTimestamp++ + logger().Warnw("dropping packet, old timestamp", nil) + + if r.maybeRestart(sequenceNumber, timestamp, payloadSize) { + logger().Infow("potential restart") + r.resetRestart() + flowState.UnhandledReason = RTPFlowUnhandledReasonRestart + } else { + flowState.UnhandledReason = RTPFlowUnhandledReasonOldTimestamp + } + return + } + } + + // Case 2: + // Sequence number looks like it is moving forward, but it is actually a very old packet. + if gapTS < 0 && gapSN > 0 { + undoUpdates() + + r.packetsDroppedOldSequenceNumber++ + expectedTSJump = int64(r.rtpConverter.ToRTPExt(time.Duration(timeSinceHighest))) + logger().Warnw("dropping packet, old sequence number", nil) + + if r.maybeRestart(sequenceNumber, timestamp, payloadSize) { + logger().Infow("potential restart") + r.resetRestart() + flowState.UnhandledReason = RTPFlowUnhandledReasonRestart + } else { + flowState.UnhandledReason = RTPFlowUnhandledReasonOldSequenceNumber + } + return + } + } + + // it is possible that sequence number has rolled over too + if (gapSN < 0 || gapSN > (1<<15)) && gapTS > 0 && payloadSize > 0 { + // not possible to know how many cycles of sequence number roll over could have happened, + // ensure that it at least does not go backwards + snRolloverCount = 0 + if sequenceNumber < r.sequenceNumber.GetHighest() { + snRolloverCount = 1 + } + resSN = r.sequenceNumber.Rollover(sequenceNumber, snRolloverCount) + if !resSN.IsUnhandled { + logger().Warnw("forcing sequence number rollover", nil) + } + } + + if resSN.IsUnhandled { + undoUpdates() + + r.packetsDroppedPreStartSequenceNumber++ + logger().Warnw("dropping packet, pre-start sequence number", nil) + + if r.maybeRestart(sequenceNumber, timestamp, payloadSize) { + logger().Infow("potential restart") + r.resetRestart() + flowState.UnhandledReason = RTPFlowUnhandledReasonRestart + } else { + flowState.UnhandledReason = RTPFlowUnhandledReasonPreStartSequenceNumber + } + return + } + } + gapSN = int64(resSN.ExtendedVal - resSN.PreExtendedHighest) + + pktSize := uint64(hdrSize + payloadSize + paddingSize) + if gapSN <= 0 { // duplicate OR out-of-order + if gapSN != 0 { + r.packetsOutOfOrder++ + } + + if r.isInRange(resSN.ExtendedVal, resSN.PreExtendedHighest) { + if r.history.GetAndSet(resSN.ExtendedVal) { + r.bytesDuplicate += pktSize + r.headerBytesDuplicate += uint64(hdrSize) + r.packetsDuplicate++ + flowState.IsDuplicate = true + } else { + r.packetsLost-- + } + } + + flowState.IsOutOfOrder = true + + if !flowState.IsDuplicate && -gapSN >= cSequenceNumberLargeJumpThreshold { + r.largeJumpNegativeCount++ + if (r.largeJumpNegativeCount-1)%100 == 0 { + logger().Warnw( + "large sequence number gap negative", nil, + "count", r.largeJumpNegativeCount, + ) + } + } + } else { // in-order + if gapSN >= cSequenceNumberLargeJumpThreshold { + r.largeJumpCount++ + if (r.largeJumpCount-1)%100 == 0 { + logger().Warnw( + "large sequence number gap", nil, + "count", r.largeJumpCount, + ) + } + } + + if resTS.ExtendedVal < resTS.PreExtendedHighest { + r.timeReversedCount++ + if (r.timeReversedCount-1)%100 == 0 { + logger().Warnw( + "time reversed", nil, + "count", r.timeReversedCount, + ) + } + } + + // update gap histogram + r.updateGapHistogram(int(gapSN)) + + // update missing sequence numbers + r.history.ClearRange(resSN.PreExtendedHighest+1, resSN.ExtendedVal-1) + r.packetsLost += uint64(gapSN - 1) + + r.history.Set(resSN.ExtendedVal) + + if timestamp != uint32(resTS.PreExtendedHighest) { + // update only on first packet as same timestamp could be in multiple packets. + // NOTE: this may not be the first packet with this time stamp if there is packet loss. + r.highestTime = packetTime + } + + flowState.LossStartInclusive = resSN.PreExtendedHighest + 1 + flowState.LossEndExclusive = resSN.ExtendedVal + } + flowState.ExtSequenceNumber = resSN.ExtendedVal + flowState.ExtTimestamp = resTS.ExtendedVal + + if !flowState.IsDuplicate { + if payloadSize == 0 { + r.packetsPadding++ + r.bytesPadding += pktSize + r.headerBytesPadding += uint64(hdrSize) + } else { + r.bytes += pktSize + r.headerBytes += uint64(hdrSize) + + if marker { + r.frames++ + } + + r.updateJitter(resTS.ExtendedVal, packetTime) + } + } + r.resetRestart() + return +} + +func (r *RTPStatsReceiver) getExtendedSenderReport(srData *livekit.RTCPSenderReportState) *livekit.RTCPSenderReportState { + tsCycles := uint64(0) + if r.srNewest != nil { + // use time since last sender report to ensure long gaps where the time stamp might + // jump more than half the range + timeSinceLastReport := mediatransportutil.NtpTime(srData.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(r.srNewest.NtpTimestamp).Time()) + expectedRTPTimestampExt := r.srNewest.RtpTimestampExt + r.rtpConverter.ToRTPExt(timeSinceLastReport) + lbound := expectedRTPTimestampExt - uint64(cReportSlack*float64(r.params.ClockRate)) + ubound := expectedRTPTimestampExt + uint64(cReportSlack*float64(r.params.ClockRate)) + isInRange := (srData.RtpTimestamp-uint32(lbound) < (1 << 31)) && (uint32(ubound)-srData.RtpTimestamp < (1 << 31)) + if isInRange { + lbTSCycles := lbound & 0xFFFF_FFFF_0000_0000 + ubTSCycles := ubound & 0xFFFF_FFFF_0000_0000 + if lbTSCycles == ubTSCycles { + tsCycles = lbTSCycles + } else { + if srData.RtpTimestamp < (1 << 31) { + // rolled over + tsCycles = ubTSCycles + } else { + tsCycles = lbTSCycles + } + } + } else { + // ideally this method should not be required, but there are clients + // negotiating one clock rate, but actually send media at a different rate. + tsCycles = r.srNewest.RtpTimestampExt & 0xFFFF_FFFF_0000_0000 + if (srData.RtpTimestamp-r.srNewest.RtpTimestamp) < (1<<31) && srData.RtpTimestamp < r.srNewest.RtpTimestamp { + tsCycles += (1 << 32) + } + + if tsCycles >= (1 << 32) { + if (srData.RtpTimestamp-r.srNewest.RtpTimestamp) >= (1<<31) && srData.RtpTimestamp > r.srNewest.RtpTimestamp { + tsCycles -= (1 << 32) + } + } + } + } + + srDataExt := protoutils.CloneProto(srData) + srDataExt.RtpTimestampExt = uint64(srDataExt.RtpTimestamp) + tsCycles + return srDataExt +} + +func (r *RTPStatsReceiver) checkOutOfOrderSenderReport(srData *livekit.RTCPSenderReportState) bool { + if r.srNewest != nil && srData.RtpTimestampExt < r.srNewest.RtpTimestampExt { + // This can happen when a track is replaced with a null and then restored - + // i. e. muting replacing with null and unmute restoring the original track. + // Or it could be due bad report generation. + // In any case, ignore out-of-order reports. + r.outOfOrderSenderReportCount++ + if (r.outOfOrderSenderReportCount-1)%10 == 0 { + r.logger.Infow( + "received sender report, out-of-order, skipping", + "current", WrappedRTCPSenderReportStateLogger{srData}, + "count", r.outOfOrderSenderReportCount, + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + } + return true + } + + return false +} + +func (r *RTPStatsReceiver) checkRTPClockSkewForSenderReport(srData *livekit.RTCPSenderReportState) { + if r.srNewest == nil { + return + } + + timeSinceLast := mediatransportutil.NtpTime(srData.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(r.srNewest.NtpTimestamp).Time()).Seconds() + rtpDiffSinceLast := srData.RtpTimestampExt - r.srNewest.RtpTimestampExt + calculatedClockRateFromLast := float64(rtpDiffSinceLast) / timeSinceLast + + timeSinceFirst := mediatransportutil.NtpTime(srData.NtpTimestamp).Time().Sub(mediatransportutil.NtpTime(r.srFirst.NtpTimestamp).Time()).Seconds() + rtpDiffSinceFirst := srData.RtpTimestampExt - r.srFirst.RtpTimestampExt + calculatedClockRateFromFirst := float64(rtpDiffSinceFirst) / timeSinceFirst + + if (timeSinceLast > 0.2 && math.Abs(float64(r.params.ClockRate)-calculatedClockRateFromLast) > 0.2*float64(r.params.ClockRate)) || + (timeSinceFirst > 0.2 && math.Abs(float64(r.params.ClockRate)-calculatedClockRateFromFirst) > 0.2*float64(r.params.ClockRate)) { + r.clockSkewCount++ + if (r.clockSkewCount-1)%100 == 0 { + r.logger.Infow( + "received sender report, clock skew", + "current", WrappedRTCPSenderReportStateLogger{srData}, + "timeSinceFirst", timeSinceFirst, + "rtpDiffSinceFirst", rtpDiffSinceFirst, + "calculatedFirst", calculatedClockRateFromFirst, + "timeSinceLast", timeSinceLast, + "rtpDiffSinceLast", rtpDiffSinceLast, + "calculatedLast", calculatedClockRateFromLast, + "count", r.clockSkewCount, + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + } + } +} + +func (r *RTPStatsReceiver) checkRTPClockSkewAgainstMediaPathForSenderReport(srData *livekit.RTCPSenderReportState) { + if r.highestTime == 0 { + return + } + + nowNano := mono.UnixNano() + timeSinceSR := time.Duration(nowNano - srData.AtAdjusted) + extNowTSSR := srData.RtpTimestampExt + r.rtpConverter.ToRTPExt(timeSinceSR) + + timeSinceHighest := time.Duration(nowNano - r.highestTime) + extNowTSHighest := r.timestamp.GetExtendedHighest() + r.rtpConverter.ToRTPExt(timeSinceHighest) + diffHighest := extNowTSSR - extNowTSHighest + + timeSinceFirst := time.Duration(nowNano - r.firstTime) + extNowTSFirst := r.timestamp.GetExtendedStart() + r.rtpConverter.ToRTPExt(timeSinceFirst) + diffFirst := extNowTSSR - extNowTSFirst + + // is it more than 5 seconds off? + if uint32(math.Abs(float64(int64(diffHighest)))) > 5*r.params.ClockRate || uint32(math.Abs(float64(int64(diffFirst)))) > 5*r.params.ClockRate { + r.clockSkewMediaPathCount++ + if (r.clockSkewMediaPathCount-1)%100 == 0 { + r.logger.Infow( + "received sender report, clock skew against media path", + "current", WrappedRTCPSenderReportStateLogger{srData}, + "timeSinceSR", timeSinceSR, + "extNowTSSR", extNowTSSR, + "timeSinceHighest", timeSinceHighest, + "extNowTSHighest", extNowTSHighest, + "diffHighest", int64(diffHighest), + "timeSinceFirst", timeSinceFirst, + "extNowTSFirst", extNowTSFirst, + "diffFirst", int64(diffFirst), + "count", r.clockSkewMediaPathCount, + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + } + } +} + +func (r *RTPStatsReceiver) updatePropagationDelayAndRecordSenderReport(srData *livekit.RTCPSenderReportState) { + senderClockTime := mediatransportutil.NtpTime(srData.NtpTimestamp).Time().UnixNano() + estimatedPropagationDelay, stepChange := r.propagationDelayEstimator.Update(senderClockTime, srData.At) + if stepChange { + r.logger.Debugw( + "propagation delay step change", + "currentSenderReport", WrappedRTCPSenderReportStateLogger{srData}, + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + } + + if r.srFirst == nil { + r.srFirst = srData + } + // adjust receive time to estimated propagation delay + srData.AtAdjusted = senderClockTime + estimatedPropagationDelay + r.srNewest = srData +} + +func (r *RTPStatsReceiver) SetRtcpSenderReportData(srData *livekit.RTCPSenderReportState) bool { + r.lock.Lock() + defer r.lock.Unlock() + + if srData == nil || !r.initialized { + return false + } + + // prevent against extreme case of anachronous sender reports + if r.srNewest != nil && r.srNewest.NtpTimestamp > srData.NtpTimestamp { + r.logger.Infow( + "received sender report, anachronous, dropping", + "current", WrappedRTCPSenderReportStateLogger{srData}, + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + return false + } + + srDataExt := r.getExtendedSenderReport(srData) + + if r.checkOutOfOrderSenderReport(srDataExt) { + return false + } + + r.checkRTPClockSkewForSenderReport(srDataExt) + r.updatePropagationDelayAndRecordSenderReport(srDataExt) + r.checkRTPClockSkewAgainstMediaPathForSenderReport(srDataExt) + + adjustment, err, loggingFields := r.maybeAdjustFirstPacketTime(r.srNewest, 0, r.timestamp.GetExtendedStart()) + if err != nil { + r.logger.Infow(err.Error(), append(loggingFields, "rtpStats", lockedRTPStatsReceiverLogEncoder{r})...) + } + r.propagationDelayEstimator.InitialAdjustment(adjustment) + return true +} + +func (r *RTPStatsReceiver) GetRtcpSenderReportData() *livekit.RTCPSenderReportState { + r.lock.RLock() + defer r.lock.RUnlock() + + return protoutils.CloneProto(r.srNewest) +} + +func (r *RTPStatsReceiver) LastSenderReportTime() time.Time { + r.lock.RLock() + defer r.lock.RUnlock() + + if r.srNewest != nil { + return time.Unix(0, r.srNewest.At) + } + + return time.Time{} +} + +func (r *RTPStatsReceiver) GetRtcpReceptionReport(ssrc uint32, proxyFracLost uint8, snapshotID uint32) *rtcp.ReceptionReport { + r.lock.Lock() + defer r.lock.Unlock() + + extHighestSN := r.sequenceNumber.GetExtendedHighest() + then, now := r.getAndResetSnapshot(snapshotID, r.sequenceNumber.GetExtendedStart(), extHighestSN) + if now == nil || then == nil { + return nil + } + + packetsExpected := now.extStartSN - then.extStartSN + if packetsExpected > cNumSequenceNumbers { + r.logger.Warnw( + "too many packets expected in receiver report", + fmt.Errorf("start: %d, end: %d, expected: %d", then.extStartSN, now.extStartSN, packetsExpected), + "rtpStats", lockedRTPStatsReceiverLogEncoder{r}, + ) + return nil + } + if packetsExpected == 0 { + return nil + } + + packetsLost := uint32(now.packetsLost - then.packetsLost) + if int32(packetsLost) < 0 { + packetsLost = 0 + } + lossRate := float32(packetsLost) / float32(packetsExpected) + fracLost := max(proxyFracLost, uint8(lossRate*256.0)) + + totalLost := min(r.packetsLost, 0xffffff) // 24-bits max + + lastSR := uint32(0) + dlsr := uint32(0) + if r.srNewest != nil { + lastSR = uint32(r.srNewest.NtpTimestamp >> 16) + if r.srNewest.At != 0 { + delayUS := time.Since(time.Unix(0, r.srNewest.At)).Microseconds() + dlsr = uint32(delayUS * 65536 / 1e6) + } + } + + return &rtcp.ReceptionReport{ + SSRC: ssrc, + FractionLost: fracLost, + TotalLost: uint32(totalLost), + LastSequenceNumber: uint32(now.extStartSN), + Jitter: uint32(r.jitter), + LastSenderReport: lastSR, + Delay: dlsr, + } +} + +func (r *RTPStatsReceiver) DeltaInfo(snapshotID uint32) *RTPDeltaInfo { + r.lock.Lock() + defer r.lock.Unlock() + + deltaInfo, err, loggingFields := r.deltaInfo( + snapshotID, + r.sequenceNumber.GetExtendedStart(), + r.sequenceNumber.GetExtendedHighest(), + ) + if err != nil { + r.logger.Infow(err.Error(), append(loggingFields, "rtpStats", lockedRTPStatsReceiverLogEncoder{r})...) + } + + return deltaInfo +} + +func (r *RTPStatsReceiver) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + r.lock.RLock() + defer r.lock.RUnlock() + + return lockedRTPStatsReceiverLogEncoder{r}.MarshalLogObject(e) +} + +func (r *RTPStatsReceiver) ToProto() *livekit.RTPStats { + if r == nil { + return nil + } + + r.lock.RLock() + defer r.lock.RUnlock() + + extStartSN, extHighestSN := r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest() + return r.toProto( + getPacketsExpected(extStartSN, extHighestSN), + r.getPacketsSeenMinusPadding(extStartSN, extHighestSN), + r.packetsLost, + r.timestamp.GetExtendedStart(), + r.timestamp.GetExtendedHighest(), + r.jitter, + r.maxJitter, + ) +} + +func (r *RTPStatsReceiver) isInRange(esn uint64, ehsn uint64) bool { + diff := int64(ehsn - esn) + return diff >= 0 && diff < cHistorySize +} + +func (r *RTPStatsReceiver) HighestTimestamp() uint32 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.timestamp.GetHighest() +} + +// for testing only +func (r *RTPStatsReceiver) HighestSequenceNumber() uint16 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.sequenceNumber.GetHighest() +} + +// for testing only +func (r *RTPStatsReceiver) ExtendedHighestSequenceNumber() uint64 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.sequenceNumber.GetExtendedHighest() +} + +func (r *RTPStatsReceiver) maybeRestart(sn uint16, ts uint32, payloadSize int) bool { + if payloadSize > 0 { + r.restartPackets = append(r.restartPackets, packet{sn, ts}) + } + if len(r.restartPackets) < restartThreshold { + return false + } + + r.restartPackets = r.restartPackets[max(len(r.restartPackets)-restartThreshold, 0):] + // check for contiguous sequence numbers and equal or increasing timestamps + for i := 1; i < len(r.restartPackets); i++ { + p := &r.restartPackets[i] + prev := &r.restartPackets[i-1] + if p.sequenceNumber != prev.sequenceNumber+1 || (p.timestamp-prev.timestamp) > (1<<31) { + return false + } + } + + return true +} + +func (r *RTPStatsReceiver) resetRestart() { + r.restartPackets = r.restartPackets[:0] +} + +// ---------------------------------- + +type lockedRTPStatsReceiverLogEncoder struct { + *RTPStatsReceiver +} + +func (r lockedRTPStatsReceiverLogEncoder) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r.RTPStatsReceiver == nil { + return nil + } + + extStartSN, extHighestSN := r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest() + extStartTS, extHighestTS := r.timestamp.GetExtendedStart(), r.timestamp.GetExtendedHighest() + if _, err := r.rtpStatsBase.marshalLogObject( + e, + getPacketsExpected(extStartSN, extHighestSN), + r.getPacketsSeenMinusPadding(extStartSN, extHighestSN), + extStartTS, + extHighestTS, + ); err != nil { + return err + } + + e.AddUint64("extStartSN", extStartSN) + e.AddUint64("extHighestSN", extHighestSN) + e.AddUint64("extStartTS", extStartTS) + e.AddUint64("extHighestTS", extHighestTS) + + e.AddObject("propagationDelayEstimator", r.propagationDelayEstimator) + + e.AddInt("clockSkewCount", r.clockSkewCount) + e.AddInt("clockSkewMediaPathCount", r.clockSkewMediaPathCount) + e.AddInt("outOfOrderSenderReportCount", r.outOfOrderSenderReportCount) + e.AddInt("largeJumpCount", r.largeJumpCount) + e.AddInt("largeJumpNegativeCount", r.largeJumpNegativeCount) + e.AddInt("timeReversedCount", r.timeReversedCount) + + e.AddInt("packetsDroppedPreStartTimestamp", r.packetsDroppedPreStartTimestamp) + e.AddInt("packetsDroppedOldTimestamp", r.packetsDroppedOldTimestamp) + e.AddInt("packetsDroppedPreStartSequenceNumber", r.packetsDroppedPreStartSequenceNumber) + e.AddInt("packetsDroppedOldSequenceNumber", r.packetsDroppedOldSequenceNumber) + + e.AddArray("restartPackets", logger.ObjectSlice(r.restartPackets)) + return nil +} + +// ---------------------------------- diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_receiver_lite.go b/livekit/pkg/sfu/rtpstats/rtpstats_receiver_lite.go new file mode 100644 index 0000000..14a07da --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_receiver_lite.go @@ -0,0 +1,177 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "go.uber.org/zap/zapcore" + + "github.com/livekit/mediatransportutil/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/mono" +) + +type RTPFlowStateLite struct { + IsNotHandled bool + + LossStartInclusive uint64 + LossEndExclusive uint64 + + ExtSequenceNumber uint64 +} + +func (r *RTPFlowStateLite) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + e.AddBool("IsNotHandled", r.IsNotHandled) + e.AddUint64("LossStartInclusive", r.LossStartInclusive) + e.AddUint64("LossEndExclusive", r.LossEndExclusive) + e.AddUint64("ExtSequenceNumber", r.ExtSequenceNumber) + return nil +} + +// --------------------------------------------------------------------- + +type RTPStatsReceiverLite struct { + *rtpStatsBaseLite + + sequenceNumber *utils.WrapAround[uint16, uint64] +} + +func NewRTPStatsReceiverLite(params RTPStatsParams) *RTPStatsReceiverLite { + return &RTPStatsReceiverLite{ + rtpStatsBaseLite: newRTPStatsBaseLite(params), + sequenceNumber: utils.NewWrapAround[uint16, uint64](utils.WrapAroundParams{IsRestartAllowed: false}), + } +} + +func (r *RTPStatsReceiverLite) NewSnapshotLiteId() uint32 { + r.lock.Lock() + defer r.lock.Unlock() + + return r.newSnapshotLiteID(r.sequenceNumber.GetExtendedHighest()) +} + +func (r *RTPStatsReceiverLite) Update(packetTime int64, packetSize int, sequenceNumber uint16) (flowStateLite RTPFlowStateLite) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + flowStateLite.IsNotHandled = true + return + } + + var resSN utils.WrapAroundUpdateResult[uint64] + if !r.initialized { + r.initialized = true + + r.startTime = mono.UnixNano() + + resSN = r.sequenceNumber.Update(sequenceNumber) + + // initialize lite snapshots if any + for i := uint32(0); i < r.nextSnapshotLiteID-cFirstSnapshotID; i++ { + r.snapshotLites[i] = initSnapshotLite(r.startTime, r.sequenceNumber.GetExtendedStart()) + } + + r.logger.Debugw( + "rtp receiver lite stream start", + "rtpStats", lockedRTPStatsReceiverLiteLogEncoder{r}, + ) + } else { + resSN = r.sequenceNumber.Update(sequenceNumber) + if resSN.IsUnhandled { + flowStateLite.IsNotHandled = true + return + } + } + + gapSN := int64(resSN.ExtendedVal - resSN.PreExtendedHighest) + if gapSN <= 0 { // duplicate OR out-of-order + r.packetsOutOfOrder++ // counting duplicate as out-of-order + r.packetsLost-- + } else { // in-order + r.updateGapHistogram(int(gapSN)) + r.packetsLost += uint64(gapSN - 1) + + flowStateLite.LossStartInclusive = resSN.PreExtendedHighest + 1 + flowStateLite.LossEndExclusive = resSN.ExtendedVal + } + flowStateLite.ExtSequenceNumber = resSN.ExtendedVal + r.bytes += uint64(packetSize) + return +} + +func (r *RTPStatsReceiverLite) DeltaInfoLite(snapshotLiteID uint32) *RTPDeltaInfoLite { + r.lock.Lock() + defer r.lock.Unlock() + + deltaInfoLite, err, loggingFields := r.deltaInfoLite( + snapshotLiteID, + r.sequenceNumber.GetExtendedStart(), + r.sequenceNumber.GetExtendedHighest(), + ) + if err != nil { + r.logger.Infow(err.Error(), append(loggingFields, "rtpStats", lockedRTPStatsReceiverLiteLogEncoder{r})...) + } + + return deltaInfoLite +} + +func (r *RTPStatsReceiverLite) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + r.lock.RLock() + defer r.lock.RUnlock() + + return lockedRTPStatsReceiverLiteLogEncoder{r}.MarshalLogObject(e) +} + +func (r *RTPStatsReceiverLite) ToProto() *livekit.RTPStats { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.rtpStatsBaseLite.toProto(r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest(), r.packetsLost) +} + +// ---------------------------------- + +type lockedRTPStatsReceiverLiteLogEncoder struct { + *RTPStatsReceiverLite +} + +func (r lockedRTPStatsReceiverLiteLogEncoder) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r.RTPStatsReceiverLite == nil { + return nil + } + + extStartSN, extHighestSN := r.sequenceNumber.GetExtendedStart(), r.sequenceNumber.GetExtendedHighest() + if _, err := r.rtpStatsBaseLite.marshalLogObject( + e, + getPacketsExpected(extStartSN, extHighestSN), + getPacketsExpected(extStartSN, extHighestSN), + ); err != nil { + return err + } + + e.AddUint64("extStartSN", r.sequenceNumber.GetExtendedStart()) + e.AddUint64("extHighestSN", r.sequenceNumber.GetExtendedHighest()) + return nil +} + +// ---------------------------------- diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_sender.go b/livekit/pkg/sfu/rtpstats/rtpstats_sender.go new file mode 100644 index 0000000..186dfba --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_sender.go @@ -0,0 +1,1400 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "errors" + "fmt" + "math" + "time" + + "github.com/pion/rtcp" + "go.uber.org/zap/zapcore" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/mediatransportutil" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/mono" +) + +// ------------------------------------------------------------------- + +type snInfoFlag byte + +const ( + snInfoFlagMarker snInfoFlag = 1 << iota + snInfoFlagPadding + snInfoFlagOutOfOrder +) + +type snInfo struct { + pktSize uint16 + hdrSize uint8 + flags snInfoFlag +} + +// ------------------------------------------------------------------- + +type intervalStats struct { + packets uint64 + bytes uint64 + headerBytes uint64 + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + packetsLostFeed uint64 + packetsOutOfOrderFeed uint64 + frames uint32 + packetsNotFoundMetadata uint64 +} + +func (is *intervalStats) aggregate(other *intervalStats) { + if is == nil || other == nil { + return + } + + is.packets += other.packets + is.bytes += other.bytes + is.headerBytes += other.headerBytes + is.packetsPadding += other.packetsPadding + is.bytesPadding += other.bytesPadding + is.headerBytesPadding += other.headerBytesPadding + is.packetsLostFeed += other.packetsLostFeed + is.packetsOutOfOrderFeed += other.packetsOutOfOrderFeed + is.frames += other.frames + is.packetsNotFoundMetadata += other.packetsNotFoundMetadata +} + +func (is *intervalStats) MarshalLogObject(e zapcore.ObjectEncoder) error { + if is == nil { + return nil + } + e.AddUint64("packets", is.packets) + e.AddUint64("bytes", is.bytes) + e.AddUint64("headerBytes", is.headerBytes) + e.AddUint64("packetsPadding", is.packetsPadding) + e.AddUint64("bytesPadding", is.bytesPadding) + e.AddUint64("headerBytesPadding", is.headerBytesPadding) + e.AddUint64("packetsLostFeed", is.packetsLostFeed) + e.AddUint64("packetsOutOfOrderFeed", is.packetsOutOfOrderFeed) + e.AddUint32("frames", is.frames) + e.AddUint64("packetsNotFoundMetadata", is.packetsNotFoundMetadata) + + return nil +} + +// ------------------------------------------------------------------- + +type wrappedReceptionReportsLogger struct { + *senderSnapshotReceiverView +} + +func (w wrappedReceptionReportsLogger) MarshalLogObject(e zapcore.ObjectEncoder) error { + for i, rr := range w.senderSnapshotReceiverView.processedReceptionReports { + e.AddReflected(fmt.Sprintf("%d", i), rr) + } + + return nil +} + +// ------------------------------------------------------------------- + +type senderSnapshotWindow struct { + isValid bool + + startTime int64 + + extStartSN uint64 + bytes uint64 + headerBytes uint64 + + packetsPadding uint64 + bytesPadding uint64 + headerBytesPadding uint64 + + packetsDuplicate uint64 + bytesDuplicate uint64 + headerBytesDuplicate uint64 + + packetsOutOfOrderFeed uint64 + + packetsLostFeed uint64 + + frames uint32 + + nacks uint32 + nackRepeated uint32 + plis uint32 + firs uint32 + + maxJitterFeed float64 +} + +func (s *senderSnapshotWindow) MarshalLogObject(e zapcore.ObjectEncoder) error { + if s == nil { + return nil + } + + e.AddBool("isValid", s.isValid) + e.AddTime("startTime", time.Unix(0, s.startTime)) + e.AddUint64("extStartSN", s.extStartSN) + e.AddUint64("bytes", s.bytes) + e.AddUint64("headerBytes", s.headerBytes) + e.AddUint64("packetsPadding", s.packetsPadding) + e.AddUint64("bytesPadding", s.bytesPadding) + e.AddUint64("headerBytesPadding", s.headerBytesPadding) + e.AddUint64("packetsDuplicate", s.packetsDuplicate) + e.AddUint64("bytesDuplicate", s.bytesDuplicate) + e.AddUint64("headerBytesDuplicate", s.headerBytesDuplicate) + e.AddUint64("packetsOutOfOrderFeed", s.packetsOutOfOrderFeed) + e.AddUint64("packetsLostFeed", s.packetsLostFeed) + e.AddUint32("frames", s.frames) + e.AddUint32("nacks", s.nacks) + e.AddUint32("nackRepeated", s.nackRepeated) + e.AddUint32("plis", s.plis) + e.AddUint32("firs", s.firs) + e.AddFloat64("maxJitterFeed", s.maxJitterFeed) + return nil +} + +func (s *senderSnapshotWindow) maybeReinit(oldESN uint64, newESN uint64) { + if s.extStartSN == oldESN { + s.extStartSN = newESN + } +} + +func (s *senderSnapshotWindow) maybeUpdateMaxJitterFeed(jitter float64) { + if jitter > s.maxJitterFeed { + s.maxJitterFeed = jitter + } +} + +// --------- + +type senderSnapshotReceiverView struct { + senderSnapshotWindow + + packetsLost uint64 + + maxRtt uint32 + maxJitter float64 + + extLastRRSN uint64 + intervalStats intervalStats + processedReceptionReports []rtcp.ReceptionReport + metadataCacheOverflowCount int +} + +func (s *senderSnapshotReceiverView) MarshalLogObject(e zapcore.ObjectEncoder) error { + if s == nil { + return nil + } + + s.senderSnapshotWindow.MarshalLogObject(e) + e.AddUint64("packetsLost", s.packetsLost) + e.AddUint32("maxRtt", s.maxRtt) + e.AddFloat64("maxJitter", s.maxJitter) + e.AddUint64("extLastRRSN", s.extLastRRSN) + e.AddObject("intervalStats", &s.intervalStats) + e.AddObject("processedReceptionReports", wrappedReceptionReportsLogger{s}) + e.AddInt("metadataCacheOverflowCount", s.metadataCacheOverflowCount) + return nil +} + +func (s *senderSnapshotReceiverView) maybeReinit(oldESN uint64, newESN uint64) { + if s.extStartSN == oldESN { + s.extStartSN = newESN + if s.extLastRRSN == (oldESN - 1) { + s.extLastRRSN = newESN - 1 + } + } +} + +func (s *senderSnapshotReceiverView) maybeUpdateMaxRTT(rtt uint32) { + if rtt > s.maxRtt { + s.maxRtt = rtt + } +} + +func (s *senderSnapshotReceiverView) maybeUpdateMaxJitter(jitter float64) { + if jitter > s.maxJitter { + s.maxJitter = jitter + } +} + +// --------- + +type senderSnapshot struct { + senderView senderSnapshotWindow + receiverView senderSnapshotReceiverView +} + +func (s *senderSnapshot) MarshalLogObject(e zapcore.ObjectEncoder) error { + if s == nil { + return nil + } + + e.AddObject("senderView", &s.senderView) + e.AddObject("receiverView", &s.receiverView) + return nil +} + +func (s *senderSnapshot) maybeReinit(oldESN uint64, newESN uint64) { + s.senderView.maybeReinit(oldESN, newESN) + s.receiverView.maybeReinit(oldESN, newESN) +} + +func (s *senderSnapshot) maybeUpdateMaxJitterFeed(jitter float64) { + s.senderView.maybeUpdateMaxJitterFeed(jitter) + s.receiverView.maybeUpdateMaxJitterFeed(jitter) +} + +func (s *senderSnapshot) maybeUpdateMaxRTT(rtt uint32) { + s.receiverView.maybeUpdateMaxRTT(rtt) +} + +func (s *senderSnapshot) maybeUpdateMaxJitter(jitter float64) { + s.receiverView.maybeUpdateMaxJitter(jitter) +} + +// ------------------------------------------------------------------- + +type rttMarker struct { + ntpTime mediatransportutil.NtpTime + sentAt time.Time +} + +// ------------------------------------------------------------------- + +type RTPStatsSender struct { + *rtpStatsBase + + extStartSN uint64 + extHighestSN uint64 + extHighestSNFromRR uint64 + extHighestSNFromRRMisalignment uint64 + + rttMarker rttMarker + + lastRRTime int64 + lastRR rtcp.ReceptionReport + + extStartTS uint64 + extHighestTS uint64 + + packetsLostFromRR uint64 + + jitterFromRR float64 + maxJitterFromRR float64 + + snInfos []snInfo + + layerLockPlis uint32 + lastLayerLockPli time.Time + + nextSenderSnapshotID uint32 + senderSnapshots []senderSnapshot + + clockSkewCount int + largeJumpNegativeCount int + largeJumpCount int + timeReversedCount int +} + +func NewRTPStatsSender(params RTPStatsParams, cacheSize int) *RTPStatsSender { + return &RTPStatsSender{ + rtpStatsBase: newRTPStatsBase(params), + snInfos: make([]snInfo, cacheSize), + nextSenderSnapshotID: cFirstSnapshotID, + senderSnapshots: make([]senderSnapshot, 2), + } +} + +func (r *RTPStatsSender) Seed(from *RTPStatsSender) { + r.lock.Lock() + defer r.lock.Unlock() + + if !r.seed(from.rtpStatsBase) { + return + } + + r.extStartSN = from.extStartSN + r.extHighestSN = from.extHighestSN + r.extHighestSNFromRR = from.extHighestSNFromRR + r.extHighestSNFromRRMisalignment = from.extHighestSNFromRRMisalignment + + r.rttMarker = from.rttMarker + + r.lastRRTime = from.lastRRTime + r.lastRR = from.lastRR + + r.extStartTS = from.extStartTS + r.extHighestTS = from.extHighestTS + + r.packetsLostFromRR = from.packetsLostFromRR + + r.jitterFromRR = from.jitterFromRR + r.maxJitterFromRR = from.maxJitterFromRR + + r.snInfos = make([]snInfo, len(from.snInfos)) + copy(r.snInfos, from.snInfos) + + r.layerLockPlis = from.layerLockPlis + r.lastLayerLockPli = from.lastLayerLockPli + + r.nextSenderSnapshotID = from.nextSenderSnapshotID + r.senderSnapshots = make([]senderSnapshot, cap(from.senderSnapshots)) + copy(r.senderSnapshots, from.senderSnapshots) +} + +func (r *RTPStatsSender) NewSnapshotId() uint32 { + r.lock.Lock() + defer r.lock.Unlock() + + return r.newSnapshotID(r.extHighestSN) +} + +func (r *RTPStatsSender) NewSenderSnapshotId() uint32 { + r.lock.Lock() + defer r.lock.Unlock() + + id := r.nextSenderSnapshotID + r.nextSenderSnapshotID++ + + if cap(r.senderSnapshots) < int(r.nextSenderSnapshotID-cFirstSnapshotID) { + senderSnapshots := make([]senderSnapshot, r.nextSenderSnapshotID-cFirstSnapshotID) + copy(senderSnapshots, r.senderSnapshots) + r.senderSnapshots = senderSnapshots + } + + if r.initialized { + r.senderSnapshots[id-cFirstSnapshotID] = initSenderSnapshot(mono.UnixNano(), r.extHighestSN) + } + return id +} + +func (r *RTPStatsSender) Update( + packetTime int64, + extSequenceNumber uint64, + extTimestamp uint64, + marker bool, + hdrSize int, + payloadSize int, + paddingSize int, + isOutOfOrder bool, +) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + if !r.initialized { + if payloadSize == 0 && !r.params.IsRTX { + // do not start on a padding only packet + return + } + + r.initialized = true + + r.startTime = mono.UnixNano() + + r.highestTime = packetTime + + r.extStartSN = extSequenceNumber + r.extHighestSN = extSequenceNumber - 1 + + r.extStartTS = extTimestamp + r.extHighestTS = extTimestamp + + // initialize snapshots if any + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + r.snapshots[i] = initSnapshot(r.startTime, r.extStartSN) + } + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + r.senderSnapshots[i] = initSenderSnapshot(r.startTime, r.extStartSN) + } + + r.logger.Debugw( + "rtp sender stream start", + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } + if !isOutOfOrder && r.firstTime == 0 { + // do not set first packet time if packet is out-of-order, + // as first packet time is used to calculate expected time stamp, + // using an out-of-order packet would skew that. + r.firstTime = packetTime + } + + pktSize := uint64(hdrSize + payloadSize + paddingSize) + isDuplicate := false + gapSN := int64(extSequenceNumber - r.extHighestSN) + ulgr := func() logger.UnlikelyLogger { + return r.logger.WithUnlikelyValues( + "currSN", extSequenceNumber, + "gapSN", gapSN, + "currTS", extTimestamp, + "gapTS", int64(extTimestamp-r.extHighestTS), + "packetTime", time.Unix(0, packetTime), + "timeSinceHighest", time.Duration(packetTime-r.highestTime), + "marker", marker, + "hdrSize", hdrSize, + "payloadSize", payloadSize, + "paddingSize", paddingSize, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } + if gapSN <= 0 { // duplicate OR out-of-order + if payloadSize == 0 && extSequenceNumber < r.extStartSN && !r.params.IsRTX { + // do not start on a padding only packet + return + } + + if extSequenceNumber < r.extStartSN { + r.packetsLost += r.extStartSN - extSequenceNumber - 1 + + // adjust start of snapshots + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] + if s.extStartSN == r.extStartSN { + s.extStartSN = extSequenceNumber + } + } + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + r.senderSnapshots[i].maybeReinit(r.extStartSN, extSequenceNumber) + } + + ulgr().Infow( + "adjusting start sequence number", + "snAfter", extSequenceNumber, + "tsAfter", extTimestamp, + ) + r.extStartSN = extSequenceNumber + } + + if gapSN != 0 { + r.packetsOutOfOrder++ + } + + if !r.isSnInfoLost(extSequenceNumber, r.extHighestSN) { + r.bytesDuplicate += pktSize + r.headerBytesDuplicate += uint64(hdrSize) + r.packetsDuplicate++ + isDuplicate = true + } else { + r.packetsLost-- + r.setSnInfo(extSequenceNumber, r.extHighestSN, uint16(pktSize), uint8(hdrSize), uint16(payloadSize), marker, true) + } + + if !isDuplicate && -gapSN >= cSequenceNumberLargeJumpThreshold { + r.largeJumpNegativeCount++ + if (r.largeJumpNegativeCount-1)%100 == 0 { + ulgr().Warnw( + "large sequence number gap negative", nil, + "count", r.largeJumpNegativeCount, + ) + } + } + } else { // in-order + if gapSN >= cSequenceNumberLargeJumpThreshold { + r.largeJumpCount++ + if (r.largeJumpCount-1)%100 == 0 { + ulgr().Warnw( + "large sequence number gap", nil, + "count", r.largeJumpCount, + ) + } + } + + if extTimestamp < r.extHighestTS { + r.timeReversedCount++ + if (r.timeReversedCount-1)%100 == 0 { + ulgr().Warnw( + "time reversed", nil, + "count", r.timeReversedCount, + ) + } + } + + // update gap histogram + r.updateGapHistogram(int(gapSN)) + + // update missing sequence numbers + r.clearSnInfos(r.extHighestSN+1, extSequenceNumber) + r.packetsLost += uint64(gapSN - 1) + + r.setSnInfo(extSequenceNumber, r.extHighestSN, uint16(pktSize), uint8(hdrSize), uint16(payloadSize), marker, false) + + r.extHighestSN = extSequenceNumber + } + + if extTimestamp < r.extStartTS { + ulgr().Infow( + "adjusting start timestamp", + "snAfter", extSequenceNumber, + "tsAfter", extTimestamp, + ) + r.extStartTS = extTimestamp + } + + if extTimestamp > r.extHighestTS { + // update only on first packet as same timestamp could be in multiple packets. + // NOTE: this may not be the first packet with this time stamp if there is packet loss. + if payloadSize > 0 { + // skip updating on padding only packets as they could re-use an old timestamp + r.highestTime = packetTime + } + r.extHighestTS = extTimestamp + } + + if !isDuplicate { + if payloadSize == 0 { + r.packetsPadding++ + r.bytesPadding += pktSize + r.headerBytesPadding += uint64(hdrSize) + } else { + r.bytes += pktSize + r.headerBytes += uint64(hdrSize) + + if marker { + r.frames++ + } + + jitter := r.updateJitter(extTimestamp, packetTime) + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + r.senderSnapshots[i].maybeUpdateMaxJitterFeed(jitter) + } + } + } +} + +func (r *RTPStatsSender) UpdateLayerLockPliAndTime(pliCount uint32) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + r.layerLockPlis += pliCount + r.lastLayerLockPli = time.Now() +} + +func (r *RTPStatsSender) GetPacketsSeenMinusPadding() uint64 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.getPacketsSeenMinusPadding(r.extStartSN, r.extHighestSN) +} + +func (r *RTPStatsSender) UpdateFromReceiverReport(rr rtcp.ReceptionReport) (rtt uint32, isRttChanged bool) { + r.lock.Lock() + defer r.lock.Unlock() + + if !r.initialized || r.endTime != 0 { + return + } + + extHighestSNFromRR := r.extHighestSNFromRR&0xFFFF_FFFF_0000_0000 + uint64(rr.LastSequenceNumber) + r.extHighestSNFromRRMisalignment + if r.lastRRTime != 0 { + if (rr.LastSequenceNumber-r.lastRR.LastSequenceNumber) < (1<<31) && rr.LastSequenceNumber < r.lastRR.LastSequenceNumber { + extHighestSNFromRR += (1 << 32) + } + } + if (extHighestSNFromRR + (r.extStartSN & 0xFFFF_FFFF_FFFF_0000)) < r.extStartSN { + // it is possible that the `LastSequenceNumber` in the receiver report is before the starting + // sequence number when dummy packets are used to trigger Pion's OnTrack path. + return + } + + nowNano := mono.UnixNano() + defer func() { + r.lastRRTime = nowNano + r.lastRR = rr + }() + + timeSinceLastRR := func() time.Duration { + if r.lastRRTime != 0 { + return time.Duration(nowNano - r.lastRRTime) + } + return time.Duration(nowNano - r.startTime) + } + + extReceivedRRSN := extHighestSNFromRR + (r.extStartSN & 0xFFFF_FFFF_FFFF_0000) + if int64(r.extHighestSN-extReceivedRRSN) < 0 || int64(r.extHighestSN-extReceivedRRSN) > 4*(1<<16) { + r.logger.Infow( + "receiver report runaway, dropping", + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extHighestSNFromRR", extHighestSNFromRR, + "extReceivedRRSN", extReceivedRRSN, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + return + } + + if r.extHighestSNFromRR != extHighestSNFromRR && int64(r.extHighestSN-extReceivedRRSN) >= (1<<16) { + // there are cases where remote does not send RTCP Receiver Report for extended periods of time, + // some times several minutes, in that interval the sequence number rolls over, + // + // NOTE: even if there is a large gap in time, the sequence number should be higher + // than previous report (extended sequence number in receiver report is 32-bit wide and + // should not roll over for long time, for e. g. it will approximately take 100 days at 500 pps). + // So, there seems to be a remote reporter issue where the sequence number rollover is missed. + // + // catch up till difference between highest sent and highest received via receiver report is + // less than full 16-bit range. + // + // in a different flavor, there are clients that do not report properly, + // i. e. never update the last received sequence number, + // so skip any catch up if the last receeved sequence number reported in + // RTCP RR does not change. + r.logger.Infow( + "receiver report missed rollover, adjusting", + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extHighestSNFromRR", extHighestSNFromRR, + "extReceivedRRSN", extReceivedRRSN, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + for int64(r.extHighestSN-extReceivedRRSN) >= (1 << 16) { + extHighestSNFromRR += (1 << 16) + r.extHighestSNFromRRMisalignment += (1 << 16) + extReceivedRRSN = extHighestSNFromRR + (r.extStartSN & 0xFFFF_FFFF_FFFF_0000) + } + r.logger.Infow( + "receiver report missed rollover, adjusted", + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extHighestSNFromRR", extHighestSNFromRR, + "extReceivedRRSN", extReceivedRRSN, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } + + if r.extHighestSN < extReceivedRRSN { + // if remote adjusts somehow, roll back alignment + r.logger.Infow( + "receiver report caught up rollover, adjusting", + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extHighestSNFromRR", extHighestSNFromRR, + "extReceivedRRSN", extReceivedRRSN, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + for r.extHighestSN < extReceivedRRSN { + extHighestSNFromRR -= (1 << 16) + r.extHighestSNFromRRMisalignment -= (1 << 16) + extReceivedRRSN = extHighestSNFromRR + (r.extStartSN & 0xFFFF_FFFF_FFFF_0000) + } + r.logger.Infow( + "receiver report caught up rollover, adjusted", + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extHighestSNFromRR", extHighestSNFromRR, + "extReceivedRRSN", extReceivedRRSN, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } + + if r.extHighestSNFromRR > extHighestSNFromRR { + r.logger.Infow( + "receiver report out-of-order, dropping", + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extHighestSNFromRR", extHighestSNFromRR, + "extReceivedRRSN", extReceivedRRSN, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + return + } + r.extHighestSNFromRR = extHighestSNFromRR + + if r.srNewest != nil { + var err error + rtt, err = mediatransportutil.GetRttMs(&rr, r.rttMarker.ntpTime, r.rttMarker.sentAt) + if err == nil { + isRttChanged = rtt != r.rtt + } else { + r.logger.Debugw("error getting rtt", "error", err) + } + } + + r.packetsLostFromRR = uint64(rr.TotalLost) + lossDelta := (rr.TotalLost - r.lastRR.TotalLost) & ((1 << 24) - 1) + if lossDelta < (1<<23) && rr.TotalLost < r.lastRR.TotalLost { + r.packetsLostFromRR += (1 << 24) + } + + if isRttChanged { + r.rtt = rtt + if rtt > r.maxRtt { + r.maxRtt = rtt + } + } + + r.jitterFromRR = float64(rr.Jitter) + if r.jitterFromRR > r.maxJitterFromRR { + r.maxJitterFromRR = r.jitterFromRR + } + + // update snapshots + for i := uint32(0); i < r.nextSnapshotID-cFirstSnapshotID; i++ { + s := &r.snapshots[i] + if isRttChanged { + s.maybeUpdateMaxRTT(rtt) + } + } + + for i := uint32(0); i < r.nextSenderSnapshotID-cFirstSnapshotID; i++ { + s := &r.senderSnapshots[i] + if isRttChanged { + s.maybeUpdateMaxRTT(rtt) + } + + s.maybeUpdateMaxJitter(r.jitterFromRR) + + // on every RR, calculate delta since last RR using packet metadata cache + is := r.getIntervalStats(s.receiverView.extLastRRSN+1, extReceivedRRSN+1, r.extHighestSN) + eis := &s.receiverView.intervalStats + eis.aggregate(&is) + if is.packetsNotFoundMetadata != 0 { + s.receiverView.metadataCacheOverflowCount++ + if (s.receiverView.metadataCacheOverflowCount-1)%10 == 0 { + r.logger.Infow( + "metadata cache overflow", + "senderSnapshotID", i+cFirstSnapshotID, + "senderSnapshot", s, + "timeSinceLastRR", timeSinceLastRR(), + "receivedRR", rr, + "extReceivedRRSN", extReceivedRRSN, + "packetsInInterval", extReceivedRRSN-s.receiverView.extLastRRSN, + "intervalStats", &is, + "aggregateIntervalStats", eis, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } + } + s.receiverView.extLastRRSN = extReceivedRRSN + s.receiverView.processedReceptionReports = append(s.receiverView.processedReceptionReports, rr) + } + + return +} + +func (r *RTPStatsSender) LastReceiverReportTime() int64 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.lastRRTime +} + +func (r *RTPStatsSender) MaybeAdjustFirstPacketTime(publisherSRData *livekit.RTCPSenderReportState, tsOffset uint64) { + r.lock.Lock() + defer r.lock.Unlock() + + if !r.initialized || publisherSRData == nil { + return + } + + if _, err, loggingFields := r.maybeAdjustFirstPacketTime(publisherSRData, tsOffset, r.extStartTS); err != nil { + r.logger.Infow(err.Error(), append(loggingFields, "rtpStats", lockedRTPStatsSenderLogEncoder{r})...) + } +} + +func (r *RTPStatsSender) GetExpectedRTPTimestamp(at time.Time) (expectedTSExt uint64, err error) { + r.lock.RLock() + defer r.lock.RUnlock() + + if r.firstTime == 0 { + err = errors.New("uninitialized") + return + } + + timeDiff := at.Sub(time.Unix(0, r.firstTime)) + expectedTSExt = r.extStartTS + r.rtpConverter.ToRTPExt(timeDiff) + return +} + +func (r *RTPStatsSender) GetRtcpSenderReport(ssrc uint32, publisherSRData *livekit.RTCPSenderReportState, tsOffset uint64, passThrough bool) *rtcp.SenderReport { + r.lock.Lock() + defer r.lock.Unlock() + + if !r.initialized || publisherSRData == nil { + return nil + } + + var ( + reportTime int64 + reportTimeAdjusted int64 + nowNTP mediatransportutil.NtpTime + nowRTPExt uint64 + ) + nowNano := mono.UnixNano() + if passThrough { + timeSincePublisherSR := time.Duration(nowNano - publisherSRData.At) + reportTime = publisherSRData.At + timeSincePublisherSR.Nanoseconds() + reportTimeAdjusted = publisherSRData.AtAdjusted + timeSincePublisherSR.Nanoseconds() + + nowNTP = mediatransportutil.ToNtpTime(mediatransportutil.NtpTime(publisherSRData.NtpTimestamp).Time().Add(timeSincePublisherSR)) + nowRTPExt = publisherSRData.RtpTimestampExt - tsOffset + r.rtpConverter.ToRTPExt(timeSincePublisherSR) + } else { + timeSincePublisherSRAdjusted := time.Duration(nowNano - publisherSRData.AtAdjusted) + reportTimeAdjusted = publisherSRData.AtAdjusted + timeSincePublisherSRAdjusted.Nanoseconds() + reportTime = reportTimeAdjusted + + nowNTP = mediatransportutil.ToNtpTime(time.Unix(0, reportTime)) + nowRTPExt = publisherSRData.RtpTimestampExt - tsOffset + r.rtpConverter.ToRTPExt(timeSincePublisherSRAdjusted) + } + + packetCount := uint32(r.getPacketsSeenPlusDuplicates(r.extStartSN, r.extHighestSN)) + octetCount := r.bytes + r.bytesDuplicate + r.bytesPadding + srData := &livekit.RTCPSenderReportState{ + NtpTimestamp: uint64(nowNTP), + RtpTimestamp: uint32(nowRTPExt), + RtpTimestampExt: nowRTPExt, + At: reportTime, + AtAdjusted: reportTimeAdjusted, + Packets: packetCount, + Octets: octetCount, + } + + ulgr := func() logger.UnlikelyLogger { + return r.logger.WithUnlikelyValues( + "curr", WrappedRTCPSenderReportStateLogger{srData}, + "feed", WrappedRTCPSenderReportStateLogger{publisherSRData}, + "tsOffset", tsOffset, + "timeNow", mono.Now(), + "reportTime", time.Unix(0, reportTime), + "reportTimeAdjusted", time.Unix(0, reportTimeAdjusted), + "timeSinceHighest", time.Duration(nowNano-r.highestTime), + "timeSinceFirst", time.Duration(nowNano-r.firstTime), + "timeSincePublisherSRAdjusted", time.Duration(nowNano-publisherSRData.AtAdjusted), + "timeSincePublisherSR", time.Duration(nowNano-publisherSRData.At), + "nowRTPExt", nowRTPExt, + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } + + if r.srNewest != nil && nowRTPExt >= r.srNewest.RtpTimestampExt { + timeSinceLastReport := nowNTP.Time().Sub(mediatransportutil.NtpTime(r.srNewest.NtpTimestamp).Time()) + rtpDiffSinceLastReport := nowRTPExt - r.srNewest.RtpTimestampExt + windowClockRate := float64(rtpDiffSinceLastReport) / timeSinceLastReport.Seconds() + if timeSinceLastReport.Seconds() > 0.2 && math.Abs(float64(r.params.ClockRate)-windowClockRate) > 0.2*float64(r.params.ClockRate) { + r.clockSkewCount++ + if (r.clockSkewCount-1)%100 == 0 { + ulgr().Infow( + "sending sender report, clock skew", + "timeSinceLastReport", timeSinceLastReport, + "rtpDiffSinceLastReport", rtpDiffSinceLastReport, + "windowClockRate", windowClockRate, + "count", r.clockSkewCount, + ) + } + } + } + + if r.srNewest != nil && nowRTPExt < r.srNewest.RtpTimestampExt { + // If report being generated is behind the last report, skip it. + // Should not happen. + ulgr().Infow("sending sender report, out-of-order, skipping") + return nil + } + + r.srNewest = srData + if r.srFirst == nil { + r.srFirst = r.srNewest + } + + r.rttMarker = rttMarker{ + ntpTime: nowNTP, + sentAt: mono.Now(), + } + + return &rtcp.SenderReport{ + SSRC: ssrc, + NTPTime: uint64(nowNTP), + RTPTime: uint32(nowRTPExt), + PacketCount: packetCount, + OctetCount: uint32(octetCount), + } +} + +func (r *RTPStatsSender) DeltaInfo(snapshotID uint32) *RTPDeltaInfo { + r.lock.Lock() + defer r.lock.Unlock() + + deltaInfo, err, loggingFields := r.deltaInfo( + snapshotID, + r.extStartSN, + r.extHighestSN, + ) + if err != nil { + r.logger.Infow(err.Error(), append(loggingFields, "rtpStats", lockedRTPStatsSenderLogEncoder{r})...) + } + + return deltaInfo +} + +func (r *RTPStatsSender) DeltaInfoSender(senderSnapshotID uint32) (*RTPDeltaInfo, *RTPDeltaInfo) { + r.lock.Lock() + defer r.lock.Unlock() + + var deltaStatsSenderView *RTPDeltaInfo + thenSenderView, nowSenderView := r.getAndResetSenderSnapshotWindow(senderSnapshotID) + if thenSenderView != nil && nowSenderView != nil { + startTime := thenSenderView.startTime + endTime := nowSenderView.startTime + + packetsExpected := uint32(nowSenderView.extStartSN - thenSenderView.extStartSN) + if packetsExpected > cNumSequenceNumbers { + r.logger.Warnw( + "too many packets expected in delta (sender)", nil, + "senderSnapshotID", senderSnapshotID, + "senderSnapshotNow", nowSenderView, + "senderSnapshotThen", thenSenderView, + "packetsExpected", packetsExpected, + "duration", time.Duration(endTime-startTime), + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } else if packetsExpected != 0 { + packetsLostFeed := uint32(nowSenderView.packetsLostFeed - thenSenderView.packetsLostFeed) + if int32(packetsLostFeed) < 0 { + packetsLostFeed = 0 + } + if packetsLostFeed > packetsExpected { + r.logger.Warnw( + "unexpected number of packets lost", nil, + "senderSnapshotID", senderSnapshotID, + "senderSnapshotNow", nowSenderView, + "senderSnapshotThen", thenSenderView, + "packetsExpected", packetsExpected, + "packetsLostFeed", packetsLostFeed, + "duration", time.Duration(endTime-startTime), + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + packetsLostFeed = packetsExpected + } + + maxJitterTime := thenSenderView.maxJitterFeed / float64(r.params.ClockRate) * 1e6 + + deltaStatsSenderView = &RTPDeltaInfo{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + Packets: packetsExpected - uint32(nowSenderView.packetsPadding-thenSenderView.packetsPadding), + Bytes: nowSenderView.bytes - thenSenderView.bytes, + HeaderBytes: nowSenderView.headerBytes - thenSenderView.headerBytes, + PacketsDuplicate: uint32(nowSenderView.packetsDuplicate - thenSenderView.packetsDuplicate), + BytesDuplicate: nowSenderView.bytesDuplicate - thenSenderView.bytesDuplicate, + HeaderBytesDuplicate: nowSenderView.headerBytesDuplicate - thenSenderView.headerBytesDuplicate, + PacketsPadding: uint32(nowSenderView.packetsPadding - thenSenderView.packetsPadding), + BytesPadding: nowSenderView.bytesPadding - thenSenderView.bytesPadding, + HeaderBytesPadding: nowSenderView.headerBytesPadding - thenSenderView.headerBytesPadding, + PacketsMissing: packetsLostFeed, + PacketsOutOfOrder: uint32(nowSenderView.packetsOutOfOrderFeed - thenSenderView.packetsOutOfOrderFeed), + Frames: nowSenderView.frames - thenSenderView.frames, + JitterMax: maxJitterTime, + Nacks: nowSenderView.nacks - thenSenderView.nacks, + NackRepeated: nowSenderView.nackRepeated - thenSenderView.nackRepeated, + Plis: nowSenderView.plis - thenSenderView.plis, + Firs: nowSenderView.firs - thenSenderView.firs, + } + } + } + + var deltaStatsReceiverView *RTPDeltaInfo + if r.lastRRTime != 0 { + thenReceiverView, nowReceiverView := r.getAndResetSenderSnapshotReceiverView(senderSnapshotID) + if thenReceiverView != nil && nowReceiverView != nil { + startTime := thenReceiverView.startTime + endTime := nowReceiverView.startTime + + packetsExpected := uint32(nowReceiverView.extStartSN - thenReceiverView.extStartSN) + if packetsExpected > cNumSequenceNumbers { + r.logger.Warnw( + "too many packets expected in delta (sender - receiver view)", nil, + "senderSnapshotID", senderSnapshotID, + "senderSnapshotNow", nowReceiverView, + "senderSnapshotThen", thenReceiverView, + "packetsExpected", packetsExpected, + "duration", time.Duration(endTime-startTime), + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + } else if packetsExpected != 0 { + // do not process if no RTCP RR (OR) publisher is not producing any data + packetsLost := uint32(nowReceiverView.packetsLost - thenReceiverView.packetsLost) + if int32(packetsLost) < 0 { + packetsLost = 0 + } + packetsLostFeed := uint32(nowReceiverView.packetsLostFeed - thenReceiverView.packetsLostFeed) + if int32(packetsLostFeed) < 0 { + packetsLostFeed = 0 + } + if packetsLost > packetsExpected { + r.logger.Warnw( + "unexpected number of packets lost (sender - receiver view)", nil, + "senderSnapshotID", senderSnapshotID, + "senderSnapshotNow", nowReceiverView, + "senderSnapshotThen", thenReceiverView, + "packetsExpected", packetsExpected, + "packetsLost", packetsLost, + "packetsLostFeed", packetsLostFeed, + "duration", time.Duration(endTime-startTime), + "rtpStats", lockedRTPStatsSenderLogEncoder{r}, + ) + packetsLost = packetsExpected + } + + maxJitterTime := thenReceiverView.maxJitter / float64(r.params.ClockRate) * 1e6 + + deltaStatsReceiverView = &RTPDeltaInfo{ + StartTime: time.Unix(0, startTime), + EndTime: time.Unix(0, endTime), + Packets: packetsExpected - uint32(nowReceiverView.packetsPadding-thenReceiverView.packetsPadding), + Bytes: nowReceiverView.bytes - thenReceiverView.bytes, + HeaderBytes: nowReceiverView.headerBytes - thenReceiverView.headerBytes, + PacketsDuplicate: uint32(nowReceiverView.packetsDuplicate - thenReceiverView.packetsDuplicate), + BytesDuplicate: nowReceiverView.bytesDuplicate - thenReceiverView.bytesDuplicate, + HeaderBytesDuplicate: nowReceiverView.headerBytesDuplicate - thenReceiverView.headerBytesDuplicate, + PacketsPadding: uint32(nowReceiverView.packetsPadding - thenReceiverView.packetsPadding), + BytesPadding: nowReceiverView.bytesPadding - thenReceiverView.bytesPadding, + HeaderBytesPadding: nowReceiverView.headerBytesPadding - thenReceiverView.headerBytesPadding, + PacketsLost: packetsLost, + PacketsMissing: packetsLostFeed, + PacketsOutOfOrder: uint32(nowReceiverView.packetsOutOfOrderFeed - thenReceiverView.packetsOutOfOrderFeed), + Frames: nowReceiverView.frames - thenReceiverView.frames, + RttMax: thenReceiverView.maxRtt, + JitterMax: maxJitterTime, + Nacks: nowReceiverView.nacks - thenReceiverView.nacks, + NackRepeated: nowReceiverView.nackRepeated - thenReceiverView.nackRepeated, + Plis: nowReceiverView.plis - thenReceiverView.plis, + Firs: nowReceiverView.firs - thenReceiverView.firs, + } + } + } + } + + return deltaStatsSenderView, deltaStatsReceiverView +} + +func (r *RTPStatsSender) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + r.lock.RLock() + defer r.lock.RUnlock() + + return lockedRTPStatsSenderLogEncoder{r}.MarshalLogObject(e) +} + +func (r *RTPStatsSender) ToProto() *livekit.RTPStats { + r.lock.RLock() + defer r.lock.RUnlock() + + p := r.toProto( + getPacketsExpected(r.extStartSN, r.extHighestSN), + r.getPacketsSeenMinusPadding(r.extStartSN, r.extHighestSN), + r.packetsLostFromRR, + r.extStartTS, + r.extHighestTS, + r.jitterFromRR, + r.maxJitterFromRR, + ) + + if p != nil { + p.LayerLockPlis = r.layerLockPlis + p.LastLayerLockPli = timestamppb.New(r.lastLayerLockPli) + } + return p +} + +func (r *RTPStatsSender) getAndResetSenderSnapshotWindow(senderSnapshotID uint32) (*senderSnapshotWindow, *senderSnapshotWindow) { + if !r.initialized || senderSnapshotID < cFirstSnapshotID { + return nil, nil + } + + idx := senderSnapshotID - cFirstSnapshotID + then := r.senderSnapshots[idx] + if !then.senderView.isValid { + then.senderView = initSenderSnapshotWindow(r.startTime, r.extStartSN) + r.senderSnapshots[idx] = then + } + + // snapshot now + r.senderSnapshots[idx].senderView = r.getSenderSnapshotWindow(mono.UnixNano()) + return &then.senderView, &r.senderSnapshots[idx].senderView +} + +func (r *RTPStatsSender) getSenderSnapshotWindow(startTime int64) senderSnapshotWindow { + return senderSnapshotWindow{ + isValid: true, + startTime: startTime, + extStartSN: r.extHighestSN + 1, + bytes: r.bytes, + headerBytes: r.headerBytes, + packetsPadding: r.packetsPadding, + bytesPadding: r.bytesPadding, + headerBytesPadding: r.headerBytesPadding, + packetsDuplicate: r.packetsDuplicate, + bytesDuplicate: r.bytesDuplicate, + headerBytesDuplicate: r.headerBytesDuplicate, + packetsOutOfOrderFeed: r.packetsOutOfOrder, + packetsLostFeed: r.packetsLost, + frames: r.frames, + nacks: r.nacks, + nackRepeated: r.nackRepeated, + plis: r.plis, + firs: r.firs, + maxJitterFeed: r.jitter, + } +} + +func (r *RTPStatsSender) getAndResetSenderSnapshotReceiverView(senderSnapshotID uint32) (*senderSnapshotReceiverView, *senderSnapshotReceiverView) { + if !r.initialized || r.lastRRTime == 0 || senderSnapshotID < cFirstSnapshotID { + return nil, nil + } + + idx := senderSnapshotID - cFirstSnapshotID + then := r.senderSnapshots[idx] + if !then.receiverView.isValid { + then.receiverView = initSenderSnapshotReceiverView(r.startTime, r.extStartSN) + r.senderSnapshots[idx] = then + } + + // snapshot now + r.senderSnapshots[idx].receiverView = r.getSenderSnapshotReceiverView(r.lastRRTime, &then.receiverView) + return &then.receiverView, &r.senderSnapshots[idx].receiverView +} + +func (r *RTPStatsSender) getSenderSnapshotReceiverView(startTime int64, s *senderSnapshotReceiverView) senderSnapshotReceiverView { + if s == nil { + return senderSnapshotReceiverView{} + } + + return senderSnapshotReceiverView{ + senderSnapshotWindow: senderSnapshotWindow{ + isValid: true, + startTime: startTime, + extStartSN: s.extLastRRSN + 1, + bytes: s.bytes + s.intervalStats.bytes, + headerBytes: s.headerBytes + s.intervalStats.headerBytes, + packetsPadding: s.packetsPadding + s.intervalStats.packetsPadding, + bytesPadding: s.bytesPadding + s.intervalStats.bytesPadding, + headerBytesPadding: s.headerBytesPadding + s.intervalStats.headerBytesPadding, + packetsDuplicate: r.packetsDuplicate, + bytesDuplicate: r.bytesDuplicate, + headerBytesDuplicate: r.headerBytesDuplicate, + packetsOutOfOrderFeed: s.packetsOutOfOrderFeed + s.intervalStats.packetsOutOfOrderFeed, + packetsLostFeed: s.packetsLostFeed + s.intervalStats.packetsLostFeed, + frames: s.frames + s.intervalStats.frames, + nacks: r.nacks, + nackRepeated: r.nackRepeated, + plis: r.plis, + firs: r.firs, + maxJitterFeed: r.jitter, + }, + packetsLost: r.packetsLostFromRR, + maxRtt: r.rtt, + maxJitter: r.jitterFromRR, + extLastRRSN: s.extLastRRSN, + metadataCacheOverflowCount: s.metadataCacheOverflowCount, + } +} + +func (r *RTPStatsSender) getSnInfoOutOfOrderSlot(esn uint64, ehsn uint64) int { + offset := int64(ehsn - esn) + if offset >= int64(len(r.snInfos)) || offset < 0 { + // too old OR too new (i. e. ahead of highest) + return -1 + } + + return int(esn) % len(r.snInfos) +} + +func (r *RTPStatsSender) setSnInfo(esn uint64, ehsn uint64, pktSize uint16, hdrSize uint8, payloadSize uint16, marker bool, isOutOfOrder bool) { + var slot int + if int64(esn-ehsn) < 0 { + slot = r.getSnInfoOutOfOrderSlot(esn, ehsn) + if slot < 0 { + return + } + } else { + slot = int(esn) % len(r.snInfos) + } + + snInfo := &r.snInfos[slot] + snInfo.pktSize = pktSize + snInfo.hdrSize = hdrSize + snInfo.flags = 0 + if marker { + snInfo.flags |= snInfoFlagMarker + } + if payloadSize == 0 { + snInfo.flags |= snInfoFlagPadding + } + if isOutOfOrder { + snInfo.flags |= snInfoFlagOutOfOrder + } +} + +func (r *RTPStatsSender) clearSnInfos(extStartInclusive uint64, extEndExclusive uint64) { + if extEndExclusive <= extStartInclusive { + return + } + + for esn := extStartInclusive; esn != extEndExclusive; esn++ { + snInfo := &r.snInfos[int(esn)%len(r.snInfos)] + snInfo.pktSize = 0 + snInfo.hdrSize = 0 + snInfo.flags = 0 + } +} + +func (r *RTPStatsSender) isSnInfoLost(esn uint64, ehsn uint64) bool { + slot := r.getSnInfoOutOfOrderSlot(esn, ehsn) + if slot < 0 { + return false + } + + return r.snInfos[slot].pktSize == 0 +} + +func (r *RTPStatsSender) getIntervalStats( + extStartInclusive uint64, + extEndExclusive uint64, + ehsn uint64, +) (intervalStats intervalStats) { + upperBound := ehsn + 1 + lowerBound := uint64(0) + if n := uint64(len(r.snInfos)); n != 0 && ehsn >= n-1 { + lowerBound = ehsn - n + 1 + } + extStartInclusiveClamped := max(min(extStartInclusive, upperBound), lowerBound) + extEndExclusiveClamped := max(min(extEndExclusive, upperBound), extStartInclusiveClamped) + + intervalStats.packetsNotFoundMetadata = (extEndExclusive - extStartInclusive) - (extEndExclusiveClamped - extStartInclusiveClamped) + + for esn := extStartInclusiveClamped; esn != extEndExclusiveClamped; esn++ { + slot := r.getSnInfoOutOfOrderSlot(esn, ehsn) + snInfo := &r.snInfos[slot] + switch { + case snInfo.pktSize == 0: + intervalStats.packetsLostFeed++ + + case snInfo.flags&snInfoFlagPadding != 0: + intervalStats.packetsPadding++ + intervalStats.bytesPadding += uint64(snInfo.pktSize) + intervalStats.headerBytesPadding += uint64(snInfo.hdrSize) + + default: + intervalStats.packets++ + intervalStats.bytes += uint64(snInfo.pktSize) + intervalStats.headerBytes += uint64(snInfo.hdrSize) + if (snInfo.flags & snInfoFlagOutOfOrder) != 0 { + intervalStats.packetsOutOfOrderFeed++ + } + } + + if (snInfo.flags & snInfoFlagMarker) != 0 { + intervalStats.frames++ + } + } + return +} + +func (r *RTPStatsSender) ExtHighestSequenceNumber() uint64 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.extHighestSN +} + +// ------------------------------------------------------------------- + +type lockedRTPStatsSenderLogEncoder struct { + *RTPStatsSender +} + +func (r lockedRTPStatsSenderLogEncoder) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r.RTPStatsSender == nil { + return nil + } + + packetsExpected := getPacketsExpected(r.extStartSN, r.extHighestSN) + elapsedSeconds, err := r.rtpStatsBase.marshalLogObject( + e, + packetsExpected, + r.getPacketsSeenMinusPadding(r.extStartSN, r.extHighestSN), + r.extStartTS, + r.extHighestTS, + ) + if err != nil { + return err + } + + e.AddUint64("extStartSN", r.extStartSN) + e.AddUint64("extHighestSN", r.extHighestSN) + + e.AddUint64("extStartTS", r.extStartTS) + e.AddUint64("extHighestTS", r.extHighestTS) + + e.AddTime("lastRRTime", time.Unix(0, r.lastRRTime)) + e.AddReflected("lastRR", r.lastRR) + e.AddUint64("extHighestSNFromRR", r.extHighestSNFromRR) + e.AddUint64("extHighestSNFromRRMisalignment", r.extHighestSNFromRRMisalignment) + e.AddUint64("packetsLostFromRR", r.packetsLostFromRR) + e.AddFloat64("packetsLostFromRRRate", float64(r.packetsLostFromRR)/elapsedSeconds) + if packetsExpected != 0 { + e.AddFloat32("packetLostFromRRPercentage", float32(r.packetsLostFromRR)/float32(packetsExpected)*100.0) + } + e.AddFloat64("jitterFromRR", r.jitterFromRR) + e.AddFloat64("maxJitterFromRR", r.maxJitterFromRR) + + e.AddUint32("layerLockPlis", r.layerLockPlis) + e.AddTime("lastLayerLockPli", r.lastLayerLockPli) + return nil +} + +// ------------------------------------------------------------------- + +func initSenderSnapshot(startTime int64, extStartSN uint64) senderSnapshot { + return senderSnapshot{ + senderView: initSenderSnapshotWindow(startTime, extStartSN), + receiverView: initSenderSnapshotReceiverView(startTime, extStartSN), + } +} + +func initSenderSnapshotWindow(startTime int64, extStartSN uint64) senderSnapshotWindow { + return senderSnapshotWindow{ + isValid: true, + startTime: startTime, + extStartSN: extStartSN, + } +} + +func initSenderSnapshotReceiverView(startTime int64, extStartSN uint64) senderSnapshotReceiverView { + return senderSnapshotReceiverView{ + senderSnapshotWindow: senderSnapshotWindow{ + isValid: true, + startTime: startTime, + extStartSN: extStartSN, + }, + extLastRRSN: extStartSN - 1, + } +} diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_sender_lite.go b/livekit/pkg/sfu/rtpstats/rtpstats_sender_lite.go new file mode 100644 index 0000000..ae42235 --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_sender_lite.go @@ -0,0 +1,121 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils/mono" + "go.uber.org/zap/zapcore" +) + +type RTPStatsSenderLite struct { + *rtpStatsBaseLite + + extStartSN uint64 + extHighestSN uint64 +} + +func NewRTPStatsSenderLite(params RTPStatsParams) *RTPStatsSenderLite { + return &RTPStatsSenderLite{ + rtpStatsBaseLite: newRTPStatsBaseLite(params), + } +} + +func (r *RTPStatsSenderLite) Update(packetTime int64, packetSize int, extSequenceNumber uint64) { + r.lock.Lock() + defer r.lock.Unlock() + + if r.endTime != 0 { + return + } + + if !r.initialized { + r.initialized = true + + r.startTime = mono.UnixNano() + + r.extStartSN = extSequenceNumber + r.extHighestSN = extSequenceNumber - 1 + + r.logger.Debugw( + "rtp sender lite stream start", + "rtpStats", lockedRTPStatsSenderLiteLogEncoder{r}, + ) + } + + gapSN := int64(extSequenceNumber - r.extHighestSN) + if gapSN <= 0 { // duplicate OR out-of-order + r.packetsOutOfOrder++ // counting duplicate as out-of-order + r.packetsLost-- + } else { // in-order + r.updateGapHistogram(int(gapSN)) + r.packetsLost += uint64(gapSN - 1) + + r.extHighestSN = extSequenceNumber + } + + r.bytes += uint64(packetSize) +} + +func (r *RTPStatsSenderLite) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + r.lock.RLock() + defer r.lock.RUnlock() + + return lockedRTPStatsSenderLiteLogEncoder{r}.MarshalLogObject(e) +} + +func (r *RTPStatsSenderLite) ToProto() *livekit.RTPStats { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.rtpStatsBaseLite.toProto(r.extStartSN, r.extHighestSN, r.packetsLost) +} + +func (r *RTPStatsSenderLite) ExtHighestSequenceNumber() uint64 { + r.lock.RLock() + defer r.lock.RUnlock() + + return r.extHighestSN +} + +// ------------------------------------------------------------------- + +type lockedRTPStatsSenderLiteLogEncoder struct { + *RTPStatsSenderLite +} + +func (r lockedRTPStatsSenderLiteLogEncoder) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r.RTPStatsSenderLite == nil { + return nil + } + + if _, err := r.rtpStatsBaseLite.marshalLogObject( + e, + getPacketsExpected(r.extStartSN, r.extHighestSN), + getPacketsExpected(r.extStartSN, r.extHighestSN), + ); err != nil { + return err + } + + e.AddUint64("extStartSN", r.extStartSN) + e.AddUint64("extHighestSN", r.extHighestSN) + return nil +} + +// ------------------------------------------------------------------- diff --git a/livekit/pkg/sfu/rtpstats/rtpstats_test.go b/livekit/pkg/sfu/rtpstats/rtpstats_test.go new file mode 100644 index 0000000..49066d8 --- /dev/null +++ b/livekit/pkg/sfu/rtpstats/rtpstats_test.go @@ -0,0 +1,266 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtpstats + +import ( + "math/rand" + "testing" + "time" + + "github.com/pion/rtp" + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/logger" +) + +func getPacket(sn uint16, ts uint32, payloadSize int) *rtp.Packet { + return &rtp.Packet{ + Header: rtp.Header{ + SequenceNumber: sn, + Timestamp: ts, + }, + Payload: make([]byte, payloadSize), + } +} + +func Test_RTPStatsReceiver_Update(t *testing.T) { + clockRate := uint32(90000) + r := NewRTPStatsReceiver(RTPStatsParams{ + ClockRate: clockRate, + Logger: logger.GetLogger(), + }) + + sequenceNumber := uint16(rand.Float64() * float64(1<<16)) + timestamp := uint32(rand.Float64() * float64(1<<32)) + packet := getPacket(sequenceNumber, timestamp, 1000) + flowState := r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.True(t, r.initialized) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) + + // in-order, no loss + sequenceNumber++ + timestamp += 3000 + packet = getPacket(sequenceNumber, timestamp, 1000) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) + + // out-of-order, would cause a restart which is disallowed + packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, RTPFlowUnhandledReasonPreStartTimestamp, flowState.UnhandledReason) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) + require.Equal(t, uint64(0), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) + + // duplicate of the above out-of-order packet, but would not be handled as it causes a restart + packet = getPacket(sequenceNumber-10, timestamp-30000, 1000) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, RTPFlowUnhandledReasonPreStartTimestamp, flowState.UnhandledReason) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) + require.Equal(t, uint64(0), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) + + // loss + sequenceNumber += 10 + timestamp += 30000 + packet = getPacket(sequenceNumber, timestamp, 1000) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, uint64(sequenceNumber-9), flowState.LossStartInclusive) + require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) + require.Equal(t, uint64(9), r.packetsLost) + + // out-of-order should decrement number of lost packets + packet = getPacket(sequenceNumber-6, timestamp-18000, 1000) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) + require.Equal(t, uint64(1), r.packetsOutOfOrder) + require.Equal(t, uint64(0), r.packetsDuplicate) + require.Equal(t, uint64(8), r.packetsLost) + + // test sequence number history + // with a gap + sequenceNumber += 2 + timestamp += 6000 + packet = getPacket(sequenceNumber, timestamp, 1000) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, uint64(sequenceNumber-1), flowState.LossStartInclusive) + require.Equal(t, uint64(sequenceNumber), flowState.LossEndExclusive) + require.Equal(t, uint64(9), r.packetsLost) + require.False(t, r.history.IsSet(uint64(sequenceNumber)-1)) + + // out-of-order + sequenceNumber-- + timestamp -= 3000 + packet = getPacket(sequenceNumber, timestamp, 999) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, uint64(8), r.packetsLost) + require.Equal(t, uint64(2), r.packetsOutOfOrder) + require.True(t, r.history.IsSet(uint64(sequenceNumber))) + + // padding only + sequenceNumber += 2 + timestamp += 3000 + packet = getPacket(sequenceNumber, timestamp, 0) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 25, + ) + require.Equal(t, uint64(8), r.packetsLost) + require.Equal(t, uint64(2), r.packetsOutOfOrder) + require.True(t, r.history.IsSet(uint64(sequenceNumber))) + require.True(t, r.history.IsSet(uint64(sequenceNumber)-1)) + require.True(t, r.history.IsSet(uint64(sequenceNumber)-2)) + + // old packet, but simulating increasing sequence number after roll over + packet = getPacket(sequenceNumber+400, timestamp-6000, 300) + flowState = r.Update( + time.Now().UnixNano(), + packet.Header.SequenceNumber, + packet.Header.Timestamp, + packet.Header.Marker, + packet.Header.MarshalSize(), + len(packet.Payload), + 0, + ) + require.Equal(t, RTPFlowUnhandledReasonOldSequenceNumber, flowState.UnhandledReason) + require.Equal(t, sequenceNumber, r.sequenceNumber.GetHighest()) + require.Equal(t, sequenceNumber, uint16(r.sequenceNumber.GetExtendedHighest())) + require.Equal(t, timestamp, r.timestamp.GetHighest()) + require.Equal(t, timestamp, uint32(r.timestamp.GetExtendedHighest())) + + r.Stop() +} + +func Test_RTPStatsReceiver_Restart(t *testing.T) { + clockRate := uint32(90000) + r := NewRTPStatsReceiver(RTPStatsParams{ + ClockRate: clockRate, + Logger: logger.GetLogger(), + }) + + // should not restart till there are at least threshold packets + require.False(t, r.maybeRestart(10, 20, 1000)) + require.False(t, r.maybeRestart(11, 20, 1000)) + require.False(t, r.maybeRestart(13, 20, 1000)) + require.False(t, r.maybeRestart(14, 20, 1000)) + // although adding 5th packet should have enough packets for a check, + // still should not restart as there is a sequence number gap between 11 and 13 + require.False(t, r.maybeRestart(15, 20, 1000)) + require.False(t, r.maybeRestart(16, 19, 1000)) + // has enough packets, but still cannot restart because timestamps are not increasing + require.False(t, r.maybeRestart(17, 21, 1000)) + require.False(t, r.maybeRestart(18, 21, 1000)) + require.False(t, r.maybeRestart(19, 21, 1000)) + // can restart as there are enough packets with proper sequencing + require.True(t, r.maybeRestart(20, 21, 1000)) + require.Equal(t, restartThreshold, len(r.restartPackets)) + + r.resetRestart() + require.Zero(t, len(r.restartPackets)) + + r.Stop() +} + +func Test_RTPStatsSender_getIntervalStats(t *testing.T) { + t.Run("packetsNotFoundMetadata should match lost packets", func(t *testing.T) { + r := NewRTPStatsSender(RTPStatsParams{}, 1024) + stats := r.getIntervalStats(0, 10000, 10000) + require.EqualValues(t, 8977, stats.packetsNotFoundMetadata) + }) +} diff --git a/livekit/pkg/sfu/sequencer.go b/livekit/pkg/sfu/sequencer.go new file mode 100644 index 0000000..1e08ef9 --- /dev/null +++ b/livekit/pkg/sfu/sequencer.go @@ -0,0 +1,463 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "math" + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/utils" + "github.com/livekit/protocol/logger" + "go.uber.org/zap/zapcore" +) + +const ( + defaultRtt = 70 + ignoreRetransmission = 100 // Ignore packet retransmission after ignoreRetransmission milliseconds + maxAck = 3 +) + +type packetMeta struct { + // Original extended sequence number from stream. + // The original extended sequence number is used to find the original + // packet from publisher + sourceSeqNo uint64 + // Modified sequence number after offset. + // This sequence number is used for the associated + // down track, is modified according the offsets, and + // must not be shared + targetSeqNo uint16 + // Modified timestamp for current associated + // down track. + timestamp uint32 + // Modified marker + marker bool + // The last time this packet was nack requested. + // Sometimes clients request the same packet more than once, so keep + // track of the requested packets helps to avoid writing multiple times + // the same packet. + // The resolution is 1 ms counting after the sequencer start time. + lastNack uint32 + // number of NACKs this packet has received + nacked uint8 + // Spatial layer of packet + layer int8 + // Information that differs depending on the codec + codecBytes [8]byte + numCodecBytesIn uint8 + numCodecBytesOut uint8 + codecBytesSlice []byte + // Dependency Descriptor of packet + ddBytes [8]byte + ddBytesSize uint8 + ddBytesSlice []byte + // abs-capture-time of packet + actBytes []byte +} + +func (pm packetMeta) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddUint64("sourceSeqNo", pm.sourceSeqNo) + e.AddUint16("targetSeqNo", pm.targetSeqNo) + e.AddInt8("layer", pm.layer) + e.AddUint8("nacked", pm.nacked) + e.AddUint8("numCodecBytesIn", pm.numCodecBytesIn) + if len(pm.codecBytesSlice) != 0 { + e.AddInt("codecBytesSlice", len(pm.codecBytesSlice)) + } else { + e.AddUint8("numCodecBytesOut", pm.numCodecBytesOut) + } + if len(pm.ddBytesSlice) != 0 { + e.AddInt("ddBytesSlice", len(pm.ddBytesSlice)) + } else { + e.AddUint8("ddBytesSize", pm.ddBytesSize) + } + if len(pm.actBytes) != 0 { + e.AddInt("actBytes", len(pm.actBytes)) + } + return nil +} + +type extPacketMeta struct { + packetMeta + extSequenceNumber uint64 + extTimestamp uint64 +} + +func (epm extPacketMeta) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddObject("packetMeta", epm.packetMeta) + e.AddUint64("extSequenceNumber", epm.extSequenceNumber) + return nil +} + +// Sequencer stores the packet sequence received by the down track +type sequencer struct { + sync.Mutex + size int + startTime int64 + initialized bool + extStartSN uint64 + extHighestSN uint64 + snOffset uint64 + extHighestTS uint64 + meta []packetMeta + snRangeMap *utils.RangeMap[uint64, uint64] + rtt uint32 + logger logger.Logger +} + +func newSequencer(size int, maybeSparse bool, logger logger.Logger) *sequencer { + if size == 0 { + return nil + } + + s := &sequencer{ + size: size, + startTime: time.Now().UnixNano(), + meta: make([]packetMeta, size), + rtt: defaultRtt, + logger: logger, + } + + if maybeSparse { + s.snRangeMap = utils.NewRangeMap[uint64, uint64]((size + 1) / 2) // assume run lengths of at least 2 in between padding bursts + } + return s +} + +func (s *sequencer) setRTT(rtt uint32) { + s.Lock() + defer s.Unlock() + + if rtt == 0 { + s.rtt = defaultRtt + } else { + s.rtt = rtt + } +} + +func (s *sequencer) push( + packetTime int64, + extIncomingSN, extModifiedSN uint64, + extModifiedTS uint64, + marker bool, + layer int8, + codecBytes []byte, + numCodecBytesIn int, + ddBytes []byte, + actBytes []byte, +) { + s.Lock() + defer s.Unlock() + + if !s.initialized { + s.initialized = true + s.extStartSN = extModifiedSN + s.extHighestSN = extModifiedSN + s.extHighestTS = extModifiedTS + s.updateSNOffset() + } + + if extModifiedSN < s.extStartSN { + // old packet, should not happen + return + } + + extHighestSNAdjusted := s.extHighestSN - s.snOffset + extModifiedSNAdjusted := extModifiedSN - s.snOffset + if extModifiedSN < s.extHighestSN { + if s.snRangeMap != nil { + snOffset, err := s.snRangeMap.GetValue(extModifiedSN) + if err != nil { + s.logger.Errorw( + "could not get sequence number offset", err, + "extStartSN", s.extStartSN, + "extHighestSN", s.extHighestSN, + "extIncomingSN", extIncomingSN, + "extModifiedSN", extModifiedSN, + "snOffset", s.snOffset, + ) + return + } + + extModifiedSNAdjusted = extModifiedSN - snOffset + } + } + + if int64(extModifiedSNAdjusted-extHighestSNAdjusted) <= -int64(s.size) { + s.logger.Warnw( + "old packet, cannot be sequenced", nil, + "extHighestSN", s.extHighestSN, + "extIncomingSN", extIncomingSN, + "extModifiedSN", extModifiedSN, + ) + return + } + + // invalidate missing sequence numbers + if extModifiedSNAdjusted > extHighestSNAdjusted { + numInvalidated := 0 + for esn := extHighestSNAdjusted + 1; esn != extModifiedSNAdjusted; esn++ { + s.invalidateSlot(int(esn % uint64(s.size))) + numInvalidated++ + if numInvalidated >= s.size { + break + } + } + } + + slot := extModifiedSNAdjusted % uint64(s.size) + s.meta[slot] = packetMeta{ + sourceSeqNo: extIncomingSN, + targetSeqNo: uint16(extModifiedSN), + timestamp: uint32(extModifiedTS), + marker: marker, + layer: layer, + numCodecBytesIn: uint8(numCodecBytesIn), + lastNack: s.getRefTime(packetTime), // delay retransmissions after the original transmission + } + pm := &s.meta[slot] + + pm.numCodecBytesOut = uint8(len(codecBytes)) + if len(codecBytes) > len(pm.codecBytes) { + pm.codecBytesSlice = append([]byte{}, codecBytes...) + } else { + copy(pm.codecBytes[:pm.numCodecBytesOut], codecBytes) + } + + pm.ddBytesSize = uint8(len(ddBytes)) + if len(ddBytes) > len(pm.ddBytes) { + pm.ddBytesSlice = append([]byte{}, ddBytes...) + } else { + copy(pm.ddBytes[:pm.ddBytesSize], ddBytes) + } + + pm.actBytes = append([]byte{}, actBytes...) + + if extModifiedSN > s.extHighestSN { + s.extHighestSN = extModifiedSN + } + if extModifiedTS > s.extHighestTS { + s.extHighestTS = extModifiedTS + } +} + +func (s *sequencer) pushPadding(extStartSNInclusive uint64, extEndSNInclusive uint64) { + s.Lock() + defer s.Unlock() + + if s.snRangeMap == nil || !s.initialized { + return + } + + if extStartSNInclusive <= s.extHighestSN { + // a higher sequence number has already been recorded with an offset, + // adding an exclusion range before the highest means the offset of sequence numbers + // after the exclusion range will be affected and all those higher sequence numbers + // need to be patched. + // + // Not recording exclusion range means a few slots (of the size of exclusion range) + // are wasted in this cycle. That should be fine as the exclusion ranges should be + // a few packets at a time. + if extEndSNInclusive >= s.extHighestSN { + s.logger.Errorw("cannot exclude overlapping range", nil, "extHighestSN", s.extHighestSN, "startSN", extStartSNInclusive, "endSN", extEndSNInclusive) + } else { + s.logger.Warnw("cannot exclude old range", nil, "extHighestSN", s.extHighestSN, "startSN", extStartSNInclusive, "endSN", extEndSNInclusive) + } + + // if exclusion range is before what has already been sequenced, invalidate exclusion range slots + for esn := extStartSNInclusive; esn != extEndSNInclusive+1; esn++ { + diff := int64(esn - s.extHighestSN) + if diff >= 0 || diff < -int64(s.size) { + // too old OR too new (too new should not happen, just be safe) + continue + } + + snOffset, err := s.snRangeMap.GetValue(esn) + if err != nil { + s.logger.Errorw("could not get sequence number offset", err, "sn", esn) + continue + } + + slot := (esn - snOffset) % uint64(s.size) + s.invalidateSlot(int(slot)) + } + return + } + + if err := s.snRangeMap.ExcludeRange(extStartSNInclusive, extEndSNInclusive+1); err != nil { + s.logger.Errorw("could not exclude range", err, "startSN", extStartSNInclusive, "endSN", extEndSNInclusive) + return + } + + s.extHighestSN = extEndSNInclusive + s.updateSNOffset() +} + +func (s *sequencer) getExtPacketMetas(seqNo []uint16) []extPacketMeta { + s.Lock() + defer s.Unlock() + + if !s.initialized { + return nil + } + + snOffset := uint64(0) + var err error + extPacketMetas := make([]extPacketMeta, 0, len(seqNo)) + refTime := s.getRefTime(time.Now().UnixNano()) + highestSN := uint16(s.extHighestSN) + highestTS := uint32(s.extHighestTS) + for _, sn := range seqNo { + diff := highestSN - sn + if diff > (1 << 15) { + // out-of-order from head (should not happen, just be safe) + continue + } + + // find slot by adjusting for padding only packets that were not recorded in sequencer + extSN := uint64(sn) + (s.extHighestSN & 0xFFFF_FFFF_FFFF_0000) + if sn > highestSN { + extSN -= (1 << 16) + } + + if s.snRangeMap != nil { + snOffset, err = s.snRangeMap.GetValue(extSN) + if err != nil { + // could be padding packet which is excluded and will not have value + continue + } + } + + extSNAdjusted := extSN - snOffset + extHighestSNAdjusted := s.extHighestSN - s.snOffset + if extHighestSNAdjusted-extSNAdjusted >= uint64(s.size) { + // too old + continue + } + + slot := extSNAdjusted % uint64(s.size) + meta := &s.meta[slot] + if meta.targetSeqNo != sn || s.isInvalidSlot(int(slot)) { + // invalid slot access could happen if padding packets exclusion range could not be recorded + continue + } + + if meta.nacked < maxAck && refTime-meta.lastNack > uint32(math.Min(float64(ignoreRetransmission), float64(2*s.rtt))) { + meta.nacked++ + meta.lastNack = refTime + + extTS := uint64(meta.timestamp) + (s.extHighestTS & 0xFFFF_FFFF_0000_0000) + if meta.timestamp > highestTS { + extTS -= (1 << 32) + } + epm := extPacketMeta{ + packetMeta: *meta, + extSequenceNumber: extSN, + extTimestamp: extTS, + } + epm.codecBytesSlice = append([]byte{}, meta.codecBytesSlice...) + epm.ddBytesSlice = append([]byte{}, meta.ddBytesSlice...) + epm.actBytes = append([]byte{}, meta.actBytes...) + extPacketMetas = append(extPacketMetas, epm) + } + } + + return extPacketMetas +} + +func (s *sequencer) lookupExtPacketMeta(extSN uint64) *extPacketMeta { + s.Lock() + defer s.Unlock() + + if !s.initialized { + return nil + } + + snOffset := uint64(0) + var err error + if s.snRangeMap != nil { + snOffset, err = s.snRangeMap.GetValue(extSN) + if err != nil { + return nil + } + } + + extSNAdjusted := extSN - snOffset + extHighestSNAdjusted := s.extHighestSN - s.snOffset + if extHighestSNAdjusted-extSNAdjusted >= uint64(s.size) { + // too old + return nil + } + + slot := extSNAdjusted % uint64(s.size) + meta := &s.meta[slot] + if s.isInvalidSlot(int(slot)) { + // invalid slot access could happen if padding packets exclusion range could not be recorded + return nil + } + + extTS := uint64(meta.timestamp) + (s.extHighestTS & 0xFFFF_FFFF_0000_0000) + if meta.timestamp > uint32(s.extHighestTS) { + extTS -= (1 << 32) + } + epm := extPacketMeta{ + packetMeta: *meta, + extSequenceNumber: extSN, + extTimestamp: extTS, + } + epm.codecBytesSlice = append([]byte{}, meta.codecBytesSlice...) + epm.ddBytesSlice = append([]byte{}, meta.ddBytesSlice...) + epm.actBytes = append([]byte{}, meta.actBytes...) + return &epm +} + +func (s *sequencer) getRefTime(at int64) uint32 { + return uint32((at - s.startTime) / 1e6) +} + +func (s *sequencer) updateSNOffset() { + if s.snRangeMap == nil { + return + } + + snOffset, err := s.snRangeMap.GetValue(s.extHighestSN + 1) + if err != nil { + s.logger.Errorw("could not update sequence number offset", err, "extHighestSN", s.extHighestSN) + return + } + s.snOffset = snOffset +} + +func (s *sequencer) invalidateSlot(slot int) { + if slot >= len(s.meta) { + return + } + + s.meta[slot] = packetMeta{ + sourceSeqNo: 0, + targetSeqNo: 0, + lastNack: 0, + } +} + +func (s *sequencer) isInvalidSlot(slot int) bool { + if slot >= len(s.meta) { + return true + } + + meta := &s.meta[slot] + return meta.sourceSeqNo == 0 && meta.targetSeqNo == 0 && meta.lastNack == 0 +} diff --git a/livekit/pkg/sfu/sequencer_test.go b/livekit/pkg/sfu/sequencer_test.go new file mode 100644 index 0000000..862ac53 --- /dev/null +++ b/livekit/pkg/sfu/sequencer_test.go @@ -0,0 +1,368 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/logger" +) + +func Test_sequencer(t *testing.T) { + seq := newSequencer(500, false, logger.GetLogger()) + off := uint16(15) + + for i := uint64(1); i < 518; i++ { + seq.push(time.Now().UnixNano(), i, i+uint64(off), 123, true, 2, nil, 0, nil, nil) + } + // send the last two out-of-order + seq.push(time.Now().UnixNano(), 519, 519+uint64(off), 123, false, 2, nil, 0, nil, nil) + seq.push(time.Now().UnixNano(), 518, 518+uint64(off), 123, true, 2, nil, 0, nil, nil) + + req := []uint16{57, 58, 62, 63, 513, 514, 515, 516, 517} + res := seq.getExtPacketMetas(req) + // nothing should be returned as not enough time has elapsed since sending packet + require.Equal(t, 0, len(res)) + + time.Sleep((ignoreRetransmission + 10) * time.Millisecond) + res = seq.getExtPacketMetas(req) + require.Equal(t, len(req), len(res)) + for i, val := range res { + require.Equal(t, val.targetSeqNo, req[i]) + require.Equal(t, val.sourceSeqNo, uint64(req[i]-off)) + require.Equal(t, val.layer, int8(2)) + require.Equal(t, val.extSequenceNumber, uint64(req[i])) + require.Equal(t, val.extTimestamp, uint64(123)) + } + res = seq.getExtPacketMetas(req) + require.Equal(t, 0, len(res)) + time.Sleep((ignoreRetransmission + 10) * time.Millisecond) + res = seq.getExtPacketMetas(req) + require.Equal(t, len(req), len(res)) + for i, val := range res { + require.Equal(t, val.targetSeqNo, req[i]) + require.Equal(t, val.sourceSeqNo, uint64(req[i]-off)) + require.Equal(t, val.layer, int8(2)) + require.Equal(t, val.extSequenceNumber, uint64(req[i])) + require.Equal(t, val.extTimestamp, uint64(123)) + } + + seq.push(time.Now().UnixNano(), 521, 521+uint64(off), 123, true, 1, nil, 0, nil, nil) + m := seq.getExtPacketMetas([]uint16{521 + off}) + require.Equal(t, 0, len(m)) + time.Sleep((ignoreRetransmission + 10) * time.Millisecond) + m = seq.getExtPacketMetas([]uint16{521 + off}) + require.Equal(t, 1, len(m)) + + seq.push(time.Now().UnixNano(), 505, 505+uint64(off), 123, false, 1, nil, 0, nil, nil) + m = seq.getExtPacketMetas([]uint16{505 + off}) + require.Equal(t, 0, len(m)) + time.Sleep((ignoreRetransmission + 10) * time.Millisecond) + m = seq.getExtPacketMetas([]uint16{505 + off}) + require.Equal(t, 1, len(m)) +} + +func Test_sequencer_getNACKSeqNo_exclusion(t *testing.T) { + type args struct { + seqNo []uint16 + } + type input struct { + seqNo uint64 + isPadding bool + } + type fields struct { + inputs []input + offset uint64 + markerOdd bool + markerEven bool + codecBytesOdd []byte + numCodecBytesInOdd int + codecBytesEven []byte + numCodecBytesInEven int + codecBytesOversized []byte + ddBytesOdd []byte + ddBytesEven []byte + ddBytesOversized []byte + actBytesOdd []byte + actBytesEven []byte + } + + tests := []struct { + name string + fields fields + args args + want []uint16 + }{ + { + name: "Should get correct seq numbers", + fields: fields{ + inputs: []input{ + {65526, false}, + {65524, false}, + {65525, false}, + {65529, false}, + {65530, false}, + {65531, true}, + {65533, false}, + {65532, true}, + {65534, false}, + }, + offset: 5, + markerOdd: true, + markerEven: false, + codecBytesOdd: []byte{1, 2, 3, 4}, + numCodecBytesInOdd: 3, + codecBytesEven: []byte{5, 6, 7}, + numCodecBytesInEven: 4, + codecBytesOversized: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, + ddBytesOdd: []byte{8, 9, 10}, + ddBytesEven: []byte{11, 12}, + ddBytesOversized: []byte{11, 12, 13, 14, 15, 16, 17, 18, 19}, + actBytesOdd: []byte{0, 1, 2, 3, 4, 5, 6, 7}, + actBytesEven: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + }, + args: args{ + seqNo: []uint16{65526 + 5, 65527 + 5, 65530 + 5, 0 /* 65531 input */, 1 /* 65532 input */, 2 /* 65533 input */, 3 /* 65534 input */}, + }, + // although 65526 is originally pushed, that would have been reset by 65532 (padding only packet) + // because of trying to add an exclusion range before highest sequence number which will fail + // and the resulting fix up of the exclusion range slots + want: []uint16{65530, 65533, 65534}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + n := newSequencer(5, true, logger.GetLogger()) + + for _, i := range tt.fields.inputs { + if i.isPadding { + n.pushPadding(i.seqNo+tt.fields.offset, i.seqNo+tt.fields.offset) + } else { + if i.seqNo%5 == 0 { + n.push( + time.Now().UnixNano(), + i.seqNo, + i.seqNo+tt.fields.offset, + 123, + tt.fields.markerOdd, + 3, + tt.fields.codecBytesOversized, + len(tt.fields.codecBytesOversized), + tt.fields.ddBytesOversized, + tt.fields.actBytesOdd, + ) + } else { + if i.seqNo%2 == 0 { + n.push( + time.Now().UnixNano(), + i.seqNo, + i.seqNo+tt.fields.offset, + 123, + tt.fields.markerEven, + 3, + tt.fields.codecBytesEven, + tt.fields.numCodecBytesInEven, + tt.fields.ddBytesEven, + tt.fields.actBytesEven, + ) + } else { + n.push( + time.Now().UnixNano(), + i.seqNo, + i.seqNo+tt.fields.offset, + 123, + tt.fields.markerOdd, + 3, + tt.fields.codecBytesOdd, + tt.fields.numCodecBytesInOdd, + tt.fields.ddBytesOdd, + tt.fields.actBytesOdd, + ) + } + } + } + } + + time.Sleep((ignoreRetransmission + 10) * time.Millisecond) + g := n.getExtPacketMetas(tt.args.seqNo) + var got []uint16 + for _, sn := range g { + got = append(got, uint16(sn.sourceSeqNo)) + if sn.sourceSeqNo%5 == 0 { + require.Equal(t, tt.fields.markerOdd, sn.marker) + require.Equal(t, tt.fields.codecBytesOversized, sn.codecBytesSlice) + require.Equal(t, uint8(len(tt.fields.codecBytesOversized)), sn.numCodecBytesIn) + require.Equal(t, tt.fields.ddBytesOversized, sn.ddBytesSlice) + require.Equal(t, uint8(len(tt.fields.codecBytesOversized)), sn.ddBytesSize) + require.Equal(t, tt.fields.actBytesOdd, sn.actBytes) + } else { + if sn.sourceSeqNo%2 == 0 { + require.Equal(t, tt.fields.markerEven, sn.marker) + require.Equal(t, tt.fields.codecBytesEven, sn.codecBytes[:sn.numCodecBytesOut]) + require.Equal(t, uint8(tt.fields.numCodecBytesInEven), sn.numCodecBytesIn) + require.Equal(t, tt.fields.ddBytesEven, sn.ddBytes[:sn.ddBytesSize]) + require.Equal(t, uint8(len(tt.fields.ddBytesEven)), sn.ddBytesSize) + require.Equal(t, tt.fields.actBytesEven, sn.actBytes) + } else { + require.Equal(t, tt.fields.markerOdd, sn.marker) + require.Equal(t, tt.fields.codecBytesOdd, sn.codecBytes[:sn.numCodecBytesOut]) + require.Equal(t, uint8(tt.fields.numCodecBytesInOdd), sn.numCodecBytesIn) + require.Equal(t, tt.fields.ddBytesOdd, sn.ddBytes[:sn.ddBytesSize]) + require.Equal(t, uint8(len(tt.fields.ddBytesOdd)), sn.ddBytesSize) + require.Equal(t, tt.fields.actBytesOdd, sn.actBytes) + } + } + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getExtPacketMetas() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_sequencer_getNACKSeqNo_no_exclusion(t *testing.T) { + type args struct { + seqNo []uint16 + } + type input struct { + seqNo uint64 + isPadding bool + } + type fields struct { + inputs []input + offset uint64 + markerOdd bool + markerEven bool + codecBytesOdd []byte + numCodecBytesInOdd int + codecBytesEven []byte + numCodecBytesInEven int + ddBytesOdd []byte + ddBytesEven []byte + actBytesOdd []byte + actBytesEven []byte + } + + tests := []struct { + name string + fields fields + args args + want []uint16 + }{ + { + name: "Should get correct seq numbers", + fields: fields{ + inputs: []input{ + {2, false}, + {3, false}, + {4, false}, + {7, false}, + {8, false}, + {9, true}, + {11, false}, + {10, true}, + {12, false}, + {13, false}, + }, + offset: 5, + markerOdd: true, + markerEven: false, + codecBytesOdd: []byte{1, 2, 3, 4}, + numCodecBytesInOdd: 3, + codecBytesEven: []byte{5, 6, 7}, + numCodecBytesInEven: 4, + ddBytesOdd: []byte{8, 9, 10}, + ddBytesEven: []byte{11, 12}, + actBytesOdd: []byte{8, 9, 10}, + actBytesEven: []byte{11, 12}, + }, + args: args{ + seqNo: []uint16{4 + 5, 5 + 5, 8 + 5, 9 + 5, 10 + 5, 11 + 5, 12 + 5}, + }, + // although 4 and 8 were originally added, they would be too old after a cycle of sequencer buffer + want: []uint16{11, 12}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + n := newSequencer(5, false, logger.GetLogger()) + + for _, i := range tt.fields.inputs { + if i.isPadding { + n.pushPadding(i.seqNo+tt.fields.offset, i.seqNo+tt.fields.offset) + } else { + if i.seqNo%2 == 0 { + n.push( + time.Now().UnixNano(), + i.seqNo, + i.seqNo+tt.fields.offset, + 123, + tt.fields.markerEven, + 3, + tt.fields.codecBytesEven, + tt.fields.numCodecBytesInEven, + tt.fields.ddBytesEven, + tt.fields.actBytesEven, + ) + } else { + n.push( + time.Now().UnixNano(), + i.seqNo, + i.seqNo+tt.fields.offset, + 123, + tt.fields.markerOdd, + 3, + tt.fields.codecBytesOdd, + tt.fields.numCodecBytesInOdd, + tt.fields.ddBytesOdd, + tt.fields.actBytesOdd, + ) + } + } + } + + time.Sleep((ignoreRetransmission + 10) * time.Millisecond) + g := n.getExtPacketMetas(tt.args.seqNo) + var got []uint16 + for _, sn := range g { + got = append(got, uint16(sn.sourceSeqNo)) + if sn.sourceSeqNo%2 == 0 { + require.Equal(t, tt.fields.markerEven, sn.marker) + require.Equal(t, tt.fields.codecBytesEven, sn.codecBytes[:sn.numCodecBytesOut]) + require.Equal(t, uint8(tt.fields.numCodecBytesInEven), sn.numCodecBytesIn) + require.Equal(t, tt.fields.ddBytesEven, sn.ddBytes[:sn.ddBytesSize]) + require.Equal(t, uint8(len(tt.fields.ddBytesEven)), sn.ddBytesSize) + require.Equal(t, tt.fields.actBytesEven, sn.actBytes) + } else { + require.Equal(t, tt.fields.markerOdd, sn.marker) + require.Equal(t, tt.fields.codecBytesOdd, sn.codecBytes[:sn.numCodecBytesOut]) + require.Equal(t, uint8(tt.fields.numCodecBytesInOdd), sn.numCodecBytesIn) + require.Equal(t, tt.fields.ddBytesOdd, sn.ddBytes[:sn.ddBytesSize]) + require.Equal(t, uint8(len(tt.fields.ddBytesOdd)), sn.ddBytesSize) + require.Equal(t, tt.fields.actBytesOdd, sn.actBytes) + } + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getExtPacketMetas() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/livekit/pkg/sfu/sfu.go b/livekit/pkg/sfu/sfu.go new file mode 100644 index 0000000..c8760d0 --- /dev/null +++ b/livekit/pkg/sfu/sfu.go @@ -0,0 +1,36 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "sync" + + "github.com/pion/rtp" +) + +var ( + PacketFactory = &sync.Pool{ + New: func() any { + b := make([]byte, 1460) + return &b + }, + } + + RTPHeaderFactory = &sync.Pool{ + New: func() any { + return &rtp.Header{} + }, + } +) diff --git a/livekit/pkg/sfu/streamallocator/streamallocator.go b/livekit/pkg/sfu/streamallocator/streamallocator.go new file mode 100644 index 0000000..8face95 --- /dev/null +++ b/livekit/pkg/sfu/streamallocator/streamallocator.go @@ -0,0 +1,1459 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamallocator + +import ( + "fmt" + "sort" + "sync" + "time" + + "github.com/pion/interceptor/pkg/cc" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/bwe" + "github.com/livekit/livekit-server/pkg/sfu/ccutils" + "github.com/livekit/livekit-server/pkg/sfu/pacer" + "github.com/livekit/livekit-server/pkg/utils" +) + +const ( + cChannelCapacityInfinity = 100 * 1000 * 1000 // 100 Mbps + + cPriorityMin = uint8(1) + cPriorityMax = uint8(255) + cPriorityDefaultScreenshare = cPriorityMax + cPriorityDefaultVideo = cPriorityMin + + cFlagAllowOvershootWhileOptimal = true + cFlagAllowOvershootWhileDeficient = false + cFlagAllowOvershootExemptTrackWhileDeficient = true + cFlagAllowOvershootInProbe = true + cFlagAllowOvershootInCatchup = false + cFlagAllowOvershootInBoost = true + + cRTTPullInterval = 30 * time.Second + + cPingLong = cRTTPullInterval / 2 + cPingShort = 100 * time.Millisecond +) + +// --------------------------------------------------------------------------- + +type streamAllocatorState int + +const ( + streamAllocatorStateStable streamAllocatorState = iota + streamAllocatorStateDeficient +) + +func (s streamAllocatorState) String() string { + switch s { + case streamAllocatorStateStable: + return "STABLE" + case streamAllocatorStateDeficient: + return "DEFICIENT" + default: + return fmt.Sprintf("UNKNOWN: %d", int(s)) + } +} + +// --------------------------------------------------------------------------- + +type streamAllocatorSignal int + +const ( + streamAllocatorSignalAllocateTrack streamAllocatorSignal = iota + streamAllocatorSignalAllocateAllTracks + streamAllocatorSignalAdjustState + streamAllocatorSignalEstimate + streamAllocatorSignalFeedback + streamAllocatorSignalPeriodicPing + streamAllocatorSignalProbeClusterSwitch + streamAllocatorSignalSendProbe + streamAllocatorSignalPacerProbeObserverClusterComplete + streamAllocatorSignalResume + streamAllocatorSignalSetAllowPause + streamAllocatorSignalSetChannelCapacity + streamAllocatorSignalCongestionStateChange +) + +func (s streamAllocatorSignal) String() string { + switch s { + case streamAllocatorSignalAllocateTrack: + return "ALLOCATE_TRACK" + case streamAllocatorSignalAllocateAllTracks: + return "ALLOCATE_ALL_TRACKS" + case streamAllocatorSignalAdjustState: + return "ADJUST_STATE" + case streamAllocatorSignalEstimate: + return "ESTIMATE" + case streamAllocatorSignalFeedback: + return "FEEDBACK" + case streamAllocatorSignalPeriodicPing: + return "PERIODIC_PING" + case streamAllocatorSignalProbeClusterSwitch: + return "PROBE_CLUSTER_SWITCH" + case streamAllocatorSignalSendProbe: + return "SEND_PROBE" + case streamAllocatorSignalPacerProbeObserverClusterComplete: + return "PACER_PROBE_OBSERVER_CLUSTER_COMPLETE" + case streamAllocatorSignalResume: + return "RESUME" + case streamAllocatorSignalSetAllowPause: + return "SET_ALLOW_PAUSE" + case streamAllocatorSignalSetChannelCapacity: + return "SET_CHANNEL_CAPACITY" + case streamAllocatorSignalCongestionStateChange: + return "CONGESTION_STATE_CHANGE" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +// --------------------------------------------------------------------------- + +type Event struct { + *StreamAllocator + Signal streamAllocatorSignal + TrackID livekit.TrackID + Data any +} + +func (e Event) String() string { + return fmt.Sprintf("StreamAllocator:Event{signal: %s, trackID: %s, data: %+v}", e.Signal, e.TrackID, e.Data) +} + +// --------------------------------------------------------------------------- + +type ( + ProbeMode string +) + +const ( + ProbeModePadding ProbeMode = "padding" + ProbeModeMedia ProbeMode = "media" +) + +type StreamAllocatorConfig struct { + MinChannelCapacity int64 `yaml:"min_channel_capacity,omitempty"` + DisableEstimationUnmanagedTracks bool `yaml:"disable_etimation_unmanaged_tracks,omitempty"` + + ProbeMode ProbeMode `yaml:"probe_mode,omitempty"` + ProbeOveragePct int64 `yaml:"probe_overage_pct,omitempty"` + ProbeMinBps int64 `yaml:"probe_min_bps,omitempty"` + + PausedMinWait time.Duration `yaml:"paused_min_wait,omitempty"` +} + +var ( + DefaultStreamAllocatorConfig = StreamAllocatorConfig{ + MinChannelCapacity: 0, + DisableEstimationUnmanagedTracks: false, + + ProbeMode: ProbeModePadding, + ProbeOveragePct: 120, + ProbeMinBps: 200_000, + + PausedMinWait: 5 * time.Second, + } +) + +// --------------------------------------------------------------------------- + +type StreamAllocatorParams struct { + Config StreamAllocatorConfig + BWE bwe.BWE + Pacer pacer.Pacer + RTTGetter func() (float64, bool) + Logger logger.Logger +} + +type StreamAllocator struct { + params StreamAllocatorParams + + onStreamStateChange func(update *StreamStateUpdate) error + + sendSideBWEInterceptor cc.BandwidthEstimator + + enabled bool + allowPause bool + + committedChannelCapacity int64 + overriddenChannelCapacity int64 + + prober *ccutils.Prober + + videoTracksMu sync.RWMutex + videoTracks map[livekit.TrackID]*Track + isAllocateAllPending bool + rembTrackingSSRC uint32 + + state streamAllocatorState + + activeProbeClusterId ccutils.ProbeClusterId + activeProbeGoalReached bool + activeProbeCongesting bool + + eventsQueue *utils.TypedOpsQueue[Event] + + lastRTTTime time.Time + + pingGeneration atomic.Uint32 + + isStopped atomic.Bool +} + +func NewStreamAllocator(params StreamAllocatorParams, enabled bool, allowPause bool) *StreamAllocator { + s := &StreamAllocator{ + params: params, + enabled: enabled, + allowPause: allowPause, + videoTracks: make(map[livekit.TrackID]*Track), + state: streamAllocatorStateStable, + activeProbeClusterId: ccutils.ProbeClusterIdInvalid, + eventsQueue: utils.NewTypedOpsQueue[Event](utils.OpsQueueParams{ + Name: "stream-allocator", + MinSize: 64, + Logger: params.Logger, + }), + lastRTTTime: time.Now().Add(-cRTTPullInterval), + } + + s.prober = ccutils.NewProber(ccutils.ProberParams{ + Listener: s, + Logger: params.Logger, + }) + + s.params.BWE.SetBWEListener(s) + s.params.Pacer.SetPacerProbeObserverListener(s) + + return s +} + +func (s *StreamAllocator) Start() { + s.eventsQueue.Start() + go s.ping(s.pingGeneration.Inc(), cPingLong) +} + +func (s *StreamAllocator) Stop() { + if s.isStopped.Swap(true) { + return + } + + // wait for eventsQueue to be done + <-s.eventsQueue.Stop() + + s.maybeStopProbe() +} + +func (s *StreamAllocator) OnStreamStateChange(f func(update *StreamStateUpdate) error) { + s.onStreamStateChange = f +} + +func (s *StreamAllocator) SetSendSideBWEInterceptor(sendSideBWEInterceptor cc.BandwidthEstimator) { + if sendSideBWEInterceptor != nil { + sendSideBWEInterceptor.OnTargetBitrateChange(s.onTargetBitrateChange) + } + s.sendSideBWEInterceptor = sendSideBWEInterceptor +} + +type AddTrackParams struct { + Source livekit.TrackSource + Priority uint8 + IsMultiLayered bool + PublisherID livekit.ParticipantID +} + +func (s *StreamAllocator) AddTrack(downTrack *sfu.DownTrack, params AddTrackParams) { + if downTrack.Kind() != webrtc.RTPCodecTypeVideo { + return + } + + track := NewTrack(downTrack, params.Source, params.IsMultiLayered, params.PublisherID, s.params.Logger) + track.SetPriority(params.Priority) + + trackID := livekit.TrackID(downTrack.ID()) + s.videoTracksMu.Lock() + oldTrack := s.videoTracks[trackID] + s.videoTracks[trackID] = track + s.videoTracksMu.Unlock() + + if oldTrack != nil { + oldTrack.DownTrack().SetStreamAllocatorListener(nil) + } + + downTrack.SetStreamAllocatorListener(s) + downTrack.SetProbeClusterId(s.activeProbeClusterId) + + s.maybePostEventAllocateTrack(downTrack) +} + +func (s *StreamAllocator) RemoveTrack(downTrack *sfu.DownTrack) { + s.videoTracksMu.Lock() + if existing := s.videoTracks[livekit.TrackID(downTrack.ID())]; existing != nil && existing.DownTrack() == downTrack { + delete(s.videoTracks, livekit.TrackID(downTrack.ID())) + } + s.videoTracksMu.Unlock() + + // STREAM-ALLOCATOR-TODO: use any saved bandwidth to re-distribute + s.postEvent(Event{ + Signal: streamAllocatorSignalAdjustState, + }) +} + +func (s *StreamAllocator) SetTrackPriority(downTrack *sfu.DownTrack, priority uint8) { + s.videoTracksMu.Lock() + if track := s.videoTracks[livekit.TrackID(downTrack.ID())]; track != nil { + changed := track.SetPriority(priority) + if changed && !s.isAllocateAllPending { + // do a full allocation on a track priority change to keep it simple + s.isAllocateAllPending = true + s.postEvent(Event{ + Signal: streamAllocatorSignalAllocateAllTracks, + }) + } + } + s.videoTracksMu.Unlock() +} + +func (s *StreamAllocator) SetAllowPause(allowPause bool) { + s.postEvent(Event{ + Signal: streamAllocatorSignalSetAllowPause, + Data: allowPause, + }) +} + +func (s *StreamAllocator) SetChannelCapacity(channelCapacity int64) { + s.postEvent(Event{ + Signal: streamAllocatorSignalSetChannelCapacity, + Data: channelCapacity, + }) +} + +// called when a new REMB is received (receive side bandwidth estimation) +func (s *StreamAllocator) OnREMB(downTrack *sfu.DownTrack, remb *rtcp.ReceiverEstimatedMaximumBitrate) { + // + // Channel capacity is estimated at a peer connection level. All down tracks + // in the peer connection will end up calling this for a REMB report with + // the same estimated channel capacity. Use a tracking SSRC to lock onto to + // one report. As SSRCs can be dropped over time, update tracking SSRC as needed + // + // A couple of things to keep in mind + // - REMB reports could be sent gratuitously as a way of providing + // periodic feedback, i.e. even if the estimated capacity does not + // change, there could be REMB packets on the wire. Those gratuitous + // REMBs should not trigger anything bad. + // - As each down track will issue this callback for the same REMB packet + // from the wire, theoretically it is possible that one down track's + // callback from previous REMB comes after another down track's callback + // from the new REMB. REMBs could fire very quickly especially when + // the network is entering congestion. + // STREAM-ALLOCATOR-TODO-START + // Need to check if the same SSRC reports can somehow race, i.e. does pion send + // RTCP dispatch for same SSRC on different threads? If not, the tracking SSRC + // should prevent racing + // STREAM-ALLOCATOR-TODO-END + // + + // if there are no video tracks, ignore any straggler REMB + s.videoTracksMu.Lock() + if len(s.videoTracks) == 0 { + s.videoTracksMu.Unlock() + return + } + + downTrackSSRC := uint32(0) + downTrackSSRCRTX := uint32(0) + track := s.videoTracks[livekit.TrackID(downTrack.ID())] + if track != nil { + downTrackSSRC = track.DownTrack().SSRC() + downTrackSSRCRTX = track.DownTrack().SSRCRTX() + } + + found := false + for _, ssrc := range remb.SSRCs { + if ssrc == s.rembTrackingSSRC { + found = true + break + } + } + if !found { + if len(remb.SSRCs) == 0 { + s.params.Logger.Warnw("stream allocator: no SSRC to track REMB", nil) + s.videoTracksMu.Unlock() + return + } + + // try to lock to track which is sending this update + for _, ssrc := range remb.SSRCs { + if ssrc == 0 { + continue + } + + if ssrc == downTrackSSRC { + s.rembTrackingSSRC = downTrackSSRC + found = true + break + } + if ssrc == downTrackSSRCRTX { + s.rembTrackingSSRC = downTrackSSRCRTX + found = true + break + } + } + + if !found { + s.rembTrackingSSRC = remb.SSRCs[0] + } + } + + if s.rembTrackingSSRC == 0 || (s.rembTrackingSSRC != downTrackSSRC && s.rembTrackingSSRC != downTrackSSRCRTX) { + s.videoTracksMu.Unlock() + return + } + s.videoTracksMu.Unlock() + + s.postEvent(Event{ + Signal: streamAllocatorSignalEstimate, + Data: int64(remb.Bitrate), + }) +} + +// called when a new transport-cc feedback is received +func (s *StreamAllocator) OnTransportCCFeedback(downTrack *sfu.DownTrack, fb *rtcp.TransportLayerCC) { + s.postEvent(Event{ + Signal: streamAllocatorSignalFeedback, + Data: fb, + }) +} + +// called when target bitrate changes (send side bandwidth estimation) +func (s *StreamAllocator) onTargetBitrateChange(bitrate int) { + s.postEvent(Event{ + Signal: streamAllocatorSignalEstimate, + Data: int64(bitrate), + }) +} + +// called when congestion state changes (send side bandwidth estimation) +type congestionStateChangeData struct { + fromState bwe.CongestionState + toState bwe.CongestionState + estimatedAvailableChannelCapacity int64 +} + +// BWEListener implementation +func (s *StreamAllocator) OnCongestionStateChange(fromState bwe.CongestionState, toState bwe.CongestionState, estimatedAvailableChannelCapacity int64) { + s.postEvent(Event{ + Signal: streamAllocatorSignalCongestionStateChange, + Data: congestionStateChangeData{fromState, toState, estimatedAvailableChannelCapacity}, + }) +} + +// called when feeding track's layer availability changes +func (s *StreamAllocator) OnAvailableLayersChanged(downTrack *sfu.DownTrack) { + s.maybePostEventAllocateTrack(downTrack) +} + +// called when feeding track's bitrate measurement of any layer is available +func (s *StreamAllocator) OnBitrateAvailabilityChanged(downTrack *sfu.DownTrack) { + s.maybePostEventAllocateTrack(downTrack) +} + +// called when feeding track's max published spatial layer changes +func (s *StreamAllocator) OnMaxPublishedSpatialChanged(downTrack *sfu.DownTrack) { + s.maybePostEventAllocateTrack(downTrack) +} + +// called when feeding track's max published temporal layer changes +func (s *StreamAllocator) OnMaxPublishedTemporalChanged(downTrack *sfu.DownTrack) { + s.maybePostEventAllocateTrack(downTrack) +} + +// called when subscription settings changes (muting/unmuting of track) +func (s *StreamAllocator) OnSubscriptionChanged(downTrack *sfu.DownTrack) { + s.maybePostEventAllocateTrack(downTrack) +} + +// called when subscribed layer changes (limiting max layer) +func (s *StreamAllocator) OnSubscribedLayerChanged(downTrack *sfu.DownTrack, layer buffer.VideoLayer) { + shouldPost := false + s.videoTracksMu.Lock() + if track := s.videoTracks[livekit.TrackID(downTrack.ID())]; track != nil { + if track.SetMaxLayer(layer) && track.SetDirty(true) { + shouldPost = true + } + } + s.videoTracksMu.Unlock() + + if shouldPost { + s.postEvent(Event{ + Signal: streamAllocatorSignalAllocateTrack, + TrackID: livekit.TrackID(downTrack.ID()), + }) + } +} + +// called when forwarder resumes a track +func (s *StreamAllocator) OnResume(downTrack *sfu.DownTrack) { + s.postEvent(Event{ + Signal: streamAllocatorSignalResume, + TrackID: livekit.TrackID(downTrack.ID()), + }) +} + +// called when probe cluster changes +func (s *StreamAllocator) OnProbeClusterSwitch(pci ccutils.ProbeClusterInfo) { + s.postEvent(Event{ + Signal: streamAllocatorSignalProbeClusterSwitch, + Data: pci, + }) +} + +// called when prober wants to send packet(s) +func (s *StreamAllocator) OnSendProbe(bytesToSend int) { + s.postEvent(Event{ + Signal: streamAllocatorSignalSendProbe, + Data: bytesToSend, + }) +} + +// called when pacer probe observer observes a cluster completion +func (s *StreamAllocator) OnPacerProbeObserverClusterComplete(probeClusterId ccutils.ProbeClusterId) { + s.postEvent(Event{ + Signal: streamAllocatorSignalPacerProbeObserverClusterComplete, + Data: probeClusterId, + }) +} + +// called to check if track should participate in BWE +func (s *StreamAllocator) IsBWEEnabled(downTrack *sfu.DownTrack) bool { + if !s.params.Config.DisableEstimationUnmanagedTracks { + return true + } + + s.videoTracksMu.Lock() + defer s.videoTracksMu.Unlock() + + if track := s.videoTracks[livekit.TrackID(downTrack.ID())]; track != nil { + return track.IsManaged() + } + + return true +} + +func (s *StreamAllocator) BWEType() bwe.BWEType { + return s.params.BWE.Type() +} + +// called to check if track subscription mute can be applied +func (s *StreamAllocator) IsSubscribeMutable(downTrack *sfu.DownTrack) bool { + s.videoTracksMu.Lock() + defer s.videoTracksMu.Unlock() + + if track := s.videoTracks[livekit.TrackID(downTrack.ID())]; track != nil { + return track.IsSubscribeMutable() + } + + return true +} + +func (s *StreamAllocator) maybePostEventAllocateTrack(downTrack *sfu.DownTrack) { + shouldPost := false + s.videoTracksMu.Lock() + if track := s.videoTracks[livekit.TrackID(downTrack.ID())]; track != nil { + shouldPost = track.SetDirty(true) + } + s.videoTracksMu.Unlock() + + if shouldPost { + s.postEvent(Event{ + Signal: streamAllocatorSignalAllocateTrack, + TrackID: livekit.TrackID(downTrack.ID()), + }) + } +} + +func (s *StreamAllocator) ping(pingGeneration uint32, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + <-ticker.C + if s.isStopped.Load() || (pingGeneration != s.pingGeneration.Load()) { + return + } + + s.postEvent(Event{ + Signal: streamAllocatorSignalPeriodicPing, + }) + } +} + +func (s *StreamAllocator) postEvent(event Event) { + event.StreamAllocator = s + s.eventsQueue.Enqueue(func(event Event) { + switch event.Signal { + case streamAllocatorSignalAllocateTrack: + event.handleSignalAllocateTrack(event) + case streamAllocatorSignalAllocateAllTracks: + event.handleSignalAllocateAllTracks(event) + case streamAllocatorSignalAdjustState: + event.handleSignalAdjustState(event) + case streamAllocatorSignalEstimate: + event.handleSignalEstimate(event) + case streamAllocatorSignalFeedback: + event.handleSignalFeedback(event) + case streamAllocatorSignalPeriodicPing: + event.handleSignalPeriodicPing(event) + case streamAllocatorSignalProbeClusterSwitch: + event.handleSignalProbeClusterSwitch(event) + case streamAllocatorSignalSendProbe: + event.handleSignalSendProbe(event) + case streamAllocatorSignalPacerProbeObserverClusterComplete: + event.handleSignalPacerProbeObserverClusterComplete(event) + case streamAllocatorSignalResume: + event.handleSignalResume(event) + case streamAllocatorSignalSetAllowPause: + event.handleSignalSetAllowPause(event) + case streamAllocatorSignalSetChannelCapacity: + event.handleSignalSetChannelCapacity(event) + case streamAllocatorSignalCongestionStateChange: + s.handleSignalCongestionStateChange(event) + } + }, event) +} + +func (s *StreamAllocator) handleSignalAllocateTrack(event Event) { + s.videoTracksMu.Lock() + track := s.videoTracks[event.TrackID] + if track != nil { + track.SetDirty(false) + } + s.videoTracksMu.Unlock() + + if track != nil { + s.allocateTrack(track) + } +} + +func (s *StreamAllocator) handleSignalAllocateAllTracks(Event) { + s.videoTracksMu.Lock() + s.isAllocateAllPending = false + s.videoTracksMu.Unlock() + + if s.state == streamAllocatorStateDeficient { + s.allocateAllTracks() + } +} + +func (s *StreamAllocator) handleSignalAdjustState(Event) { + s.adjustState() +} + +func (s *StreamAllocator) handleSignalEstimate(event Event) { + receivedEstimate := event.Data.(int64) + + // always update NACKs + packetDelta, repeatedNackDelta := s.getNackDelta() + + s.params.BWE.HandleREMB( + receivedEstimate, + s.getExpectedBandwidthUsage(), + packetDelta, + repeatedNackDelta, + ) +} + +func (s *StreamAllocator) handleSignalFeedback(event Event) { + fb := event.Data.(*rtcp.TransportLayerCC) + if s.sendSideBWEInterceptor != nil { + s.sendSideBWEInterceptor.WriteRTCP([]rtcp.Packet{fb}, nil) + } + + s.params.BWE.HandleTWCCFeedback(fb) +} + +func (s *StreamAllocator) handleSignalPeriodicPing(Event) { + // if pause is allowed, there may be no packets sent and BWE could be in congested state, + // reset BWE if that persists for a while + if s.allowPause && s.state == streamAllocatorStateDeficient && s.params.BWE.CongestionState() != bwe.CongestionStateNone && s.params.Pacer.TimeSinceLastSentPacket() > s.params.Config.PausedMinWait { + s.params.Logger.Infow("stream allocator: resetting bwe to enable probing") + s.maybeStopProbe() + s.params.BWE.Reset() + + // as BWE is reset, there is no finalizing for active cluster, so reset active cluster id + s.activeProbeClusterId = ccutils.ProbeClusterIdInvalid + } + + if s.activeProbeClusterId != ccutils.ProbeClusterIdInvalid { + if !s.activeProbeCongesting && !s.activeProbeGoalReached && s.params.BWE.ProbeClusterIsGoalReached() { + s.params.Logger.Debugw( + "stream allocator: probe goal reached", + "activeProbeClusterId", s.activeProbeClusterId, + ) + s.activeProbeGoalReached = true + s.maybeStopProbe() + } + + // finalize any probe that may have finished/aborted + if probeSignal, channelCapacity, isFinalized := s.params.BWE.ProbeClusterFinalize(); isFinalized { + s.params.Logger.Debugw( + "stream allocator: probe result", + "activeProbeClusterId", s.activeProbeClusterId, + "probeSignal", probeSignal, + "channelCapacity", channelCapacity, + ) + + s.activeProbeClusterId = ccutils.ProbeClusterIdInvalid + + if probeSignal != ccutils.ProbeSignalCongesting { + if channelCapacity > s.committedChannelCapacity { + s.committedChannelCapacity = channelCapacity + } + + s.maybeBoostDeficientTracks() + } + } + } + + // probe if necessary and timing is right + if s.state == streamAllocatorStateDeficient { + s.maybeProbe() + } + + if time.Since(s.lastRTTTime) > cRTTPullInterval { + s.lastRTTTime = time.Now() + + if s.params.RTTGetter != nil { + if rtt, ok := s.params.RTTGetter(); ok { + s.params.BWE.UpdateRTT(rtt) + } + } + } +} + +func (s *StreamAllocator) handleSignalProbeClusterSwitch(event Event) { + pci := event.Data.(ccutils.ProbeClusterInfo) + s.activeProbeClusterId = pci.Id + s.activeProbeGoalReached = false + s.activeProbeCongesting = false + + s.params.BWE.ProbeClusterStarting(pci) + + s.params.Pacer.StartProbeCluster(pci) + + for _, t := range s.getTracks() { + t.DownTrack().SetProbeClusterId(pci.Id) + } +} + +func (s *StreamAllocator) handleSignalSendProbe(event Event) { + bytesToSend := event.Data.(int) + if bytesToSend <= 0 { + return + } + + bytesSent := 0 + for _, track := range s.getTracks() { + sent := track.WriteProbePackets(bytesToSend) + bytesSent += sent + bytesToSend -= sent + if bytesToSend <= 0 { + break + } + } + + s.prober.ProbesSent(bytesSent) +} + +func (s *StreamAllocator) handleSignalPacerProbeObserverClusterComplete(event Event) { + probeClusterId, _ := event.Data.(ccutils.ProbeClusterId) + pci := s.params.Pacer.EndProbeCluster(probeClusterId) + + for _, t := range s.getTracks() { + t.DownTrack().SwapProbeClusterId(pci.Id, ccutils.ProbeClusterIdInvalid) + } + + s.params.BWE.ProbeClusterDone(pci) + s.prober.ClusterDone(pci) +} + +func (s *StreamAllocator) handleSignalResume(event Event) { + s.videoTracksMu.Lock() + track := s.videoTracks[event.TrackID] + updated := track != nil && track.SetStreamState(StreamStateActive) + s.videoTracksMu.Unlock() + + if updated { + update := NewStreamStateUpdate() + update.HandleStreamingChange(track, StreamStateActive) + s.maybeSendUpdate(update) + } +} + +func (s *StreamAllocator) handleSignalSetAllowPause(event Event) { + s.allowPause = event.Data.(bool) +} + +func (s *StreamAllocator) handleSignalSetChannelCapacity(event Event) { + s.overriddenChannelCapacity = event.Data.(int64) + if s.overriddenChannelCapacity > 0 { + s.params.Logger.Infow("allocating on override channel capacity", "override", s.overriddenChannelCapacity) + s.allocateAllTracks() + } else { + s.params.Logger.Infow("clearing override channel capacity") + } +} + +func (s *StreamAllocator) handleSignalCongestionStateChange(event Event) { + cscd := event.Data.(congestionStateChangeData) + if cscd.toState != bwe.CongestionStateNone { + // end/abort any running probe if channel is not clear + s.maybeStopProbe() + } + + // some tracks may have been held at sub-optimal allocation + // during early warning hold (if there was one) + if isHoldableCongestionState(cscd.fromState) && cscd.toState == bwe.CongestionStateNone && s.state == streamAllocatorStateStable { + update := NewStreamStateUpdate() + for _, track := range s.getTracks() { + allocation := track.AllocateOptimal(cFlagAllowOvershootWhileOptimal, false) + updateStreamStateChange(track, allocation, update) + } + s.maybeSendUpdate(update) + } + + if cscd.toState == bwe.CongestionStateCongested { + if s.activeProbeClusterId != ccutils.ProbeClusterIdInvalid { + if !s.activeProbeCongesting { + s.activeProbeCongesting = true + s.params.Logger.Debugw( + "stream allocator: channel congestion detected, not updating channel capacity in active probe", + "old(bps)", s.committedChannelCapacity, + "new(bps)", cscd.estimatedAvailableChannelCapacity, + "expectedUsage(bps)", s.getExpectedBandwidthUsage(), + ) + } + } else { + s.params.Logger.Debugw( + "stream allocator: channel congestion detected, updating channel capacity", + "old(bps)", s.committedChannelCapacity, + "new(bps)", cscd.estimatedAvailableChannelCapacity, + "expectedUsage(bps)", s.getExpectedBandwidthUsage(), + ) + s.committedChannelCapacity = cscd.estimatedAvailableChannelCapacity + + s.allocateAllTracks() + } + } +} + +func (s *StreamAllocator) setState(state streamAllocatorState) { + if s.state == state { + return + } + + s.params.Logger.Infow("stream allocator: state change", "from", s.state, "to", state) + s.state = state + + // restart everything when state is STABLE + if state == streamAllocatorStateStable { + s.maybeStopProbe() + + s.params.BWE.Reset() + + s.activeProbeClusterId = ccutils.ProbeClusterIdInvalid + go s.ping(s.pingGeneration.Inc(), cPingLong) + } else { + go s.ping(s.pingGeneration.Inc(), cPingShort) + } +} + +func (s *StreamAllocator) adjustState() { + for _, track := range s.getTracks() { + if track.IsDeficient() { + s.setState(streamAllocatorStateDeficient) + return + } + } + + s.setState(streamAllocatorStateStable) +} + +func (s *StreamAllocator) allocateTrack(track *Track) { + // end/abort any probe that may be running when a track specific change needs allocation + s.maybeStopProbe() + + // if not deficient, free pass allocate track + bweCongestionState := s.params.BWE.CongestionState() + if !s.enabled || (s.state == streamAllocatorStateStable && !isDeficientCongestionState(bweCongestionState)) || !track.IsManaged() { + update := NewStreamStateUpdate() + allocation := track.AllocateOptimal(cFlagAllowOvershootWhileOptimal, isHoldableCongestionState(bweCongestionState)) + updateStreamStateChange(track, allocation, update) + s.maybeSendUpdate(update) + return + } + + // + // In DEFICIENT state, + // Two possibilities + // 1. Available headroom is enough to accommodate track that needs change. + // Note that the track could be muted, hence stopping. + // 2. Have to steal bits from other tracks currently streaming. + // + // For both cases, do + // a. Find cooperative transition from track that needs allocation. + // b. If track is giving back bits, apply the transition and use bits given + // back to boost any deficient track(s). + // + // If track needs more bits, i.e. upward transition (may need resume or higher layer subscription), + // a. Try to allocate using existing headroom. This can be tried to get the best + // possible fit for the available headroom. + // b. If there is not enough headroom to allocate anything, ask for best offer from + // other tracks that are currently streaming and try to use it. This is done only if the + // track needing change is not currently streaming, i. e. it has to be resumed. + // + track.ProvisionalAllocatePrepare() + transition := track.ProvisionalAllocateGetCooperativeTransition(cFlagAllowOvershootWhileDeficient) + + // downgrade, giving back bits + if transition.From.GreaterThan(transition.To) { + allocation := track.ProvisionalAllocateCommit() + + update := NewStreamStateUpdate() + updateStreamStateChange(track, allocation, update) + s.maybeSendUpdate(update) + + s.adjustState() + + // Use the bits given back to boost deficient track(s). + // Note layer downgrade may actually have positive delta (i.e. consume more bits) + // because of when the measurement is done. But, only available headroom after + // applying the transition will be used to boost deficient track(s). + s.maybeBoostDeficientTracks() + return + } + + // a no-op transition + if transition.From == transition.To { + return + } + + // this track is currently not streaming and needs bits to start OR streaming at some layer and wants more bits. + // NOTE: With co-operative transition, tracks should not be asking for more if already streaming, but handle that case any way. + // first try an allocation using available headroom, current consumption of this track is discounted to calculate headroom. + availableChannelCapacity := s.getAvailableHeadroomWithoutTracks(false, []*Track{track}) + if availableChannelCapacity > 0 { + track.ProvisionalAllocateReset() // to reset allocation from co-operative transition above and try fresh + + bestLayer := buffer.InvalidLayer + + alloc_loop: + for spatial := int32(0); spatial <= buffer.DefaultMaxLayerSpatial; spatial++ { + for temporal := int32(0); temporal <= buffer.DefaultMaxLayerTemporal; temporal++ { + layer := buffer.VideoLayer{ + Spatial: spatial, + Temporal: temporal, + } + + isCandidate, usedChannelCapacity := track.ProvisionalAllocate( + availableChannelCapacity, + layer, + s.allowPause, + cFlagAllowOvershootWhileDeficient, + ) + if availableChannelCapacity < usedChannelCapacity { + break alloc_loop + } + + if isCandidate { + bestLayer = layer + } + } + } + + if bestLayer.IsValid() { + if bestLayer.GreaterThan(transition.From) { + // found layer that can fit in available headroom, take it if it is better than existing + update := NewStreamStateUpdate() + allocation := track.ProvisionalAllocateCommit() + updateStreamStateChange(track, allocation, update) + s.maybeSendUpdate(update) + } + + s.adjustState() + return + } + + track.ProvisionalAllocateReset() + transition = track.ProvisionalAllocateGetCooperativeTransition(cFlagAllowOvershootWhileDeficient) // get transition again to reset above allocation attempt using available headroom + } + + // if there is not enough headroom, try to redistribute starting with tracks that are closest to their desired. + bandwidthAcquired := int64(0) + var contributingTracks []*Track + + minDistanceSorted := s.getMinDistanceSorted(track) + for _, t := range minDistanceSorted { + t.ProvisionalAllocatePrepare() + } + + for _, t := range minDistanceSorted { + tx := t.ProvisionalAllocateGetBestWeightedTransition() + if tx.BandwidthDelta < 0 { + contributingTracks = append(contributingTracks, t) + + bandwidthAcquired += -tx.BandwidthDelta + if bandwidthAcquired >= transition.BandwidthDelta { + break + } + } + } + + update := NewStreamStateUpdate() + if bandwidthAcquired >= transition.BandwidthDelta { + // commit the tracks that contributed + for _, t := range contributingTracks { + allocation := t.ProvisionalAllocateCommit() + updateStreamStateChange(t, allocation, update) + } + + // STREAM-ALLOCATOR-TODO if got too much extra, can potentially give it to some deficient track + } + + // commit the track that needs change if enough could be acquired or pause not allowed + if !s.allowPause || bandwidthAcquired >= transition.BandwidthDelta { + allocation := track.ProvisionalAllocateCommit() + updateStreamStateChange(track, allocation, update) + } else { + // explicitly pause to ensure stream state update happens if a track coming out of mute cannot be allocated + allocation := track.Pause() + updateStreamStateChange(track, allocation, update) + } + + s.maybeSendUpdate(update) + + s.adjustState() +} + +func (s *StreamAllocator) maybeStopProbe() { + if s.activeProbeClusterId == ccutils.ProbeClusterIdInvalid { + return + } + + pci := s.params.Pacer.EndProbeCluster(s.activeProbeClusterId) + + for _, t := range s.getTracks() { + t.DownTrack().SwapProbeClusterId(pci.Id, ccutils.ProbeClusterIdInvalid) + } + + s.params.BWE.ProbeClusterDone(pci) + s.prober.Reset(pci) +} + +func (s *StreamAllocator) maybeBoostDeficientTracks() { + availableChannelCapacity := s.getAvailableHeadroom(false) + if availableChannelCapacity <= 0 { + s.params.Logger.Debugw( + "stream allocator: no available headroom to boost deficient tracks", + "committedChannelCapacity", s.committedChannelCapacity, + "availableChannelCapacity", availableChannelCapacity, + "expectedBandwidthUsage", s.getExpectedBandwidthUsage(), + ) + return + } + + update := NewStreamStateUpdate() + + sortedTracks := s.getMaxDistanceSortedDeficient() +boost_loop: + for { + for idx, track := range sortedTracks { + allocation, boosted := track.AllocateNextHigher(availableChannelCapacity, cFlagAllowOvershootInCatchup) + if !boosted { + if idx == len(sortedTracks)-1 { + // all tracks tried + break boost_loop + } + continue + } + + updateStreamStateChange(track, allocation, update) + + availableChannelCapacity -= allocation.BandwidthDelta + if availableChannelCapacity <= 0 { + break boost_loop + } + + break // sort again below as the track that was just boosted could still be farthest from its desired + } + sortedTracks = s.getMaxDistanceSortedDeficient() + if len(sortedTracks) == 0 { + break // nothing available to boost + } + } + + s.maybeSendUpdate(update) + + s.adjustState() +} + +func (s *StreamAllocator) allocateAllTracks() { + if !s.enabled { + // nothing else to do when disabled + return + } + + // + // Goals: + // 1. Stream as many tracks as possible, i.e. no pauses. + // 2. Try to give fair allocation to all track. + // + // Start with the lowest layer and give each track a chance at that layer and keep going up. + // As long as there is enough bandwidth for tracks to stream at the lowest layer, the first goal is achieved. + // + // Tracks that have higher subscribed layer can use any additional available bandwidth. This tried to achieve the second goal. + // + // If there is not enough bandwidth even for the lowest layer, tracks at lower priorities will be paused. + // + update := NewStreamStateUpdate() + + availableChannelCapacity := s.getAvailableChannelCapacity(true) + + // + // This pass is to find out if there is any leftover channel capacity after allocating exempt tracks. + // Exempt tracks are given optimal allocation (i. e. no bandwidth constraint) so that they do not fail allocation. + // + videoTracks := s.getTracks() + for _, track := range videoTracks { + if track.IsManaged() { + continue + } + + allocation := track.AllocateOptimal(cFlagAllowOvershootExemptTrackWhileDeficient, false) + updateStreamStateChange(track, allocation, update) + + // STREAM-ALLOCATOR-TODO: optimistic allocation before bitrate is available will return 0. How to account for that? + if !s.params.Config.DisableEstimationUnmanagedTracks { + availableChannelCapacity -= allocation.BandwidthRequested + } + } + + if availableChannelCapacity < 0 { + availableChannelCapacity = 0 + } + if availableChannelCapacity == 0 && s.allowPause { + // nothing left for managed tracks, pause them all + for _, track := range videoTracks { + if !track.IsManaged() { + continue + } + + allocation := track.Pause() + updateStreamStateChange(track, allocation, update) + } + } else { + sorted := s.getSorted() + for _, track := range sorted { + track.ProvisionalAllocatePrepare() + } + + for spatial := int32(0); spatial <= buffer.DefaultMaxLayerSpatial; spatial++ { + for temporal := int32(0); temporal <= buffer.DefaultMaxLayerTemporal; temporal++ { + layer := buffer.VideoLayer{ + Spatial: spatial, + Temporal: temporal, + } + + for _, track := range sorted { + _, usedChannelCapacity := track.ProvisionalAllocate(availableChannelCapacity, layer, s.allowPause, cFlagAllowOvershootWhileDeficient) + availableChannelCapacity -= usedChannelCapacity + if availableChannelCapacity < 0 { + availableChannelCapacity = 0 + } + } + } + } + + for _, track := range sorted { + allocation := track.ProvisionalAllocateCommit() + updateStreamStateChange(track, allocation, update) + } + } + + s.maybeSendUpdate(update) + + s.adjustState() +} + +func (s *StreamAllocator) maybeSendUpdate(update *StreamStateUpdate) { + if update.Empty() { + return + } + + // logging individual changes to make it easier for logging systems + for _, streamState := range update.StreamStates { + s.params.Logger.Debugw("streamed tracks changed", + "trackID", streamState.TrackID, + "state", streamState.State, + ) + } + if s.onStreamStateChange != nil { + err := s.onStreamStateChange(update) + if err != nil { + s.params.Logger.Errorw("could not send streamed tracks update", err) + } + } +} + +func (s *StreamAllocator) getAvailableChannelCapacity(allowOverride bool) int64 { + availableChannelCapacity := s.committedChannelCapacity + if s.params.Config.MinChannelCapacity > availableChannelCapacity { + availableChannelCapacity = s.params.Config.MinChannelCapacity + s.params.Logger.Debugw( + "stream allocator: overriding channel capacity with min channel capacity", + "actual", s.committedChannelCapacity, + "override", availableChannelCapacity, + ) + } + if allowOverride && s.overriddenChannelCapacity > 0 { + availableChannelCapacity = s.overriddenChannelCapacity + s.params.Logger.Debugw( + "stream allocator: overriding channel capacity", + "actual", s.committedChannelCapacity, + "override", availableChannelCapacity, + ) + } + + return availableChannelCapacity +} + +func (s *StreamAllocator) getExpectedBandwidthUsage() int64 { + expected := int64(0) + for _, track := range s.getTracks() { + expected += track.BandwidthRequested() + } + + return expected +} + +func (s *StreamAllocator) getExpectedBandwidthUsageWithoutTracks(filteredTracks []*Track) int64 { + expected := int64(0) + for _, track := range s.getTracks() { + filtered := false + for _, ft := range filteredTracks { + if ft == track { + filtered = true + break + } + } + if !filtered { + expected += track.BandwidthRequested() + } + } + + return expected +} + +func (s *StreamAllocator) getAvailableHeadroom(allowOverride bool) int64 { + return s.getAvailableChannelCapacity(allowOverride) - s.getExpectedBandwidthUsage() +} + +func (s *StreamAllocator) getAvailableHeadroomWithoutTracks(allowOverride bool, filteredTracks []*Track) int64 { + return s.getAvailableChannelCapacity(allowOverride) - s.getExpectedBandwidthUsageWithoutTracks(filteredTracks) +} + +func (s *StreamAllocator) getNackDelta() (uint32, uint32) { + aggPacketDelta := uint32(0) + aggRepeatedNackDelta := uint32(0) + for _, track := range s.getTracks() { + packetDelta, nackDelta := track.GetNackDelta() + aggPacketDelta += packetDelta + aggRepeatedNackDelta += nackDelta + } + + return aggPacketDelta, aggRepeatedNackDelta +} + +func (s *StreamAllocator) maybeProbe() { + if s.overriddenChannelCapacity > 0 { + // do not probe if channel capacity is overridden + return + } + + if !s.params.BWE.CanProbe() { + return + } + + switch s.params.Config.ProbeMode { + case ProbeModeMedia: + s.maybeProbeWithMedia() + s.adjustState() + case ProbeModePadding: + s.maybeProbeWithPadding() + } +} + +func (s *StreamAllocator) maybeProbeWithMedia() { + // boost deficient track farthest from desired layer + for _, track := range s.getMaxDistanceSortedDeficient() { + allocation, boosted := track.AllocateNextHigher(cChannelCapacityInfinity, cFlagAllowOvershootInBoost) + if !boosted { + continue + } + + update := NewStreamStateUpdate() + updateStreamStateChange(track, allocation, update) + s.maybeSendUpdate(update) + + s.params.BWE.Reset() + break + } +} + +func (s *StreamAllocator) maybeProbeWithPadding() { + // use deficient track farthest from desired layer to find how much to probe + for _, track := range s.getMaxDistanceSortedDeficient() { + transition, available := track.GetNextHigherTransition(cFlagAllowOvershootInProbe) + if !available || transition.BandwidthDelta < 0 { + continue + } + + // overshoot a bit to account for noise (in measurement/estimate etc) + desiredIncreaseBps := (transition.BandwidthDelta * s.params.Config.ProbeOveragePct) / 100 + if desiredIncreaseBps < s.params.Config.ProbeMinBps { + desiredIncreaseBps = s.params.Config.ProbeMinBps + } + expectedBandwidthUsage := s.getExpectedBandwidthUsage() + pci := s.prober.AddCluster( + ccutils.ProbeClusterModeUniform, + ccutils.ProbeClusterGoal{ + AvailableBandwidthBps: int(s.committedChannelCapacity), + ExpectedUsageBps: int(expectedBandwidthUsage), + DesiredBps: int(expectedBandwidthUsage + desiredIncreaseBps), + Duration: s.params.BWE.ProbeDuration(), + }, + ) + s.params.Logger.Debugw( + "stream allocator: adding probe", + "probeClusterInfo", pci, + ) + break + } +} + +func (s *StreamAllocator) getTracks() []*Track { + s.videoTracksMu.RLock() + tracks := make([]*Track, 0, len(s.videoTracks)) + for _, track := range s.videoTracks { + tracks = append(tracks, track) + } + s.videoTracksMu.RUnlock() + + return tracks +} + +func (s *StreamAllocator) getSorted() TrackSorter { + s.videoTracksMu.RLock() + var trackSorter TrackSorter + for _, track := range s.videoTracks { + if !track.IsManaged() { + continue + } + + trackSorter = append(trackSorter, track) + } + s.videoTracksMu.RUnlock() + + sort.Sort(trackSorter) + + return trackSorter +} + +func (s *StreamAllocator) getMinDistanceSorted(exclude *Track) MinDistanceSorter { + s.videoTracksMu.RLock() + var minDistanceSorter MinDistanceSorter + for _, track := range s.videoTracks { + if !track.IsManaged() || track == exclude { + continue + } + + minDistanceSorter = append(minDistanceSorter, track) + } + s.videoTracksMu.RUnlock() + + sort.Sort(minDistanceSorter) + + return minDistanceSorter +} + +func (s *StreamAllocator) getMaxDistanceSortedDeficient() MaxDistanceSorter { + s.videoTracksMu.RLock() + var maxDistanceSorter MaxDistanceSorter + for _, track := range s.videoTracks { + if !track.IsManaged() || !track.IsDeficient() { + continue + } + + maxDistanceSorter = append(maxDistanceSorter, track) + } + s.videoTracksMu.RUnlock() + + sort.Sort(maxDistanceSorter) + + return maxDistanceSorter +} + +// ------------------------------------------------ + +func updateStreamStateChange(track *Track, allocation sfu.VideoAllocation, update *StreamStateUpdate) { + updated := false + streamState := StreamStateInactive + switch allocation.PauseReason { + case sfu.VideoPauseReasonMuted: + fallthrough + + case sfu.VideoPauseReasonPubMuted: + streamState = StreamStateInactive + updated = track.SetStreamState(streamState) + + case sfu.VideoPauseReasonBandwidth: + streamState = StreamStatePaused + updated = track.SetStreamState(streamState) + } + + if updated { + update.HandleStreamingChange(track, streamState) + } +} + +func isHoldableCongestionState(bweCongestionState bwe.CongestionState) bool { + return bweCongestionState == bwe.CongestionStateEarlyWarning +} + +func isDeficientCongestionState(bweCongestionState bwe.CongestionState) bool { + return bweCongestionState == bwe.CongestionStateCongested +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/streamallocator/streamstateupdate.go b/livekit/pkg/sfu/streamallocator/streamstateupdate.go new file mode 100644 index 0000000..53156de --- /dev/null +++ b/livekit/pkg/sfu/streamallocator/streamstateupdate.go @@ -0,0 +1,85 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamallocator + +import ( + "fmt" + + "github.com/livekit/protocol/livekit" +) + +// ------------------------------------------------ + +type StreamState int + +const ( + StreamStateInactive StreamState = iota + StreamStateActive + StreamStatePaused +) + +func (s StreamState) String() string { + switch s { + case StreamStateInactive: + return "INACTIVE" + case StreamStateActive: + return "ACTIVE" + case StreamStatePaused: + return "PAUSED" + default: + return fmt.Sprintf("UNKNOWN: %d", int(s)) + } +} + +// ------------------------------------------------ + +type StreamStateInfo struct { + ParticipantID livekit.ParticipantID + TrackID livekit.TrackID + State StreamState +} + +type StreamStateUpdate struct { + StreamStates []*StreamStateInfo +} + +func NewStreamStateUpdate() *StreamStateUpdate { + return &StreamStateUpdate{} +} + +func (s *StreamStateUpdate) HandleStreamingChange(track *Track, streamState StreamState) { + switch streamState { + case StreamStateInactive: + // inactive is not a notification, could get into this state because of mute + case StreamStateActive: + s.StreamStates = append(s.StreamStates, &StreamStateInfo{ + ParticipantID: track.PublisherID(), + TrackID: track.ID(), + State: StreamStateActive, + }) + case StreamStatePaused: + s.StreamStates = append(s.StreamStates, &StreamStateInfo{ + ParticipantID: track.PublisherID(), + TrackID: track.ID(), + State: StreamStatePaused, + }) + } +} + +func (s *StreamStateUpdate) Empty() bool { + return len(s.StreamStates) == 0 +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/streamallocator/track.go b/livekit/pkg/sfu/streamallocator/track.go new file mode 100644 index 0000000..a094aa5 --- /dev/null +++ b/livekit/pkg/sfu/streamallocator/track.go @@ -0,0 +1,281 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamallocator + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/sfu" + "github.com/livekit/livekit-server/pkg/sfu/buffer" +) + +type Track struct { + downTrack *sfu.DownTrack + source livekit.TrackSource + isMultiLayered bool + priority uint8 + publisherID livekit.ParticipantID + logger logger.Logger + + maxLayer buffer.VideoLayer + + totalPackets uint32 + totalRepeatedNacks uint32 + + isDirty bool + + streamState StreamState +} + +func NewTrack( + downTrack *sfu.DownTrack, + source livekit.TrackSource, + isMultiLayered bool, + publisherID livekit.ParticipantID, + logger logger.Logger, +) *Track { + t := &Track{ + downTrack: downTrack, + source: source, + isMultiLayered: isMultiLayered, + publisherID: publisherID, + logger: logger, + streamState: StreamStateInactive, + } + t.SetPriority(0) + t.SetMaxLayer(downTrack.MaxLayer()) + + return t +} + +func (t *Track) SetDirty(isDirty bool) bool { + if t.isDirty == isDirty { + return false + } + + t.isDirty = isDirty + return true +} + +func (t *Track) SetStreamState(streamState StreamState) bool { + if t.streamState == streamState { + return false + } + + t.streamState = streamState + return true +} + +func (t *Track) IsSubscribeMutable() bool { + return t.streamState != StreamStatePaused +} + +func (t *Track) SetPriority(priority uint8) bool { + if priority == 0 { + switch t.source { + case livekit.TrackSource_SCREEN_SHARE: + priority = cPriorityDefaultScreenshare + default: + priority = cPriorityDefaultVideo + } + } + + if t.priority == priority { + return false + } + + t.priority = priority + return true +} + +func (t *Track) Priority() uint8 { + return t.priority +} + +func (t *Track) DownTrack() *sfu.DownTrack { + return t.downTrack +} + +func (t *Track) IsManaged() bool { + return t.source != livekit.TrackSource_SCREEN_SHARE || t.isMultiLayered +} + +func (t *Track) ID() livekit.TrackID { + return livekit.TrackID(t.downTrack.ID()) +} + +func (t *Track) PublisherID() livekit.ParticipantID { + return t.publisherID +} + +func (t *Track) SetMaxLayer(layer buffer.VideoLayer) bool { + if t.maxLayer == layer { + return false + } + + t.maxLayer = layer + return true +} + +func (t *Track) WritePaddingRTP(bytesToSend int) int { + return t.downTrack.WritePaddingRTP(bytesToSend, false, false) +} + +func (t *Track) WriteProbePackets(bytesToSend int) int { + return t.downTrack.WriteProbePackets(bytesToSend, false) +} + +func (t *Track) AllocateOptimal(allowOvershoot bool, hold bool) sfu.VideoAllocation { + return t.downTrack.AllocateOptimal(allowOvershoot, hold) +} + +func (t *Track) ProvisionalAllocatePrepare() { + t.downTrack.ProvisionalAllocatePrepare() +} + +func (t *Track) ProvisionalAllocateReset() { + t.downTrack.ProvisionalAllocateReset() +} + +func (t *Track) ProvisionalAllocate(availableChannelCapacity int64, layer buffer.VideoLayer, allowPause bool, allowOvershoot bool) (bool, int64) { + return t.downTrack.ProvisionalAllocate(availableChannelCapacity, layer, allowPause, allowOvershoot) +} + +func (t *Track) ProvisionalAllocateGetCooperativeTransition(allowOvershoot bool) sfu.VideoTransition { + return t.downTrack.ProvisionalAllocateGetCooperativeTransition(allowOvershoot) +} + +func (t *Track) ProvisionalAllocateGetBestWeightedTransition() sfu.VideoTransition { + return t.downTrack.ProvisionalAllocateGetBestWeightedTransition() +} + +func (t *Track) ProvisionalAllocateCommit() sfu.VideoAllocation { + return t.downTrack.ProvisionalAllocateCommit() +} + +func (t *Track) AllocateNextHigher(availableChannelCapacity int64, allowOvershoot bool) (sfu.VideoAllocation, bool) { + return t.downTrack.AllocateNextHigher(availableChannelCapacity, allowOvershoot) +} + +func (t *Track) GetNextHigherTransition(allowOvershoot bool) (sfu.VideoTransition, bool) { + return t.downTrack.GetNextHigherTransition(allowOvershoot) +} + +func (t *Track) Pause() sfu.VideoAllocation { + return t.downTrack.Pause() +} + +func (t *Track) IsDeficient() bool { + return t.downTrack.IsDeficient() +} + +func (t *Track) BandwidthRequested() int64 { + return t.downTrack.BandwidthRequested() +} + +func (t *Track) DistanceToDesired() float64 { + return t.downTrack.DistanceToDesired() +} + +func (t *Track) GetNackDelta() (uint32, uint32) { + totalPackets, totalRepeatedNacks := t.downTrack.GetNackStats() + + packetDelta := totalPackets - t.totalPackets + t.totalPackets = totalPackets + + nackDelta := totalRepeatedNacks - t.totalRepeatedNacks + t.totalRepeatedNacks = totalRepeatedNacks + + return packetDelta, nackDelta +} + +// ------------------------------------------------ + +type TrackSorter []*Track + +func (t TrackSorter) Len() int { + return len(t) +} + +func (t TrackSorter) Swap(i, j int) { + t[i], t[j] = t[j], t[i] +} + +func (t TrackSorter) Less(i, j int) bool { + // + // TrackSorter is used to allocate layer-by-layer. + // So, higher priority track should come earlier so that it gets an earlier shot at each layer + // + if t[i].priority != t[j].priority { + return t[i].priority > t[j].priority + } + + if t[i].maxLayer.Spatial != t[j].maxLayer.Spatial { + return t[i].maxLayer.Spatial > t[j].maxLayer.Spatial + } + + return t[i].maxLayer.Temporal > t[j].maxLayer.Temporal +} + +// ------------------------------------------------ + +type MaxDistanceSorter []*Track + +func (m MaxDistanceSorter) Len() int { + return len(m) +} + +func (m MaxDistanceSorter) Swap(i, j int) { + m[i], m[j] = m[j], m[i] +} + +func (m MaxDistanceSorter) Less(i, j int) bool { + // + // MaxDistanceSorter is used to find a deficient track to use for probing during recovery from congestion. + // So, higher priority track should come earlier so that they have a chance to recover sooner. + // + if m[i].priority != m[j].priority { + return m[i].priority > m[j].priority + } + + return m[i].DistanceToDesired() > m[j].DistanceToDesired() +} + +// ------------------------------------------------ + +type MinDistanceSorter []*Track + +func (m MinDistanceSorter) Len() int { + return len(m) +} + +func (m MinDistanceSorter) Swap(i, j int) { + m[i], m[j] = m[j], m[i] +} + +func (m MinDistanceSorter) Less(i, j int) bool { + // + // MinDistanceSorter is used to find excess bandwidth in cooperative allocation. + // So, lower priority track should come earlier so that they contribute bandwidth to higher priority tracks. + // + if m[i].priority != m[j].priority { + return m[i].priority < m[j].priority + } + + return m[i].DistanceToDesired() < m[j].DistanceToDesired() +} + +// ------------------------------------------------ diff --git a/livekit/pkg/sfu/streamtracker/interfaces.go b/livekit/pkg/sfu/streamtracker/interfaces.go new file mode 100644 index 0000000..f3ad5e6 --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/interfaces.go @@ -0,0 +1,70 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "fmt" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" +) + +// ------------------------------------------------------------ + +type StreamStatusChange int32 + +func (s StreamStatusChange) String() string { + switch s { + case StreamStatusChangeNone: + return "none" + case StreamStatusChangeStopped: + return "stopped" + case StreamStatusChangeActive: + return "active" + default: + return fmt.Sprintf("unknown: %d", int(s)) + } +} + +const ( + StreamStatusChangeNone StreamStatusChange = iota + StreamStatusChangeStopped + StreamStatusChangeActive +) + +// ------------------------------------------------------------ + +type StreamTrackerImpl interface { + Start() + Stop() + Reset() + + GetCheckInterval() time.Duration + + Observe(hasMarker bool, ts uint32) StreamStatusChange + CheckStatus() StreamStatusChange +} + +type StreamTrackerWorker interface { + Start() + Stop() + Reset() + OnStatusChanged(f func(status StreamStatus)) + OnBitrateAvailable(f func()) + Status() StreamStatus + BitrateTemporalCumulative() []int64 + SetPaused(paused bool) + Observe(temporalLayer int32, pktSize int, payloadSize int, hasMarker bool, ts uint32, dd *buffer.ExtDependencyDescriptor) +} diff --git a/livekit/pkg/sfu/streamtracker/streamtracker.go b/livekit/pkg/sfu/streamtracker/streamtracker.go new file mode 100644 index 0000000..34a4a97 --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/streamtracker.go @@ -0,0 +1,307 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "fmt" + "sync" + "time" + + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" +) + +// ------------------------------------------------------------ + +type StreamStatus int32 + +func (s StreamStatus) String() string { + switch s { + case StreamStatusStopped: + return "stopped" + case StreamStatusActive: + return "active" + default: + return fmt.Sprintf("unknown: %d", int(s)) + } +} + +const ( + StreamStatusStopped StreamStatus = iota + StreamStatusActive +) + +// ------------------------------------------------------------ + +type StreamTrackerParams struct { + StreamTrackerImpl StreamTrackerImpl + BitrateReportInterval time.Duration + + Logger logger.Logger +} + +type StreamTracker struct { + params StreamTrackerParams + + onStatusChanged func(status StreamStatus) + onBitrateAvailable func() + + lock sync.RWMutex + + paused bool + generation atomic.Uint32 + + status StreamStatus + lastNotifiedStatus StreamStatus + + lastBitrateReport time.Time + bytesForBitrate [4]int64 + bitrate [4]int64 + + isStopped bool +} + +func NewStreamTracker(params StreamTrackerParams) *StreamTracker { + return &StreamTracker{ + params: params, + status: StreamStatusStopped, + } +} + +func (s *StreamTracker) OnStatusChanged(f func(status StreamStatus)) { + s.onStatusChanged = f +} + +func (s *StreamTracker) OnBitrateAvailable(f func()) { + s.onBitrateAvailable = f +} + +func (s *StreamTracker) Status() StreamStatus { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.status +} + +func (s *StreamTracker) setStatusLocked(status StreamStatus) { + s.status = status +} + +func (s *StreamTracker) maybeNotifyStatus() { + var status StreamStatus + notify := false + s.lock.Lock() + if s.status != s.lastNotifiedStatus { + notify = true + status = s.status + s.lastNotifiedStatus = s.status + } + s.lock.Unlock() + + if notify && s.onStatusChanged != nil { + s.onStatusChanged(status) + } +} + +func (s *StreamTracker) Start() { + s.lock.Lock() + defer s.lock.Unlock() + + s.params.StreamTrackerImpl.Start() +} + +func (s *StreamTracker) Stop() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.isStopped { + return + } + s.isStopped = true + + // bump generation to trigger exit of worker + s.generation.Inc() + + s.params.StreamTrackerImpl.Stop() +} + +func (s *StreamTracker) Reset() { + s.lock.Lock() + if s.isStopped { + s.lock.Unlock() + return + } + + s.resetLocked() + s.lock.Unlock() + + s.maybeNotifyStatus() +} + +func (s *StreamTracker) resetLocked() { + // bump generation to trigger exit of current worker + s.generation.Inc() + + s.setStatusLocked(StreamStatusStopped) + + for i := range len(s.bytesForBitrate) { + s.bytesForBitrate[i] = 0 + } + for i := range len(s.bitrate) { + s.bitrate[i] = 0 + } + + s.params.StreamTrackerImpl.Reset() +} + +func (s *StreamTracker) SetPaused(paused bool) { + s.lock.Lock() + s.paused = paused + if !paused { + s.resetLocked() + } else { + // bump generation to trigger exit of current worker + s.generation.Inc() + + s.setStatusLocked(StreamStatusStopped) + } + s.lock.Unlock() + + s.maybeNotifyStatus() +} + +func (s *StreamTracker) Observe( + temporalLayer int32, + pktSize int, + payloadSize int, + hasMarker bool, + ts uint32, + _ *buffer.ExtDependencyDescriptor, +) { + s.lock.Lock() + + if s.isStopped || s.paused || payloadSize == 0 { + s.lock.Unlock() + return + } + + statusChange := s.params.StreamTrackerImpl.Observe(hasMarker, ts) + if statusChange == StreamStatusChangeActive { + s.setStatusLocked(StreamStatusActive) + s.lastBitrateReport = time.Now() + + go s.worker(s.generation.Load()) + } + + if temporalLayer >= 0 { + s.bytesForBitrate[temporalLayer] += int64(pktSize) + } + s.lock.Unlock() + + if statusChange != StreamStatusChangeNone { + s.maybeNotifyStatus() + } +} + +// BitrateTemporalCumulative returns the current stream bitrate temporal layer accumulated with lower temporal layers. +func (s *StreamTracker) BitrateTemporalCumulative() []int64 { + s.lock.RLock() + defer s.lock.RUnlock() + + // copy and process + brs := make([]int64, len(s.bitrate)) + copy(brs, s.bitrate[:]) + + for i := len(brs) - 1; i >= 1; i-- { + if brs[i] != 0 { + for j := i - 1; j >= 0; j-- { + brs[i] += brs[j] + } + } + } + + // clear higher layers + for i := range brs { + if brs[i] == 0 { + for j := i + 1; j < len(brs); j++ { + brs[j] = 0 + } + } + } + + return brs +} + +func (s *StreamTracker) worker(generation uint32) { + ticker := time.NewTicker(s.params.StreamTrackerImpl.GetCheckInterval()) + defer ticker.Stop() + + tickerBitrate := time.NewTicker(s.params.BitrateReportInterval) + defer tickerBitrate.Stop() + + for { + select { + case <-ticker.C: + if generation != s.generation.Load() { + return + } + s.updateStatus() + + case <-tickerBitrate.C: + if generation != s.generation.Load() { + return + } + s.bitrateReport() + } + } +} + +func (s *StreamTracker) updateStatus() { + s.lock.Lock() + switch s.params.StreamTrackerImpl.CheckStatus() { + case StreamStatusChangeStopped: + s.setStatusLocked(StreamStatusStopped) + case StreamStatusChangeActive: + s.setStatusLocked(StreamStatusActive) + } + s.lock.Unlock() + + s.maybeNotifyStatus() +} + +func (s *StreamTracker) bitrateReport() { + // run this even if paused to drain out bitrate if there are no packets coming in + s.lock.Lock() + now := time.Now() + diff := now.Sub(s.lastBitrateReport) + s.lastBitrateReport = now + + bitrateAvailabilityChanged := false + for i := range len(s.bytesForBitrate) { + bitrate := int64(float64(s.bytesForBitrate[i]*8) / diff.Seconds()) + if (s.bitrate[i] == 0 && bitrate > 0) || (s.bitrate[i] > 0 && bitrate == 0) { + bitrateAvailabilityChanged = true + } + s.bitrate[i] = bitrate + s.bytesForBitrate[i] = 0 + } + s.lock.Unlock() + + if bitrateAvailabilityChanged && s.onBitrateAvailable != nil { + s.onBitrateAvailable() + } +} diff --git a/livekit/pkg/sfu/streamtracker/streamtracker_dd.go b/livekit/pkg/sfu/streamtracker/streamtracker_dd.go new file mode 100644 index 0000000..a87f2e7 --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/streamtracker_dd.go @@ -0,0 +1,303 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "sync" + "time" + + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" +) + +type StreamTrackerDependencyDescriptor struct { + lock sync.RWMutex + paused bool + generation atomic.Uint32 + params StreamTrackerParams + maxSpatialLayer int32 + maxTemporalLayer int32 + + onStatusChanged [buffer.DefaultMaxLayerSpatial + 1]func(status StreamStatus) + onBitrateAvailable [buffer.DefaultMaxLayerSpatial + 1]func() + + lastBitrateReport time.Time + bytesForBitrate [buffer.DefaultMaxLayerSpatial + 1][buffer.DefaultMaxLayerTemporal + 1]int64 + bitrate [buffer.DefaultMaxLayerSpatial + 1][buffer.DefaultMaxLayerTemporal + 1]int64 + + isStopped bool +} + +func NewStreamTrackerDependencyDescriptor(params StreamTrackerParams) *StreamTrackerDependencyDescriptor { + return &StreamTrackerDependencyDescriptor{ + params: params, + maxSpatialLayer: buffer.InvalidLayerSpatial, + maxTemporalLayer: buffer.InvalidLayerTemporal, + } +} +func (s *StreamTrackerDependencyDescriptor) Start() { +} + +func (s *StreamTrackerDependencyDescriptor) Stop() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.isStopped { + return + } + s.isStopped = true + + // bump generation to trigger exit of worker + s.generation.Inc() +} + +func (s *StreamTrackerDependencyDescriptor) OnStatusChanged(layer int32, f func(status StreamStatus)) { + s.lock.Lock() + s.onStatusChanged[layer] = f + s.lock.Unlock() +} + +func (s *StreamTrackerDependencyDescriptor) OnBitrateAvailable(layer int32, f func()) { + s.lock.Lock() + s.onBitrateAvailable[layer] = f + s.lock.Unlock() +} + +func (s *StreamTrackerDependencyDescriptor) Status(layer int32) StreamStatus { + s.lock.RLock() + defer s.lock.RUnlock() + + if layer > s.maxSpatialLayer { + return StreamStatusStopped + } + + return StreamStatusActive +} + +func (s *StreamTrackerDependencyDescriptor) BitrateTemporalCumulative(layer int32) []int64 { + s.lock.RLock() + defer s.lock.RUnlock() + + if layer > s.maxSpatialLayer { + brs := make([]int64, len(s.bitrate[0])) + return brs + } + + return s.bitrate[layer][:] +} + +func (s *StreamTrackerDependencyDescriptor) Reset() { +} + +func (s *StreamTrackerDependencyDescriptor) resetLocked() { + // bump generation to trigger exit of current worker + s.generation.Inc() + + for i := range len(s.bytesForBitrate) { + for j := range len(s.bytesForBitrate[i]) { + s.bytesForBitrate[i][j] = 0 + } + } + for i := range len(s.bitrate) { + for j := range len(s.bitrate[i]) { + s.bitrate[i][j] = 0 + } + } + + s.maxSpatialLayer = buffer.InvalidLayerSpatial + s.maxTemporalLayer = buffer.InvalidLayerTemporal +} + +func (s *StreamTrackerDependencyDescriptor) SetPaused(paused bool) { + s.lock.Lock() + if s.paused == paused { + s.lock.Unlock() + return + } + s.paused = paused + + var notifyFns []func(status StreamStatus) + var notifyStatus StreamStatus + if !paused { + s.resetLocked() + + notifyStatus = StreamStatusStopped + notifyFns = append(notifyFns, s.onStatusChanged[:]...) + } else { + s.lastBitrateReport = time.Now() + go s.worker(s.generation.Inc()) + + } + s.lock.Unlock() + + for _, fn := range notifyFns { + if fn != nil { + fn(notifyStatus) + } + } +} + +func (s *StreamTrackerDependencyDescriptor) Observe(temporalLayer int32, pktSize int, payloadSize int, hasMarker bool, ts uint32, ddVal *buffer.ExtDependencyDescriptor) { + s.lock.Lock() + + if s.isStopped || s.paused || payloadSize == 0 || ddVal == nil { + s.lock.Unlock() + return + } + + var notifyFns []func(status StreamStatus) + var notifyStatus StreamStatus + if mask := ddVal.Descriptor.ActiveDecodeTargetsBitmask; mask != nil && ddVal.ActiveDecodeTargetsUpdated { + var maxSpatial, maxTemporal int32 + for _, dt := range ddVal.DecodeTargets { + if *mask&(1< buffer.DefaultMaxLayerSpatial { + maxSpatial = buffer.DefaultMaxLayerSpatial + s.params.Logger.Warnw("max spatial layer exceeded", nil, "maxSpatial", maxSpatial) + } + if maxTemporal > buffer.DefaultMaxLayerTemporal { + maxTemporal = buffer.DefaultMaxLayerTemporal + s.params.Logger.Warnw("max temporal layer exceeded", nil, "maxTemporal", maxTemporal) + } + + s.params.Logger.Debugw("max layer changed", "maxSpatial", maxSpatial, "maxTemporal", maxTemporal) + oldMaxSpatial := s.maxSpatialLayer + s.maxSpatialLayer, s.maxTemporalLayer = maxSpatial, maxTemporal + if oldMaxSpatial == -1 { + s.lastBitrateReport = time.Now() + go s.worker(s.generation.Inc()) + } + + if oldMaxSpatial > s.maxSpatialLayer { + notifyStatus = StreamStatusStopped + for i := s.maxSpatialLayer + 1; i <= oldMaxSpatial; i++ { + notifyFns = append(notifyFns, s.onStatusChanged[i]) + } + } else if oldMaxSpatial < s.maxSpatialLayer { + notifyStatus = StreamStatusActive + for i := oldMaxSpatial + 1; i <= s.maxSpatialLayer; i++ { + notifyFns = append(notifyFns, s.onStatusChanged[i]) + } + } + } + + dtis := ddVal.Descriptor.FrameDependencies.DecodeTargetIndications + + for _, dt := range ddVal.DecodeTargets { + if len(dtis) <= dt.Target { + s.params.Logger.Errorw("len(dtis) less than target", nil, "target", dt.Target, "dtis", dtis) + continue + } + // we are not dropping discardable frames now, so only ingore not present frames + if dtis[dt.Target] == dd.DecodeTargetNotPresent { + continue + } + + s.bytesForBitrate[dt.Layer.Spatial][dt.Layer.Temporal] += int64(pktSize) + } + + s.lock.Unlock() + + for _, fn := range notifyFns { + if fn != nil { + fn(notifyStatus) + } + } +} + +func (s *StreamTrackerDependencyDescriptor) worker(generation uint32) { + tickerBitrate := time.NewTicker(s.params.BitrateReportInterval) + defer tickerBitrate.Stop() + + for { + <-tickerBitrate.C + if generation != s.generation.Load() { + return + } + s.bitrateReport() + } +} + +func (s *StreamTrackerDependencyDescriptor) bitrateReport() { + // run this even if paused to drain out bitrate if there are no packets coming in + s.lock.Lock() + now := time.Now() + diff := now.Sub(s.lastBitrateReport) + s.lastBitrateReport = now + + var availableChangedFns []func() + for spatial := range len(s.bytesForBitrate) { + bytesForBitrate := s.bytesForBitrate[spatial][:] + bitrateAvailabilityChanged := false + bitrates := s.bitrate[spatial][:] + for i := range bytesForBitrate { + bitrate := int64(float64(bytesForBitrate[i]*8) / diff.Seconds()) + if (bitrates[i] == 0 && bitrate > 0) || (bitrates[i] > 0 && bitrate == 0) { + bitrateAvailabilityChanged = true + } + bitrates[i] = bitrate + bytesForBitrate[i] = 0 + } + + if bitrateAvailabilityChanged && s.onBitrateAvailable[spatial] != nil { + availableChangedFns = append(availableChangedFns, s.onBitrateAvailable[spatial]) + } + } + s.lock.Unlock() + + for _, fn := range availableChangedFns { + fn() + } +} + +func (s *StreamTrackerDependencyDescriptor) LayeredTracker(layer int32) *StreamTrackerDependencyDescriptorLayered { + return &StreamTrackerDependencyDescriptorLayered{ + StreamTrackerDependencyDescriptor: s, + layer: layer, + } +} + +// ---------------------------- +// Layered wrapper for StreamTrackerWorker +type StreamTrackerDependencyDescriptorLayered struct { + *StreamTrackerDependencyDescriptor + layer int32 +} + +func (s *StreamTrackerDependencyDescriptorLayered) OnStatusChanged(f func(status StreamStatus)) { + s.StreamTrackerDependencyDescriptor.OnStatusChanged(s.layer, f) +} + +func (s *StreamTrackerDependencyDescriptorLayered) OnBitrateAvailable(f func()) { + s.StreamTrackerDependencyDescriptor.OnBitrateAvailable(s.layer, f) +} + +func (s *StreamTrackerDependencyDescriptorLayered) Status() StreamStatus { + return s.StreamTrackerDependencyDescriptor.Status(s.layer) +} + +func (s *StreamTrackerDependencyDescriptorLayered) BitrateTemporalCumulative() []int64 { + return s.StreamTrackerDependencyDescriptor.BitrateTemporalCumulative(s.layer) +} diff --git a/livekit/pkg/sfu/streamtracker/streamtracker_dd_test.go b/livekit/pkg/sfu/streamtracker/streamtracker_dd_test.go new file mode 100644 index 0000000..8e3f400 --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/streamtracker_dd_test.go @@ -0,0 +1,98 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/protocol/logger" +) + +func createDescriptorDependencyForTargets(maxSpatial, maxTemporal int) *buffer.ExtDependencyDescriptor { + var targets []buffer.DependencyDescriptorDecodeTarget + var mask uint32 + for i := 0; i <= maxSpatial; i++ { + for j := 0; j <= maxTemporal; j++ { + targets = append(targets, buffer.DependencyDescriptorDecodeTarget{Target: len(targets), Layer: buffer.VideoLayer{Spatial: int32(i), Temporal: int32(j)}}) + mask |= 1 << uint32(len(targets)-1) + } + } + + dtis := make([]dd.DecodeTargetIndication, len(targets)) + for _, t := range targets { + dtis[t.Target] = dd.DecodeTargetRequired + } + + return &buffer.ExtDependencyDescriptor{ + Descriptor: &dd.DependencyDescriptor{ + ActiveDecodeTargetsBitmask: &mask, + FrameDependencies: &dd.FrameDependencyTemplate{ + DecodeTargetIndications: dtis, + }, + }, + DecodeTargets: targets, + ActiveDecodeTargetsUpdated: true, + } +} + +func checkStatues(t *testing.T, statuses []StreamStatus, expected StreamStatus, maxSpatial int) { + for i := 0; i <= maxSpatial; i++ { + require.Equal(t, expected, statuses[i]) + } + + for i := maxSpatial + 1; i < len(statuses); i++ { + require.NotEqual(t, expected, statuses[i]) + } +} + +func TestStreamTrackerDD(t *testing.T) { + ddTracker := NewStreamTrackerDependencyDescriptor(StreamTrackerParams{ + BitrateReportInterval: 1 * time.Second, + Logger: logger.GetLogger(), + }) + layeredTrackers := make([]StreamTrackerWorker, buffer.DefaultMaxLayerSpatial+1) + statuses := make([]StreamStatus, buffer.DefaultMaxLayerSpatial+1) + for i := 0; i <= int(buffer.DefaultMaxLayerSpatial); i++ { + layeredTrack := ddTracker.LayeredTracker(int32(i)) + layer := i + layeredTrack.OnStatusChanged(func(status StreamStatus) { + statuses[layer] = status + }) + layeredTrack.Start() + layeredTrackers[i] = layeredTrack + } + defer ddTracker.Stop() + + // no active layers + ddTracker.Observe(0, 1000, 1000, false, 0, nil) + checkStatues(t, statuses, StreamStatusActive, int(buffer.InvalidLayerSpatial)) + + // layer seen [0,1] + ddTracker.Observe(0, 1000, 1000, false, 0, createDescriptorDependencyForTargets(1, 1)) + checkStatues(t, statuses, StreamStatusActive, 1) + + // layer seen [0,1,2] + ddTracker.Observe(0, 1000, 1000, false, 0, createDescriptorDependencyForTargets(2, 1)) + checkStatues(t, statuses, StreamStatusActive, 2) + + // layer 2 gone, layer seen [0,1] + ddTracker.Observe(0, 1000, 1000, false, 0, createDescriptorDependencyForTargets(1, 1)) + checkStatues(t, statuses, StreamStatusActive, 1) +} diff --git a/livekit/pkg/sfu/streamtracker/streamtracker_frame.go b/livekit/pkg/sfu/streamtracker/streamtracker_frame.go new file mode 100644 index 0000000..e0f9f57 --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/streamtracker_frame.go @@ -0,0 +1,244 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "math" + "time" + + "github.com/livekit/protocol/logger" +) + +const ( + checkInterval = 500 * time.Millisecond + statusCheckTolerance = 0.98 + frameRateResolution = 0.01 // 1 frame every 100 seconds + frameRateIncreaseFactor = 0.6 // slow increase + frameRateDecreaseFactor = 0.9 // fast decrease +) + +// ------------------------------------------------------- + +type StreamTrackerFrameConfig struct { + MinFPS float64 `yaml:"min_fps,omitempty"` +} + +var ( + DefaultStreamTrackerFrameConfigVideo = map[int32]StreamTrackerFrameConfig{ + 0: { + MinFPS: 5.0, + }, + 1: { + MinFPS: 5.0, + }, + 2: { + MinFPS: 5.0, + }, + } + + DefaultStreamTrackerFrameConfigScreenshare = map[int32]StreamTrackerFrameConfig{ + 0: { + MinFPS: 0.5, + }, + 1: { + MinFPS: 0.5, + }, + 2: { + MinFPS: 0.5, + }, + } +) + +// ------------------------------------------------------- + +type StreamTrackerFrameParams struct { + Config StreamTrackerFrameConfig + ClockRate uint32 + Logger logger.Logger +} + +type StreamTrackerFrame struct { + params StreamTrackerFrameParams + + initialized bool + + tsInitialized bool + oldestTS uint32 + newestTS uint32 + numFrames int + + estimatedFrameRate float64 + evalInterval time.Duration + lastStatusCheckAt time.Time +} + +func NewStreamTrackerFrame(params StreamTrackerFrameParams) StreamTrackerImpl { + s := &StreamTrackerFrame{ + params: params, + } + s.Reset() + return s +} + +func (s *StreamTrackerFrame) Start() { +} + +func (s *StreamTrackerFrame) Stop() { +} + +func (s *StreamTrackerFrame) Reset() { + s.initialized = false + + s.resetFPSCalculator() + + s.lastStatusCheckAt = time.Time{} +} + +func (s *StreamTrackerFrame) resetFPSCalculator() { + s.tsInitialized = false + s.oldestTS = 0 + s.newestTS = 0 + s.numFrames = 0 + + s.estimatedFrameRate = 0.0 + s.updateEvalInterval() +} + +func (s *StreamTrackerFrame) GetCheckInterval() time.Duration { + return checkInterval +} + +func (s *StreamTrackerFrame) Observe(hasMarker bool, ts uint32) StreamStatusChange { + if hasMarker { + if !s.tsInitialized { + s.tsInitialized = true + s.oldestTS = ts + s.newestTS = ts + s.numFrames = 1 + } else { + diff := ts - s.oldestTS + if diff > (1 << 31) { + s.oldestTS = ts + } + diff = ts - s.newestTS + if diff < (1 << 31) { + s.newestTS = ts + } + s.numFrames++ + } + } + + // When starting up, check for first packet and declare active. + // Happens under following conditions + // 1. Start up + // 2. Unmute (stream restarting) + // 3. Layer starting after dynacast pause + if !s.initialized { + s.initialized = true + s.lastStatusCheckAt = time.Now() + return StreamStatusChangeActive + } + + return StreamStatusChangeNone +} + +func (s *StreamTrackerFrame) CheckStatus() StreamStatusChange { + if !s.initialized { + // should not be getting called when not initialized, but be safe + return StreamStatusChangeNone + } + + if !s.updateStatusCheckTime() { + return StreamStatusChangeNone + } + + if s.updateEstimatedFrameRate() == 0.0 { + // when stream is stopped, reset FPS calculator to ensure re-start is not done until at least two frames are available, + // i. e. enough frames available to be able to calculate FPS + s.resetFPSCalculator() + return StreamStatusChangeStopped + } + + return StreamStatusChangeActive +} + +func (s *StreamTrackerFrame) updateStatusCheckTime() bool { + // check only at intervals based on estimated frame rate + if s.lastStatusCheckAt.IsZero() { + s.lastStatusCheckAt = time.Now() + } + if time.Since(s.lastStatusCheckAt) < time.Duration(statusCheckTolerance*float64(s.evalInterval)) { + return false + } + s.lastStatusCheckAt = time.Now() + return true +} + +func (s *StreamTrackerFrame) updateEstimatedFrameRate() float64 { + diff := s.newestTS - s.oldestTS + if diff == 0 || s.numFrames < 2 { + return 0.0 + } + + frameRate := roundFrameRate(float64(s.params.ClockRate) / float64(diff) * float64(s.numFrames-1)) + + // reset for next evaluation interval + s.oldestTS = s.newestTS + s.numFrames = 1 + + factor := 1.0 + switch { + case s.estimatedFrameRate < frameRate: + // slow increase, prevents shortening eval interval too quickly on frame rate going up + factor = frameRateIncreaseFactor + case s.estimatedFrameRate > frameRate: + // fast decrease, prevents declaring stream stop too quickly on frame rate going down + factor = frameRateDecreaseFactor + } + + estimatedFrameRate := roundFrameRate(frameRate*factor + s.estimatedFrameRate*(1.0-factor)) + if s.estimatedFrameRate != estimatedFrameRate { + s.estimatedFrameRate = estimatedFrameRate + s.updateEvalInterval() + s.params.Logger.Debugw("updating estimated frame rate", "estimatedFPS", estimatedFrameRate, "evalInterval", s.evalInterval) + } + + return frameRate +} + +func (s *StreamTrackerFrame) updateEvalInterval() { + // STREAM-TRACKER-FRAME-TODO: This will run into challenges for frame rate falling steeply, How to address that? + // Maybe, look at some referential rules (between layers) for possibilities to solve it. Currently, this is addressed + // by setting a source aware min FPS to ensure evaluation window is long enough to avoid declaring stop too quickly. + s.evalInterval = checkInterval + if s.estimatedFrameRate > 0.0 { + estimatedFrameRateInterval := time.Duration(float64(time.Second) / s.estimatedFrameRate) + if estimatedFrameRateInterval > s.evalInterval { + s.evalInterval = estimatedFrameRateInterval + } + } + if s.params.Config.MinFPS > 0.0 { + minFPSInterval := time.Duration(float64(time.Second) / s.params.Config.MinFPS) + if minFPSInterval > s.evalInterval { + s.evalInterval = minFPSInterval + } + } +} + +// ------------------------------------------------------------------------------ + +func roundFrameRate(frameRate float64) float64 { + return math.Round(frameRate/frameRateResolution) * frameRateResolution +} diff --git a/livekit/pkg/sfu/streamtracker/streamtracker_packet.go b/livekit/pkg/sfu/streamtracker/streamtracker_packet.go new file mode 100644 index 0000000..1ff3be1 --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/streamtracker_packet.go @@ -0,0 +1,141 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "time" + + "github.com/livekit/protocol/logger" +) + +// -------------------------------------------- + +type StreamTrackerPacketConfig struct { + SamplesRequired uint32 `yaml:"samples_required,omitempty"` // number of samples needed per cycle + CyclesRequired uint32 `yaml:"cycles_required,omitempty"` // number of cycles needed to be active + CycleDuration time.Duration `yaml:"cycle_duration,omitempty"` +} + +var ( + DefaultStreamTrackerPacketConfigVideo = map[int32]StreamTrackerPacketConfig{ + 0: {SamplesRequired: 1, + CyclesRequired: 4, + CycleDuration: 500 * time.Millisecond, + }, + 1: {SamplesRequired: 5, + CyclesRequired: 20, + CycleDuration: 500 * time.Millisecond, + }, + 2: {SamplesRequired: 5, + CyclesRequired: 20, + CycleDuration: 500 * time.Millisecond, + }, + } + + DefaultStreamTrackerPacketConfigScreenshare = map[int32]StreamTrackerPacketConfig{ + 0: { + SamplesRequired: 1, + CyclesRequired: 1, + CycleDuration: 2 * time.Second, + }, + 1: { + SamplesRequired: 1, + CyclesRequired: 1, + CycleDuration: 2 * time.Second, + }, + 2: { + SamplesRequired: 1, + CyclesRequired: 1, + CycleDuration: 2 * time.Second, + }, + } +) + +// -------------------------------------------- + +type StreamTrackerPacketParams struct { + Config StreamTrackerPacketConfig + Logger logger.Logger +} + +type StreamTrackerPacket struct { + params StreamTrackerPacketParams + + countSinceLast uint32 // number of packets received since last check + + initialized bool + + cycleCount uint32 +} + +func NewStreamTrackerPacket(params StreamTrackerPacketParams) StreamTrackerImpl { + return &StreamTrackerPacket{ + params: params, + } +} + +func (s *StreamTrackerPacket) Start() { +} + +func (s *StreamTrackerPacket) Stop() { +} + +func (s *StreamTrackerPacket) Reset() { + s.countSinceLast = 0 + s.cycleCount = 0 + + s.initialized = false +} + +func (s *StreamTrackerPacket) GetCheckInterval() time.Duration { + return s.params.Config.CycleDuration +} + +func (s *StreamTrackerPacket) Observe(_hasMarker bool, _ts uint32) StreamStatusChange { + if !s.initialized { + // first packet + s.initialized = true + s.countSinceLast = 1 + return StreamStatusChangeActive + } + + s.countSinceLast++ + return StreamStatusChangeNone +} + +func (s *StreamTrackerPacket) CheckStatus() StreamStatusChange { + if !s.initialized { + // should not be getting called when not initialized, but be safe + return StreamStatusChangeNone + } + + if s.countSinceLast >= s.params.Config.SamplesRequired { + s.cycleCount++ + } else { + s.cycleCount = 0 + } + + statusChange := StreamStatusChangeNone + if s.cycleCount == 0 { + // no packets seen for a period, flip to stopped + statusChange = StreamStatusChangeStopped + } else if s.cycleCount >= s.params.Config.CyclesRequired { + // packets seen for some time after resume, flip to active + statusChange = StreamStatusChangeActive + } + + s.countSinceLast = 0 + return statusChange +} diff --git a/livekit/pkg/sfu/streamtracker/streamtracker_packet_test.go b/livekit/pkg/sfu/streamtracker/streamtracker_packet_test.go new file mode 100644 index 0000000..3f1b65f --- /dev/null +++ b/livekit/pkg/sfu/streamtracker/streamtracker_packet_test.go @@ -0,0 +1,222 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamtracker + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/protocol/logger" +) + +func newStreamTrackerPacket(samplesRequired uint32, cyclesRequired uint32, cycleDuration time.Duration) *StreamTracker { + stp := NewStreamTrackerPacket(StreamTrackerPacketParams{ + Config: StreamTrackerPacketConfig{ + SamplesRequired: samplesRequired, + CyclesRequired: cyclesRequired, + CycleDuration: cycleDuration, + }, + Logger: logger.GetLogger(), + }) + + return NewStreamTracker(StreamTrackerParams{ + StreamTrackerImpl: stp, + BitrateReportInterval: 1 * time.Second, + Logger: logger.GetLogger(), + }) +} + +func TestStreamTracker(t *testing.T) { + t.Run("flips to active on first observe", func(t *testing.T) { + callbackCalled := atomic.NewBool(false) + tracker := newStreamTrackerPacket(5, 60, 500*time.Millisecond) + tracker.Start() + tracker.OnStatusChanged(func(status StreamStatus) { + callbackCalled.Store(true) + }) + require.Equal(t, StreamStatusStopped, tracker.Status()) + + // observe first packet + tracker.Observe(0, 20, 10, false, 0, nil) + + testutils.WithTimeout(t, func() string { + if callbackCalled.Load() { + return "" + } + + return "first packet didn't activate stream" + }) + + require.Equal(t, StreamStatusActive, tracker.Status()) + require.True(t, callbackCalled.Load()) + + tracker.Stop() + }) + + t.Run("flips to inactive immediately", func(t *testing.T) { + tracker := newStreamTrackerPacket(5, 60, 500*time.Millisecond) + tracker.Start() + require.Equal(t, StreamStatusStopped, tracker.Status()) + + callbackStatusMu := sync.RWMutex{} + callbackStatusMu.Lock() + callbackStatus := StreamStatusStopped + callbackStatusMu.Unlock() + tracker.OnStatusChanged(func(status StreamStatus) { + callbackStatusMu.Lock() + callbackStatus = status + callbackStatusMu.Unlock() + }) + + tracker.Observe(0, 20, 10, false, 0, nil) + testutils.WithTimeout(t, func() string { + callbackStatusMu.RLock() + defer callbackStatusMu.RUnlock() + + if callbackStatus == StreamStatusActive { + return "" + } + + return "first packet did not activate stream" + }) + require.Equal(t, StreamStatusActive, tracker.Status()) + + // run a single iteration + tracker.updateStatus() + + testutils.WithTimeout(t, func() string { + callbackStatusMu.RLock() + defer callbackStatusMu.RUnlock() + + if callbackStatus == StreamStatusStopped { + return "" + } + + return "inactive cycle did not declare stream stopped" + }) + require.Equal(t, StreamStatusStopped, tracker.Status()) + require.Equal(t, StreamStatusStopped, callbackStatus) + + tracker.Stop() + }) + + t.Run("flips back to active after iterations", func(t *testing.T) { + tracker := newStreamTrackerPacket(1, 2, 500*time.Millisecond) + tracker.Start() + require.Equal(t, StreamStatusStopped, tracker.Status()) + + tracker.Observe(0, 20, 10, false, 0, nil) + testutils.WithTimeout(t, func() string { + if tracker.Status() == StreamStatusActive { + return "" + } + + return "first packet did not activate stream" + }) + + tracker.setStatusLocked(StreamStatusStopped) + + tracker.Observe(0, 20, 10, false, 0, nil) + tracker.updateStatus() + require.Equal(t, StreamStatusStopped, tracker.Status()) + + tracker.Observe(0, 20, 10, false, 0, nil) + tracker.updateStatus() + require.Equal(t, StreamStatusActive, tracker.Status()) + + tracker.Stop() + }) + + t.Run("changes to inactive when paused", func(t *testing.T) { + tracker := newStreamTrackerPacket(5, 60, 500*time.Millisecond) + tracker.Start() + tracker.Observe(0, 20, 10, false, 0, nil) + testutils.WithTimeout(t, func() string { + if tracker.Status() == StreamStatusActive { + return "" + } + + return "first packet did not activate stream" + }) + + tracker.SetPaused(true) + tracker.updateStatus() + require.Equal(t, StreamStatusStopped, tracker.Status()) + + tracker.Stop() + }) + + t.Run("flips back to active on first observe after reset", func(t *testing.T) { + callbackCalled := atomic.NewUint32(0) + tracker := newStreamTrackerPacket(5, 60, 500*time.Millisecond) + tracker.Start() + tracker.OnStatusChanged(func(status StreamStatus) { + callbackCalled.Inc() + }) + require.Equal(t, StreamStatusStopped, tracker.Status()) + + // observe first packet + tracker.Observe(0, 20, 10, false, 0, nil) + + testutils.WithTimeout(t, func() string { + if callbackCalled.Load() == 1 { + return "" + } + + return fmt.Sprintf("expected onStatusChanged to be called once, actual: %d", callbackCalled.Load()) + }) + + require.Equal(t, StreamStatusActive, tracker.Status()) + require.Equal(t, uint32(1), callbackCalled.Load()) + + // observe a few more + tracker.Observe(0, 20, 10, false, 0, nil) + tracker.Observe(0, 20, 10, false, 0, nil) + tracker.Observe(0, 20, 10, false, 0, nil) + tracker.Observe(0, 20, 10, false, 0, nil) + tracker.updateStatus() + + // should still be active + require.Equal(t, StreamStatusActive, tracker.Status()) + require.Equal(t, uint32(1), callbackCalled.Load()) + + // Reset. The first packet after reset should flip state again + tracker.Reset() + require.Equal(t, StreamStatusStopped, tracker.Status()) + require.Equal(t, uint32(2), callbackCalled.Load()) + + // first packet after reset + tracker.Observe(0, 20, 10, false, 0, nil) + + testutils.WithTimeout(t, func() string { + if callbackCalled.Load() == 3 { + return "" + } + + return fmt.Sprintf("expected onStatusChanged to be called thrice, actual %d", callbackCalled.Load()) + }) + + require.Equal(t, StreamStatusActive, tracker.Status()) + require.Equal(t, uint32(3), callbackCalled.Load()) + + tracker.Stop() + }) +} diff --git a/livekit/pkg/sfu/streamtrackermanager.go b/livekit/pkg/sfu/streamtrackermanager.go new file mode 100644 index 0000000..b102673 --- /dev/null +++ b/livekit/pkg/sfu/streamtrackermanager.go @@ -0,0 +1,661 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import ( + "slices" + "sort" + "sync" + "time" + + "github.com/frostbyte73/core" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/sfu/streamtracker" +) + +// --------------------------------------------------- + +type StreamTrackerManagerListener interface { + OnAvailableLayersChanged() + OnBitrateAvailabilityChanged() + OnMaxPublishedLayerChanged(maxPublishedLayer int32) + OnMaxTemporalLayerSeenChanged(maxTemporalLayerSeen int32) + OnMaxAvailableLayerChanged(maxAvailableLayer int32) + OnBitrateReport(availableLayers []int32, bitrates Bitrates) +} + +// --------------------------------------------------- + +type ( + StreamTrackerType string +) + +const ( + StreamTrackerTypePacket StreamTrackerType = "packet" + StreamTrackerTypeFrame StreamTrackerType = "frame" +) + +type StreamTrackerPacketConfig struct { + SamplesRequired uint32 `yaml:"samples_required,omitempty"` // number of samples needed per cycle + CyclesRequired uint32 `yaml:"cycles_required,omitempty"` // number of cycles needed to be active + CycleDuration time.Duration `yaml:"cycle_duration,omitempty"` +} + +type StreamTrackerFrameConfig struct { + MinFPS float64 `yaml:"min_fps,omitempty"` +} + +type StreamTrackerConfig struct { + StreamTrackerType StreamTrackerType `yaml:"stream_tracker_type,omitempty"` + BitrateReportInterval map[int32]time.Duration `yaml:"bitrate_report_interval,omitempty"` + PacketTracker map[int32]streamtracker.StreamTrackerPacketConfig `yaml:"packet_tracker,omitempty"` + FrameTracker map[int32]streamtracker.StreamTrackerFrameConfig `yaml:"frame_tracker,omitempty"` +} + +var ( + DefaultStreamTrackerConfigVideo = StreamTrackerConfig{ + StreamTrackerType: StreamTrackerTypePacket, + BitrateReportInterval: map[int32]time.Duration{ + 0: 1 * time.Second, + 1: 1 * time.Second, + 2: 1 * time.Second, + }, + PacketTracker: streamtracker.DefaultStreamTrackerPacketConfigVideo, + FrameTracker: streamtracker.DefaultStreamTrackerFrameConfigVideo, + } + + DefaultStreamTrackerConfigScreenshare = StreamTrackerConfig{ + StreamTrackerType: StreamTrackerTypePacket, + BitrateReportInterval: map[int32]time.Duration{ + 0: 4 * time.Second, + 1: 4 * time.Second, + 2: 4 * time.Second, + }, + PacketTracker: streamtracker.DefaultStreamTrackerPacketConfigScreenshare, + FrameTracker: streamtracker.DefaultStreamTrackerFrameConfigScreenshare, + } +) + +// --------------------------------------------------- + +type StreamTrackerManagerConfig struct { + Video StreamTrackerConfig `yaml:"video,omitempty"` + Screenshare StreamTrackerConfig `yaml:"screenshare,omitempty"` +} + +var ( + DefaultStreamTrackerManagerConfig = StreamTrackerManagerConfig{ + Video: DefaultStreamTrackerConfigVideo, + Screenshare: DefaultStreamTrackerConfigScreenshare, + } +) + +// --------------------------------------------------- + +type StreamTrackerManager struct { + logger logger.Logger + trackInfo atomic.Pointer[livekit.TrackInfo] + mimeType mime.MimeType + videoLayerMode livekit.VideoLayer_Mode + clockRate uint32 + + trackerConfig StreamTrackerConfig + + lock sync.RWMutex + maxPublishedLayer int32 + maxTemporalLayerSeen int32 + + ddTracker *streamtracker.StreamTrackerDependencyDescriptor + trackers [buffer.DefaultMaxLayerSpatial + 1]streamtracker.StreamTrackerWorker + + availableLayers []int32 + maxExpectedLayer int32 + paused bool + + closed core.Fuse + + listener StreamTrackerManagerListener +} + +func NewStreamTrackerManager( + logger logger.Logger, + trackInfo *livekit.TrackInfo, + mimeType mime.MimeType, + clockRate uint32, + config StreamTrackerManagerConfig, +) *StreamTrackerManager { + s := &StreamTrackerManager{ + logger: logger, + mimeType: mimeType, + videoLayerMode: buffer.GetVideoLayerModeForMimeType(mimeType, trackInfo), + maxPublishedLayer: buffer.InvalidLayerSpatial, + maxTemporalLayerSeen: buffer.InvalidLayerTemporal, + clockRate: clockRate, + } + s.trackInfo.Store(utils.CloneProto(trackInfo)) + + switch trackInfo.Source { + case livekit.TrackSource_SCREEN_SHARE: + s.trackerConfig = config.Screenshare + case livekit.TrackSource_CAMERA: + s.trackerConfig = config.Video + default: + s.trackerConfig = config.Video + } + + s.maxExpectedLayerFromTrackInfo() + + if trackInfo.Type == livekit.TrackType_VIDEO { + go s.bitrateReporter() + } + return s +} + +func (s *StreamTrackerManager) Close() { + s.closed.Break() +} + +func (s *StreamTrackerManager) SetListener(listener StreamTrackerManagerListener) { + s.lock.Lock() + s.listener = listener + s.lock.Unlock() +} + +func (s *StreamTrackerManager) getListener() StreamTrackerManagerListener { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.listener +} + +func (s *StreamTrackerManager) createStreamTrackerPacket(layer int32) streamtracker.StreamTrackerImpl { + packetTrackerConfig, ok := s.trackerConfig.PacketTracker[layer] + if !ok { + return nil + } + + params := streamtracker.StreamTrackerPacketParams{ + Config: packetTrackerConfig, + Logger: s.logger.WithValues("layer", layer), + } + return streamtracker.NewStreamTrackerPacket(params) +} + +func (s *StreamTrackerManager) createStreamTrackerFrame(layer int32) streamtracker.StreamTrackerImpl { + frameTrackerConfig, ok := s.trackerConfig.FrameTracker[layer] + if !ok { + return nil + } + + params := streamtracker.StreamTrackerFrameParams{ + Config: frameTrackerConfig, + ClockRate: s.clockRate, + Logger: s.logger.WithValues("layer", layer), + } + return streamtracker.NewStreamTrackerFrame(params) +} + +func (s *StreamTrackerManager) AddDependencyDescriptorTrackers() { + bitrateInterval, ok := s.trackerConfig.BitrateReportInterval[0] + if !ok { + return + } + s.lock.Lock() + var addAllTrackers bool + if s.ddTracker == nil { + s.ddTracker = streamtracker.NewStreamTrackerDependencyDescriptor(streamtracker.StreamTrackerParams{ + BitrateReportInterval: bitrateInterval, + Logger: s.logger.WithValues("layer", 0), + }) + addAllTrackers = true + } + s.lock.Unlock() + if addAllTrackers { + for i := 0; i <= int(buffer.DefaultMaxLayerSpatial); i++ { + s.AddTracker(int32(i)) + } + } +} + +func (s *StreamTrackerManager) AddTracker(layer int32) streamtracker.StreamTrackerWorker { + if layer < 0 || int(layer) >= len(s.trackers) { + return nil + } + + var tracker streamtracker.StreamTrackerWorker + s.lock.Lock() + tracker = s.trackers[layer] + if tracker != nil { + s.lock.Unlock() + return tracker + } + + if s.ddTracker != nil { + tracker = s.ddTracker.LayeredTracker(layer) + } + s.lock.Unlock() + + bitrateInterval, ok := s.trackerConfig.BitrateReportInterval[layer] + if !ok { + return nil + } + + if tracker == nil { + var trackerImpl streamtracker.StreamTrackerImpl + switch s.trackerConfig.StreamTrackerType { + case StreamTrackerTypePacket: + trackerImpl = s.createStreamTrackerPacket(layer) + case StreamTrackerTypeFrame: + trackerImpl = s.createStreamTrackerFrame(layer) + } + if trackerImpl == nil { + return nil + } + + tracker = streamtracker.NewStreamTracker(streamtracker.StreamTrackerParams{ + StreamTrackerImpl: trackerImpl, + BitrateReportInterval: bitrateInterval, + Logger: s.logger.WithValues("layer", layer), + }) + } + + s.logger.Debugw("stream tracker add track", "layer", layer) + tracker.OnStatusChanged(func(status streamtracker.StreamStatus) { + s.logger.Debugw("stream tracker status changed", "layer", layer, "status", status) + if status == streamtracker.StreamStatusStopped { + s.removeAvailableLayer(layer) + } else { + s.addAvailableLayer(layer) + } + }) + tracker.OnBitrateAvailable(func() { + if listener := s.getListener(); listener != nil { + listener.OnBitrateAvailabilityChanged() + } + }) + + s.lock.Lock() + paused := s.paused + s.trackers[layer] = tracker + + notify := false + if layer > s.maxPublishedLayer { + s.maxPublishedLayer = layer + notify = true + } + s.lock.Unlock() + + if notify { + if listener := s.getListener(); listener != nil { + go listener.OnMaxPublishedLayerChanged(layer) + } + } + + tracker.SetPaused(paused) + tracker.Start() + return tracker +} + +func (s *StreamTrackerManager) RemoveTracker(layer int32) { + s.lock.Lock() + tracker := s.trackers[layer] + s.trackers[layer] = nil + s.lock.Unlock() + + if tracker != nil { + tracker.Stop() + } +} + +func (s *StreamTrackerManager) RemoveAllTrackers() { + s.lock.Lock() + trackers := s.trackers + for layer := range s.trackers { + s.trackers[layer] = nil + } + s.availableLayers = make([]int32, 0) + s.maxExpectedLayerFromTrackInfoLocked() + s.paused = false + ddTracker := s.ddTracker + s.ddTracker = nil + s.lock.Unlock() + + for _, tracker := range trackers { + if tracker != nil { + tracker.Stop() + } + } + if ddTracker != nil { + ddTracker.Stop() + } +} + +func (s *StreamTrackerManager) GetTracker(layer int32) streamtracker.StreamTrackerWorker { + s.lock.RLock() + defer s.lock.RUnlock() + + if layer < 0 || int(layer) >= len(s.trackers) { + s.logger.Errorw("unexpected layer", nil, "layer", layer) + return nil + } + return s.trackers[layer] +} + +func (s *StreamTrackerManager) SetPaused(paused bool) { + s.lock.Lock() + s.paused = paused + trackers := s.trackers + s.lock.Unlock() + + for _, tracker := range trackers { + if tracker != nil { + tracker.SetPaused(paused) + } + } +} + +func (s *StreamTrackerManager) IsPaused() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.paused +} + +func (s *StreamTrackerManager) UpdateTrackInfo(ti *livekit.TrackInfo) { + s.trackInfo.Store(utils.CloneProto(ti)) + s.maxExpectedLayerFromTrackInfo() +} + +func (s *StreamTrackerManager) SetMaxExpectedSpatialLayer(layer int32) int32 { + s.lock.Lock() + prev := s.maxExpectedLayer + if layer <= s.maxExpectedLayer { + // some higher layer(s) expected to stop, nothing else to do + s.maxExpectedLayer = layer + s.lock.Unlock() + return prev + } + + // + // Some higher layer is expected to start. + // If the layer was not detected as stopped (i.e. it is still in available layers), + // resetting tracker will declare layer available afresh. That's fine as it will be + // a no-op in available layers handling. + // + var trackersToReset []streamtracker.StreamTrackerWorker + for l := s.maxExpectedLayer + 1; l <= layer; l++ { + if s.trackers[l] != nil { + trackersToReset = append(trackersToReset, s.trackers[l]) + } + } + s.maxExpectedLayer = layer + s.lock.Unlock() + + for _, tracker := range trackersToReset { + tracker.Reset() + } + + return prev +} + +func (s *StreamTrackerManager) DistanceToDesired() float64 { + s.lock.RLock() + defer s.lock.RUnlock() + + if s.paused || s.maxExpectedLayer < 0 || s.maxTemporalLayerSeen < 0 { + return 0 + } + + al, brs := s.getLayeredBitrateLocked() + + maxLayer := buffer.InvalidLayer +done: + for s := int32(len(brs)) - 1; s >= 0; s-- { + for t := int32(len(brs[0])) - 1; t >= 0; t-- { + if brs[s][t] != 0 { + maxLayer = buffer.VideoLayer{ + Spatial: s, + Temporal: t, + } + break done + } + } + } + + // before bit rate measurement is available, stream tracker could declare layer seen, account for that + for _, layer := range al { + if layer > maxLayer.Spatial { + maxLayer.Spatial = layer + maxLayer.Temporal = s.maxTemporalLayerSeen // till bit rate measurement is available, assume max seen as temporal + } + } + + adjustedMaxLayers := maxLayer + if !maxLayer.IsValid() { + adjustedMaxLayers = buffer.VideoLayer{Spatial: 0, Temporal: 0} + } + + distance := + ((s.maxExpectedLayer - adjustedMaxLayers.Spatial) * (s.maxTemporalLayerSeen + 1)) + + (s.maxTemporalLayerSeen - adjustedMaxLayers.Temporal) + if !maxLayer.IsValid() { + distance++ + } + + return float64(distance) / float64(s.maxTemporalLayerSeen+1) +} + +func (s *StreamTrackerManager) GetMaxPublishedLayer() int32 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.maxPublishedLayer +} + +func (s *StreamTrackerManager) GetLayeredBitrate() ([]int32, Bitrates) { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.getLayeredBitrateLocked() +} + +func (s *StreamTrackerManager) getLayeredBitrateLocked() ([]int32, Bitrates) { + var br Bitrates + + for i, tracker := range s.trackers { + if tracker != nil { + tls := make([]int64, buffer.DefaultMaxLayerTemporal+1) + if slices.Contains(s.availableLayers, int32(i)) { + tls = tracker.BitrateTemporalCumulative() + } + + for j := 0; j < len(br[i]); j++ { + br[i][j] = tls[j] + } + } + } + + // accumulate bitrates for SVC streams without dependency descriptor + if s.videoLayerMode == livekit.VideoLayer_MULTIPLE_SPATIAL_LAYERS_PER_STREAM && s.ddTracker == nil { + for i := len(br) - 1; i >= 1; i-- { + for j := len(br[i]) - 1; j >= 0; j-- { + if br[i][j] != 0 { + for k := i - 1; k >= 0; k-- { + br[i][j] += br[k][j] + } + } + } + } + } + + availableLayers := make([]int32, len(s.availableLayers)) + copy(availableLayers, s.availableLayers) + + return availableLayers, br +} + +func (s *StreamTrackerManager) addAvailableLayer(layer int32) { + s.lock.Lock() + hasLayer := false + for _, l := range s.availableLayers { + if l == layer { + hasLayer = true + break + } + } + if hasLayer { + s.lock.Unlock() + return + } + + s.availableLayers = append(s.availableLayers, layer) + sort.Slice(s.availableLayers, func(i, j int) bool { return s.availableLayers[i] < s.availableLayers[j] }) + + // check if new layer is the max layer + isMaxLayerChange := s.availableLayers[len(s.availableLayers)-1] == layer + + s.logger.Debugw( + "available layers changed - layer seen", + "added", layer, + "availableLayers", s.availableLayers, + ) + s.lock.Unlock() + + if listener := s.getListener(); listener != nil { + listener.OnAvailableLayersChanged() + + if isMaxLayerChange { + listener.OnMaxAvailableLayerChanged(layer) + } + } +} + +func (s *StreamTrackerManager) removeAvailableLayer(layer int32) { + s.lock.Lock() + prevMaxLayer := buffer.InvalidLayerSpatial + if len(s.availableLayers) > 0 { + prevMaxLayer = s.availableLayers[len(s.availableLayers)-1] + } + + newLayers := make([]int32, 0, buffer.DefaultMaxLayerSpatial+1) + for _, l := range s.availableLayers { + if l != layer { + newLayers = append(newLayers, l) + } + } + sort.Slice(newLayers, func(i, j int) bool { return newLayers[i] < newLayers[j] }) + s.availableLayers = newLayers + + s.logger.Debugw( + "available layers changed - layer gone", + "removed", layer, + "availableLayers", newLayers, + ) + + curMaxLayer := buffer.InvalidLayerSpatial + if len(s.availableLayers) > 0 { + curMaxLayer = s.availableLayers[len(s.availableLayers)-1] + } + s.lock.Unlock() + + // need to immediately switch off unavailable layers + if listener := s.getListener(); listener != nil { + listener.OnAvailableLayersChanged() + + // if maxLayer was removed, send the new maxLayer + if curMaxLayer != prevMaxLayer { + listener.OnMaxAvailableLayerChanged(curMaxLayer) + } + } +} + +func (s *StreamTrackerManager) maxExpectedLayerFromTrackInfo() { + s.lock.Lock() + defer s.lock.Unlock() + + s.maxExpectedLayerFromTrackInfoLocked() +} + +func (s *StreamTrackerManager) maxExpectedLayerFromTrackInfoLocked() { + s.maxExpectedLayer = buffer.InvalidLayerSpatial + ti := s.trackInfo.Load() + if ti != nil { + for _, layer := range buffer.GetVideoLayersForMimeType(s.mimeType, ti) { + if layer.SpatialLayer > s.maxExpectedLayer { + s.maxExpectedLayer = layer.SpatialLayer + } + } + } +} + +func (s *StreamTrackerManager) GetMaxTemporalLayerSeen() int32 { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.maxTemporalLayerSeen +} + +func (s *StreamTrackerManager) updateMaxTemporalLayerSeen(brs Bitrates) { + maxTemporalLayerSeen := buffer.InvalidLayerTemporal +done: + for t := int32(len(brs[0])) - 1; t >= 0; t-- { + for s := int32(len(brs)) - 1; s >= 0; s-- { + if brs[s][t] != 0 { + maxTemporalLayerSeen = t + break done + } + } + } + + s.lock.Lock() + if maxTemporalLayerSeen <= s.maxTemporalLayerSeen { + s.lock.Unlock() + return + } + + s.maxTemporalLayerSeen = maxTemporalLayerSeen + s.lock.Unlock() + + if listener := s.getListener(); listener != nil { + listener.OnMaxTemporalLayerSeenChanged(maxTemporalLayerSeen) + } +} + +func (s *StreamTrackerManager) bitrateReporter() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-s.closed.Watch(): + return + + case <-ticker.C: + al, brs := s.GetLayeredBitrate() + s.updateMaxTemporalLayerSeen(brs) + + if listener := s.getListener(); listener != nil { + listener.OnBitrateReport(al, brs) + } + } + } +} diff --git a/livekit/pkg/sfu/testutils/data.go b/livekit/pkg/sfu/testutils/data.go new file mode 100644 index 0000000..c283549 --- /dev/null +++ b/livekit/pkg/sfu/testutils/data.go @@ -0,0 +1,108 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import ( + "time" + + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" +) + +// ----------------------------------------------------------- + +type TestExtPacketParams struct { + Marker bool + IsKeyFrame bool + PayloadType uint8 + SequenceNumber uint16 + SNCycles int + Timestamp uint32 + TSCycles int + SSRC uint32 + PayloadSize int + PaddingSize byte + ArrivalTime time.Time + VideoLayer buffer.VideoLayer + IsOutOfOrder bool +} + +// ----------------------------------------------------------- + +func GetTestExtPacket(params *TestExtPacketParams) (*buffer.ExtPacket, error) { + packet := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Padding: params.PaddingSize != 0, + Marker: params.Marker, + PayloadType: params.PayloadType, + SequenceNumber: params.SequenceNumber, + Timestamp: params.Timestamp, + SSRC: params.SSRC, + }, + Payload: make([]byte, params.PayloadSize), + PaddingSize: params.PaddingSize, + } + + raw, err := packet.Marshal() + if err != nil { + return nil, err + } + + ep := &buffer.ExtPacket{ + VideoLayer: params.VideoLayer, + ExtSequenceNumber: uint64(params.SNCycles<<16) + uint64(params.SequenceNumber), + ExtTimestamp: uint64(params.TSCycles<<32) + uint64(params.Timestamp), + Arrival: params.ArrivalTime.UnixNano(), + Packet: &packet, + IsKeyFrame: params.IsKeyFrame, + RawPacket: raw, + IsOutOfOrder: params.IsOutOfOrder, + } + + return ep, nil +} + +// -------------------------------------- + +func GetTestExtPacketVP8(params *TestExtPacketParams, vp8 *buffer.VP8) (*buffer.ExtPacket, error) { + ep, err := GetTestExtPacket(params) + if err != nil { + return nil, err + } + + ep.IsKeyFrame = vp8.IsKeyFrame + ep.Payload = *vp8 + if ep.DependencyDescriptor == nil { + ep.Temporal = int32(vp8.TID) + } + return ep, nil +} + +// -------------------------------------- + +var TestVP8Codec = webrtc.RTPCodecCapability{ + MimeType: "video/vp8", + ClockRate: 90000, +} + +var TestOpusCodec = webrtc.RTPCodecCapability{ + MimeType: "audio/opus", + ClockRate: 48000, +} + +// -------------------------------------- diff --git a/livekit/pkg/sfu/track_remote.go b/livekit/pkg/sfu/track_remote.go new file mode 100644 index 0000000..70b08a9 --- /dev/null +++ b/livekit/pkg/sfu/track_remote.go @@ -0,0 +1,52 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sfu + +import "github.com/pion/webrtc/v4" + +type TrackRemote interface { + ID() string + RID() string + Msid() string + SSRC() webrtc.SSRC + RtxSSRC() webrtc.SSRC + StreamID() string + Kind() webrtc.RTPCodecType + Codec() webrtc.RTPCodecParameters + RTCTrack() *webrtc.TrackRemote +} + +// TrackRemoteFromSdp represents a remote track that could be created by the sdp. +// It is a wrapper around the webrtc.TrackRemote and return the Codec from sdp +// before the first RTP packet is received. +type TrackRemoteFromSdp struct { + *webrtc.TrackRemote + sdpCodec webrtc.RTPCodecParameters +} + +func NewTrackRemoteFromSdp(track *webrtc.TrackRemote, codec webrtc.RTPCodecParameters) *TrackRemoteFromSdp { + return &TrackRemoteFromSdp{ + TrackRemote: track, + sdpCodec: codec, + } +} + +func (t *TrackRemoteFromSdp) Codec() webrtc.RTPCodecParameters { + return t.sdpCodec +} + +func (t *TrackRemoteFromSdp) RTCTrack() *webrtc.TrackRemote { + return t.TrackRemote +} diff --git a/livekit/pkg/sfu/utils/debounce.go b/livekit/pkg/sfu/utils/debounce.go new file mode 100644 index 0000000..5fab42b --- /dev/null +++ b/livekit/pkg/sfu/utils/debounce.go @@ -0,0 +1,34 @@ +package utils + +import ( + "sync" + "time" +) + +func NewDebouncer(after time.Duration) *Debouncer { + return &Debouncer{ + after: after, + } +} + +type Debouncer struct { + mu sync.Mutex + after time.Duration + timer *time.Timer +} + +func (d *Debouncer) Add(f func()) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil { + d.timer.Stop() + } + d.timer = time.AfterFunc(d.after, f) +} + +func (d *Debouncer) SetDuration(after time.Duration) { + d.mu.Lock() + d.after = after + d.mu.Unlock() +} diff --git a/livekit/pkg/sfu/utils/downtrackspreader.go b/livekit/pkg/sfu/utils/downtrackspreader.go new file mode 100644 index 0000000..996889c --- /dev/null +++ b/livekit/pkg/sfu/utils/downtrackspreader.go @@ -0,0 +1,121 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" +) + +type sender interface { + SubscriberID() livekit.ParticipantID +} + +type DownTrackSpreaderParams struct { + Threshold int + Logger logger.Logger +} + +type DownTrackSpreader[T sender] struct { + params DownTrackSpreaderParams + + downTrackMu sync.RWMutex + downTracks map[livekit.ParticipantID]T + downTracksShadow []T +} + +func NewDownTrackSpreader[T sender](params DownTrackSpreaderParams) *DownTrackSpreader[T] { + d := &DownTrackSpreader[T]{ + params: params, + downTracks: make(map[livekit.ParticipantID]T), + } + + return d +} + +func (d *DownTrackSpreader[T]) GetDownTracks() []T { + d.downTrackMu.RLock() + defer d.downTrackMu.RUnlock() + return d.downTracksShadow +} + +func (d *DownTrackSpreader[T]) ResetAndGetDownTracks() []T { + d.downTrackMu.Lock() + defer d.downTrackMu.Unlock() + + downTracks := d.downTracksShadow + + d.downTracks = make(map[livekit.ParticipantID]T) + d.downTracksShadow = nil + + return downTracks +} + +func (d *DownTrackSpreader[T]) Store(sender T) { + d.downTrackMu.Lock() + defer d.downTrackMu.Unlock() + + d.downTracks[sender.SubscriberID()] = sender + d.shadowDownTracks() +} + +func (d *DownTrackSpreader[T]) Free(subscriberID livekit.ParticipantID) { + d.downTrackMu.Lock() + defer d.downTrackMu.Unlock() + + delete(d.downTracks, subscriberID) + d.shadowDownTracks() +} + +func (d *DownTrackSpreader[T]) HasDownTrack(subscriberID livekit.ParticipantID) bool { + d.downTrackMu.RLock() + defer d.downTrackMu.RUnlock() + + _, ok := d.downTracks[subscriberID] + return ok +} + +func (d *DownTrackSpreader[T]) Broadcast(writer func(T)) { + downTracks := d.GetDownTracks() + if len(downTracks) == 0 { + return + } + + threshold := uint64(d.params.Threshold) + if threshold == 0 { + threshold = 1000000 + } + + // 100µs is enough to amortize the overhead and provide sufficient load balancing. + // WriteRTP takes about 50µs on average, so we write to 2 down tracks per loop. + step := uint64(2) + utils.ParallelExec(downTracks, threshold, step, writer) +} + +func (d *DownTrackSpreader[T]) DownTrackCount() int { + d.downTrackMu.RLock() + defer d.downTrackMu.RUnlock() + return len(d.downTracksShadow) +} + +func (d *DownTrackSpreader[T]) shadowDownTracks() { + d.downTracksShadow = make([]T, 0, len(d.downTracks)) + for _, dt := range d.downTracks { + d.downTracksShadow = append(d.downTracksShadow, dt) + } +} diff --git a/livekit/pkg/sfu/utils/helpers.go b/livekit/pkg/sfu/utils/helpers.go new file mode 100644 index 0000000..bfd4019 --- /dev/null +++ b/livekit/pkg/sfu/utils/helpers.go @@ -0,0 +1,97 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "errors" + "fmt" + + "github.com/pion/interceptor" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/livekit" +) + +// Do a fuzzy find for a codec in the list of codecs +// Used for lookup up a codec in an existing list to find a match +func CodecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) { + // First attempt to match on MimeType + SDPFmtpLine + for _, c := range haystack { + if mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) && + c.RTPCodecCapability.SDPFmtpLine == needle.RTPCodecCapability.SDPFmtpLine { + return c, nil + } + } + + // Fallback to just MimeType + for _, c := range haystack { + if mime.IsMimeTypeStringEqual(c.RTPCodecCapability.MimeType, needle.RTPCodecCapability.MimeType) { + return c, nil + } + } + + return webrtc.RTPCodecParameters{}, webrtc.ErrCodecNotFound +} + +// Given a CodecParameters find the RTX CodecParameters if one exists +func FindRTXPayloadType(needle webrtc.PayloadType, haystack []webrtc.RTPCodecParameters) webrtc.PayloadType { + aptStr := fmt.Sprintf("apt=%d", needle) + for _, c := range haystack { + if aptStr == c.SDPFmtpLine { + return c.PayloadType + } + } + + return webrtc.PayloadType(0) +} + +// GetHeaderExtensionID returns the ID of a header extension, or 0 if not found +func GetHeaderExtensionID(extensions []interceptor.RTPHeaderExtension, extension webrtc.RTPHeaderExtensionCapability) int { + for _, h := range extensions { + if extension.URI == h.URI { + return h.ID + } + } + return 0 +} + +var ( + ErrInvalidRTPVersion = errors.New("invalid RTP version") + ErrRTPPayloadTypeMismatch = errors.New("RTP payload type mismatch") + ErrRTPSSRCMismatch = errors.New("RTP SSRC mismatch") +) + +// ValidateRTPPacket checks for a valid RTP packet and returns an error if fields are incorrect +func ValidateRTPPacket(pkt *rtp.Packet, expectedPayloadType uint8, expectedSSRC uint32) error { + if pkt.Version != 2 { + return fmt.Errorf("%w, expected: 2, actual: %d", ErrInvalidRTPVersion, pkt.Version) + } + + if expectedPayloadType != 0 && pkt.PayloadType != expectedPayloadType { + return fmt.Errorf("%w, expected: %d, actual: %d", ErrRTPPayloadTypeMismatch, expectedPayloadType, pkt.PayloadType) + } + + if expectedSSRC != 0 && pkt.SSRC != expectedSSRC { + return fmt.Errorf("%w, expected: %d, actual: %d", ErrRTPSSRCMismatch, expectedSSRC, pkt.SSRC) + } + + return nil +} + +func IsSimulcastMode(m livekit.VideoLayer_Mode) bool { + return m == livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM || m == livekit.VideoLayer_ONE_SPATIAL_LAYER_PER_STREAM_INCOMPLETE_RTCP_SR +} diff --git a/livekit/pkg/sfu/utils/rangemap.go b/livekit/pkg/sfu/utils/rangemap.go new file mode 100644 index 0000000..8fb36d5 --- /dev/null +++ b/livekit/pkg/sfu/utils/rangemap.go @@ -0,0 +1,200 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "errors" + "fmt" + "math" + "unsafe" + + "go.uber.org/zap/zapcore" +) + +const ( + minRanges = 1 +) + +var ( + errReversedOrder = errors.New("end <= start") + errKeyNotFound = errors.New("key not found") + errKeyTooOld = errors.New("key too old") + errKeyExcluded = errors.New("key excluded") +) + +type rangeType interface { + uint32 | uint64 +} + +type valueType interface { + uint32 | uint64 +} + +// --------------------------------------------------- + +type rangeVal[RT rangeType, VT valueType] struct { + start RT + end RT + value VT +} + +func (r rangeVal[RT, VT]) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddUint64("start", uint64(r.start)) + e.AddUint64("end", uint64(r.end)) + e.AddUint64("value", uint64(r.value)) + return nil +} + +// --------------------------------------------------- + +type RangeMap[RT rangeType, VT valueType] struct { + halfRange RT + + size int + ranges []rangeVal[RT, VT] +} + +func NewRangeMap[RT rangeType, VT valueType](size int) *RangeMap[RT, VT] { + var t RT + r := &RangeMap[RT, VT]{ + halfRange: 1 << ((unsafe.Sizeof(t) * 8) - 1), + size: int(math.Max(float64(size), float64(minRanges))), + } + r.initRanges(0, 0) + return r +} + +func (r *RangeMap[RT, VT]) MarshalLogObject(e zapcore.ObjectEncoder) error { + e.AddInt("numRanges", len(r.ranges)) + + // just the last 10 ranges max + startIdx := max(len(r.ranges)-10, 0) + for i := startIdx; i < len(r.ranges); i++ { + e.AddObject(fmt.Sprintf("range[%d]", i), r.ranges[i]) + } + + return nil +} + +func (r *RangeMap[RT, VT]) ClearAndResetValue(start RT, val VT) { + r.initRanges(start, val) +} + +func (r *RangeMap[RT, VT]) DecValue(end RT, dec VT) { + lr := &r.ranges[len(r.ranges)-1] + if lr.start > end { + // modify existing value if end is in open range + lr.value -= dec + return + } + + // close open range + lr.end = end + + // start a new open one with decremented value + r.ranges = append(r.ranges, rangeVal[RT, VT]{ + start: end + 1, + end: 0, + value: lr.value - dec, + }) + r.prune() +} + +func (r *RangeMap[RT, VT]) initRanges(start RT, val VT) { + r.ranges = []rangeVal[RT, VT]{ + { + start: start, + end: 0, + value: val, + }, + } +} + +func (r *RangeMap[RT, VT]) ExcludeRange(startInclusive RT, endExclusive RT) error { + if endExclusive == startInclusive || endExclusive-startInclusive > r.halfRange { + return fmt.Errorf("%w, start %d, end %d", errReversedOrder, startInclusive, endExclusive) + } + + lr := &r.ranges[len(r.ranges)-1] + if lr.start > startInclusive { + // start of open range is after start of exclusion range, cannot close the open range + return fmt.Errorf("%w, existingStart %d, newStart %d", errReversedOrder, lr.start, startInclusive) + } + + newValue := lr.value + VT(endExclusive-startInclusive) + + // if start of exclusion range matches start of open range, move the open range + if lr.start == startInclusive { + lr.start = endExclusive + lr.value = newValue + return nil + } + + // close previous range + lr.end = startInclusive - 1 + + // start new open one after given exclusion range + r.ranges = append(r.ranges, rangeVal[RT, VT]{ + start: endExclusive, + end: 0, + value: newValue, + }) + + r.prune() + return nil +} + +func (r *RangeMap[RT, VT]) GetValue(key RT) (VT, error) { + numRanges := len(r.ranges) + if numRanges != 0 { + if key >= r.ranges[numRanges-1].start { + // in the open range + return r.ranges[numRanges-1].value, nil + } + + if key < r.ranges[0].start { + // too old + return 0, errKeyTooOld + } + } + + for idx := numRanges - 1; idx >= 0; idx-- { + rv := &r.ranges[idx] + if idx != numRanges-1 { + // open range checked above + if key-rv.start < r.halfRange && rv.end-key < r.halfRange { + return rv.value, nil + } + } + + if idx > 0 { + rvPrev := &r.ranges[idx-1] + beforeDiff := key - rvPrev.end + afterDiff := rv.start - key + if beforeDiff > 0 && beforeDiff < r.halfRange && afterDiff > 0 && afterDiff < r.halfRange { + // in excluded range + return 0, errKeyExcluded + } + } + } + + return 0, errKeyNotFound +} + +func (r *RangeMap[RT, VT]) prune() { + if len(r.ranges) > r.size+1 { // +1 to accommodate the open range + r.ranges = r.ranges[len(r.ranges)-r.size-1:] + } +} diff --git a/livekit/pkg/sfu/utils/rangemap_test.go b/livekit/pkg/sfu/utils/rangemap_test.go new file mode 100644 index 0000000..657481f --- /dev/null +++ b/livekit/pkg/sfu/utils/rangemap_test.go @@ -0,0 +1,361 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRangeMapUint32(t *testing.T) { + r := NewRangeMap[uint32, uint32](2) + + // getting value for any key should be 0 default + value, err := r.GetValue(33333) + require.NoError(t, err) + require.Equal(t, uint32(0), value) + + expectedRanges := []rangeVal[uint32, uint32]{ + { + start: 0, + end: 0, + value: 0, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // add an exclusion, should create a new range + err = r.ExcludeRange(10, 11) + require.NoError(t, err) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 0, + end: 9, + value: 0, + }, + { + start: 11, + end: 0, + value: 1, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // getting value in old range should return 0 + value, err = r.GetValue(6) + require.NoError(t, err) + require.Equal(t, uint32(0), value) + + // newer should return 1 + value, err = r.GetValue(11) + require.NoError(t, err) + require.Equal(t, uint32(1), value) + + // excluded range should return error + value, err = r.GetValue(10) + require.ErrorIs(t, err, errKeyExcluded) + + // out-of-order exclusion should return error + err = r.ExcludeRange(9, 10) + require.ErrorIs(t, err, errReversedOrder) + + // flipped exclusion should return error + err = r.ExcludeRange(12, 11) + require.ErrorIs(t, err, errReversedOrder) + err = r.ExcludeRange(11, 11) + require.ErrorIs(t, err, errReversedOrder) + + // add adjacent exclusion range of length = 1 + err = r.ExcludeRange(11, 12) + require.NoError(t, err) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 0, + end: 9, + value: 0, + }, + { + start: 12, + end: 0, + value: 2, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // excluded range should return error, now is excluded because exclusion range could be extended + value, err = r.GetValue(11) + require.ErrorIs(t, err, errKeyExcluded) + + // getting value in old range should return 0 + value, err = r.GetValue(6) + require.NoError(t, err) + + // newer should return 2 + value, err = r.GetValue(12) + require.NoError(t, err) + require.Equal(t, uint32(2), value) + + // add adjacent exclusion range of length = 10 + err = r.ExcludeRange(12, 22) + require.NoError(t, err) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 0, + end: 9, + value: 0, + }, + { + start: 22, + end: 0, + value: 12, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // excluded range should return error, now is excluded because exclusion range could be extended + value, err = r.GetValue(15) + require.ErrorIs(t, err, errKeyExcluded) + + // newer should return 12 + value, err = r.GetValue(25) + require.NoError(t, err) + require.Equal(t, uint32(12), value) + + // add a disjoint exclusion of length = 4 + err = r.ExcludeRange(26, 30) + require.NoError(t, err) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 0, + end: 9, + value: 0, + }, + { + start: 22, + end: 25, + value: 12, + }, + { + start: 30, + end: 0, + value: 16, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // get a value from newly closed range [22, 25] + value, err = r.GetValue(23) + require.NoError(t, err) + require.Equal(t, uint32(12), value) + + // add a disjoint exclusion of length = 1 + err = r.ExcludeRange(50, 51) + require.NoError(t, err) + + // previously first range would have been pruned due to size limitations + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 22, + end: 25, + value: 12, + }, + { + start: 30, + end: 49, + value: 16, + }, + { + start: 51, + end: 0, + value: 17, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // excluded range should return error + value, err = r.GetValue(50) + require.ErrorIs(t, err, errKeyExcluded) + value, err = r.GetValue(28) + require.ErrorIs(t, err, errKeyExcluded) + value, err = r.GetValue(17) + require.ErrorIs(t, err, errKeyTooOld) + + // previously valid, but aged out key should return error + value, err = r.GetValue(5) + require.ErrorIs(t, err, errKeyTooOld) + + // valid range access should return values + value, err = r.GetValue(24) + require.NoError(t, err) + require.Equal(t, uint32(12), value) + + value, err = r.GetValue(34) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + value, err = r.GetValue(49) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(17), value) + + // reset + r.ClearAndResetValue(24, 23) + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 24, + end: 0, + value: 23, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(23), value) + + // decrement value and ensure that any key after start in ClearAndResetValue above returns that value + // (as given end is higher than open range start, open range should be closed and a new range added) + r.DecValue(34, 12) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 24, + end: 34, + value: 23, + }, + { + start: 35, + end: 0, + value: 11, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(11), value) + + // add an exclusion and then decrement value + err = r.ExcludeRange(40, 45) + require.NoError(t, err) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 24, + end: 34, + value: 23, + }, + { + start: 35, + end: 39, + value: 11, + }, + { + start: 45, + end: 0, + value: 16, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // before first range access + value, err = r.GetValue(5) + require.ErrorIs(t, err, errKeyTooOld) + + // first range access + value, err = r.GetValue(25) + require.NoError(t, err) + require.Equal(t, uint32(23), value) + + // second range access + value, err = r.GetValue(35) + require.NoError(t, err) + require.Equal(t, uint32(11), value) + + // open range access + value, err = r.GetValue(55555555) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + r.DecValue(66, 6) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 35, + end: 39, + value: 11, + }, + { + start: 45, + end: 66, + value: 16, + }, + { + start: 67, + end: 0, + value: 10, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // aged out range access + value, err = r.GetValue(25) + require.ErrorIs(t, err, errKeyTooOld) + + // access closed range before decrementing value + value, err = r.GetValue(66) + require.NoError(t, err) + require.Equal(t, uint32(16), value) + + // open range access + value, err = r.GetValue(67) + require.NoError(t, err) + require.Equal(t, uint32(10), value) + + // decrement with old end and check that open range gets decremented + r.DecValue(66, 6) + + expectedRanges = []rangeVal[uint32, uint32]{ + { + start: 35, + end: 39, + value: 11, + }, + { + start: 45, + end: 66, + value: 16, + }, + { + start: 67, + end: 0, + value: 4, + }, + } + require.Equal(t, expectedRanges, r.ranges) + + // open range access should get decremented value + value, err = r.GetValue(67) + require.NoError(t, err) + require.Equal(t, uint32(4), value) +} diff --git a/livekit/pkg/sfu/videolayerselector/base.go b/livekit/pkg/sfu/videolayerselector/base.go new file mode 100644 index 0000000..181e5d3 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/base.go @@ -0,0 +1,172 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/videolayerselector/temporallayerselector" + "github.com/livekit/protocol/logger" +) + +type Base struct { + logger logger.Logger + + tls temporallayerselector.TemporalLayerSelector + + maxLayer buffer.VideoLayer + maxSeenLayer buffer.VideoLayer + + targetLayer buffer.VideoLayer + previousTargetLayer buffer.VideoLayer + + requestSpatial int32 + + currentLayer buffer.VideoLayer + previousLayer buffer.VideoLayer +} + +func NewBase(logger logger.Logger) *Base { + return &Base{ + logger: logger, + maxLayer: buffer.InvalidLayer, + maxSeenLayer: buffer.InvalidLayer, + targetLayer: buffer.InvalidLayer, // start off with nothing, let streamallocator/opportunistic forwarder set the target + previousTargetLayer: buffer.InvalidLayer, + requestSpatial: buffer.InvalidLayerSpatial, + currentLayer: buffer.InvalidLayer, + previousLayer: buffer.InvalidLayer, + } +} + +func (b *Base) getBase() *Base { + return b +} + +func (b *Base) getLogger() logger.Logger { + return b.logger +} + +func (b *Base) IsOvershootOkay() bool { + return false +} + +func (b *Base) SetTemporalLayerSelector(tls temporallayerselector.TemporalLayerSelector) { + b.tls = tls +} + +func (b *Base) SetMax(maxLayer buffer.VideoLayer) { + b.maxLayer = maxLayer +} + +func (b *Base) SetMaxSpatial(layer int32) { + b.maxLayer.Spatial = layer +} + +func (b *Base) SetMaxTemporal(layer int32) { + b.maxLayer.Temporal = layer +} + +func (b *Base) GetMax() buffer.VideoLayer { + return b.maxLayer +} + +func (b *Base) SetTarget(targetLayer buffer.VideoLayer) { + b.previousTargetLayer = targetLayer + b.targetLayer = targetLayer +} + +func (b *Base) GetTarget() buffer.VideoLayer { + return b.targetLayer +} + +func (b *Base) SetRequestSpatial(layer int32) { + b.requestSpatial = layer +} + +func (b *Base) GetRequestSpatial() int32 { + return b.requestSpatial +} + +func (b *Base) CheckSync() (locked bool, layer int32) { + layer = b.GetRequestSpatial() + locked = layer == b.GetCurrent().Spatial + return +} + +func (b *Base) SetMaxSeen(maxSeenLayer buffer.VideoLayer) { + b.maxSeenLayer = maxSeenLayer +} + +func (b *Base) SetMaxSeenSpatial(layer int32) { + b.maxSeenLayer.Spatial = layer +} + +func (b *Base) SetMaxSeenTemporal(layer int32) { + b.maxSeenLayer.Temporal = layer +} + +func (b *Base) GetMaxSeen() buffer.VideoLayer { + return b.maxSeenLayer +} + +func (b *Base) SetCurrent(currentLayer buffer.VideoLayer) { + b.currentLayer = currentLayer +} + +func (b *Base) GetCurrent() buffer.VideoLayer { + return b.currentLayer +} + +func (b *Base) Select(_extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) { + return +} + +func (b *Base) Rollback() { + b.logger.Debugw( + "rolling back", + "previous", b.previousLayer, + "current", b.currentLayer, + "previousTarget", b.previousTargetLayer, + "target", b.targetLayer, + "max", b.maxLayer, + "req", b.requestSpatial, + "maxSeen", b.maxSeenLayer, + ) + b.currentLayer = b.previousLayer + b.targetLayer = b.previousTargetLayer +} + +func (b *Base) SelectTemporal(extPkt *buffer.ExtPacket) int32 { + if b.tls != nil { + this, next := b.tls.Select(extPkt, b.currentLayer.Temporal, b.targetLayer.Temporal) + if next != b.currentLayer.Temporal { + previousLayer := b.currentLayer + b.currentLayer.Temporal = next + + b.logger.Debugw( + "updating temporal layer", + "previous", previousLayer, + "current", b.currentLayer, + "target", b.targetLayer, + "max", b.maxLayer, + "req", b.requestSpatial, + "maxSeen", b.maxSeenLayer, + ) + } + return this + } + + return b.currentLayer.Temporal +} diff --git a/livekit/pkg/sfu/videolayerselector/decodetarget.go b/livekit/pkg/sfu/videolayerselector/decodetarget.go new file mode 100644 index 0000000..70114f2 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/decodetarget.go @@ -0,0 +1,70 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "fmt" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" +) + +type DecodeTarget struct { + buffer.DependencyDescriptorDecodeTarget + chain *FrameChain + active bool +} + +type FrameDetectionResult struct { + TargetValid bool + DTI dd.DecodeTargetIndication +} + +func NewDecodeTarget(target buffer.DependencyDescriptorDecodeTarget, chain *FrameChain) *DecodeTarget { + return &DecodeTarget{ + DependencyDescriptorDecodeTarget: target, + chain: chain, + } +} + +func (dt *DecodeTarget) Valid() bool { + return dt.chain == nil || !dt.chain.Broken() +} + +func (dt *DecodeTarget) Active() bool { + return dt.active +} + +func (dt *DecodeTarget) UpdateActive(activeBitmask uint32) { + active := (activeBitmask & (1 << dt.Target)) != 0 + dt.active = active + if dt.chain != nil { + dt.chain.UpdateActive(active) + } +} + +func (dt *DecodeTarget) OnFrame(extFrameNum uint64, fd *dd.FrameDependencyTemplate) (FrameDetectionResult, error) { + result := FrameDetectionResult{} + if len(fd.DecodeTargetIndications) <= dt.Target { + return result, fmt.Errorf("mismatch target %d and len(DecodeTargetIndications) %d", dt.Target, len(fd.DecodeTargetIndications)) + } + + result.DTI = fd.DecodeTargetIndications[dt.Target] + // The encoder can choose not to use frame chain in theory, and we need to trace every required frame is decodable in this case. + // But we don't observe this in browser and it makes no sense to not use the chain with svc, so only use chain to detect decode target broken now, + // and always return decodable if it is not protect by chain. + result.TargetValid = dt.Valid() + return result, nil +} diff --git a/livekit/pkg/sfu/videolayerselector/dependencydescriptor.go b/livekit/pkg/sfu/videolayerselector/dependencydescriptor.go new file mode 100644 index 0000000..3e2e090 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/dependencydescriptor.go @@ -0,0 +1,464 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "fmt" + "runtime/debug" + "sync" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + dede "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/protocol/logger" +) + +const ( + decisionCacheMaxElements = 256 + decisionCacheNackEntries = 80 +) + +type DependencyDescriptor struct { + *Base + + decisions *SelectorDecisionCache + + previousActiveDecodeTargetsBitmask *uint32 + activeDecodeTargetsBitmask *uint32 + structure *dede.FrameDependencyStructure + extKeyFrameNum uint64 + keyFrameValid bool + + chains []*FrameChain + + decodeTargetsLock sync.RWMutex + decodeTargets []*DecodeTarget + fnWrapper FrameNumberWrapper + + restartGeneration int +} + +func NewDependencyDescriptor(logger logger.Logger) *DependencyDescriptor { + return &DependencyDescriptor{ + Base: NewBase(logger), + decisions: NewSelectorDecisionCache(decisionCacheMaxElements, decisionCacheNackEntries), + fnWrapper: FrameNumberWrapper{logger: logger}, + } +} + +func NewDependencyDescriptorFromOther(vls VideoLayerSelector) *DependencyDescriptor { + return &DependencyDescriptor{ + Base: vls.getBase(), + decisions: NewSelectorDecisionCache(256, 80), + fnWrapper: FrameNumberWrapper{logger: vls.getLogger()}, + } +} + +func (d *DependencyDescriptor) IsOvershootOkay() bool { + return false +} + +func (d *DependencyDescriptor) Select(extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) { + // a packet is always relevant for the svc codec + if d.currentLayer.IsValid() { + result.IsRelevant = true + } + + ddwdt := extPkt.DependencyDescriptor + if ddwdt == nil { + // packet doesn't have dependency descriptor + // d.logger.Debugw(fmt.Sprintf("drop packet, no DD, incoming %v, sn: %d, isKeyFrame: %v", extPkt.VideoLayer, extPkt.Packet.SequenceNumber, extPkt.KeyFrame)) + return + } + + if ddwdt.RestartGeneration > d.restartGeneration { + d.logger.Debugw( + "stream restarted", + "packet", ddwdt.RestartGeneration, + "current", d.restartGeneration, + "structureKeyFrame", d.extKeyFrameNum, + "efn", ddwdt.ExtFrameNum, + "lastEfn", d.fnWrapper.LastOrigin(), + ) + d.restart(ddwdt.RestartGeneration) + } else if ddwdt.RestartGeneration < d.restartGeneration { + // must not happen + d.logger.Warnw("packet from old generation", nil, "packet", ddwdt.RestartGeneration, "current", d.restartGeneration) + } + + dd := ddwdt.Descriptor + + extFrameNum := ddwdt.ExtFrameNum + + fd := dd.FrameDependencies + incomingLayer := buffer.VideoLayer{ + Spatial: int32(fd.SpatialId), + Temporal: int32(fd.TemporalId), + } + + if !d.keyFrameValid && dd.AttachedStructure == nil { + // d.logger.Debugw(fmt.Sprintf("drop packet, no attached structure, incoming %v, sn: %d, isKeyFrame: %v", extPkt.VideoLayer, extPkt.Packet.SequenceNumber, extPkt.KeyFrame)) + return + } + + // early return if this frame is already forwarded or dropped + sd, err := d.decisions.GetDecision(extFrameNum) + if err != nil { + // do not mark as dropped as only error is an old frame + // d.logger.Debugw(fmt.Sprintf("drop packet on decision error, incoming %v, fn: %d/%d, sn: %d", + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, + // ), "err", err) + return + } + switch sd { + case selectorDecisionDropped: + // a packet of an alreadty dropped frame, maintain decision + // d.logger.Debugw(fmt.Sprintf("drop packet already dropped, incoming %v, fn: %d/%d, sn: %d", + // incomingLayer, + // dd.FrameNumber, + // extFrameNum, + // extPkt.Packet.SequenceNumber, + // )) + return + } + + if ddwdt.StructureUpdated { + // d.logger.Debugw("update dependency structure", + // "structureID", dd.AttachedStructure.StructureId, + // "structure", dd.AttachedStructure, + // "decodeTargets", ddwdt.DecodeTargets, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // "currentKeyframe", d.extKeyFrameNum, + // ) + + d.updateDependencyStructure(dd.AttachedStructure, ddwdt.DecodeTargets, extFrameNum) + } + + if ddwdt.ExtKeyFrameNum != d.extKeyFrameNum { + // keyframe mismatch, drop and reset chains + d.logger.Debugw("drop packet for keyframe mismatch", "incoming", incomingLayer, "efn", extFrameNum, "sn", extPkt.Packet.SequenceNumber, "requiredKeyFrame", ddwdt.ExtKeyFrameNum, "structureKeyFrame", d.extKeyFrameNum) + d.decisions.AddDropped(extFrameNum) + d.invalidateKeyFrame() + return + } + + if ddwdt.ActiveDecodeTargetsUpdated { + d.updateActiveDecodeTargets(*dd.ActiveDecodeTargetsBitmask) + } + + if len(fd.ChainDiffs) != len(d.chains) { + d.logger.Debugw("frame chain diff length mismatch", nil, + "incoming", incomingLayer, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "chainDiffs", fd.ChainDiffs, + "chains", len(d.chains), + "requiredKeyFrame", ddwdt.ExtKeyFrameNum, + "structureKeyFrame", d.extKeyFrameNum, + ) + d.decisions.AddDropped(extFrameNum) + return + } + + for _, chain := range d.chains { + chain.OnFrame(extFrameNum, fd) + } + + // find decode target closest to targetLayer + highestDecodeTarget := buffer.DependencyDescriptorDecodeTarget{ + Target: -1, + Layer: buffer.InvalidLayer, + } + var dti dede.DecodeTargetIndication + d.decodeTargetsLock.RLock() + + // decodeTargets be sorted from high to low, find the highest decode target that is active and integrity + for _, dt := range d.decodeTargets { + if !dt.Active() || dt.Layer.Spatial > d.targetLayer.Spatial || dt.Layer.Temporal > d.targetLayer.Temporal { + continue + } + + frameResult, err := dt.OnFrame(extFrameNum, fd) + if err != nil { + d.decodeTargetsLock.RUnlock() + // dtis error, dependency descriptor might lost + d.logger.Warnw(fmt.Sprintf("drop packet for frame detection error, incoming: %v", incomingLayer), err) + d.decisions.AddDropped(extFrameNum) + return + } + + if frameResult.TargetValid { + highestDecodeTarget = dt.DependencyDescriptorDecodeTarget + dti = frameResult.DTI + break + } + } + d.decodeTargetsLock.RUnlock() + + if highestDecodeTarget.Target < 0 { + // no active decode target, do not select + // d.logger.Debugw( + // "drop packet for no target found", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) + d.decisions.AddDropped(extFrameNum) + return + } + + // DD-TODO : if bandwidth in congest, could drop the 'Discardable' frame + if dti == dede.DecodeTargetNotPresent { + // d.logger.Debugw( + // "drop packet for decode target not present", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) + d.decisions.AddDropped(extFrameNum) + return + } + + // check decodability using reference frames + isDecodable := true + for _, fdiff := range fd.FrameDiffs { + if fdiff == 0 { + continue + } + + // use relaxed check for frame diff that we have chain intact detection and don't want + // to drop packet due to out-of-order packet or recoverable packet loss + if sd, _ := d.decisions.GetDecision(extFrameNum - uint64(fdiff)); sd == selectorDecisionDropped { + isDecodable = false + break + } + } + if !isDecodable { + // d.logger.Debugw( + // "drop packet for not decodable", + // "highestDecodeTarget", highestDecodeTarget, + // "decodeTargets", d.decodeTargets, + // "tagetLayer", d.targetLayer, + // "incoming", incomingLayer, + // "fn", dd.FrameNumber, + // "efn", extFrameNum, + // "sn", extPkt.Packet.SequenceNumber, + // "isKeyFrame", extPkt.KeyFrame, + // ) + d.decisions.AddDropped(extFrameNum) + return + } + + if d.currentLayer != highestDecodeTarget.Layer { + result.IsSwitching = true + if !d.currentLayer.IsValid() { + result.IsResuming = true + d.logger.Debugw( + "resuming at layer", + "current", incomingLayer, + "target", d.targetLayer, + "max", d.maxLayer, + "layer", fd.SpatialId, + "req", d.requestSpatial, + "maxSeen", d.maxSeenLayer, + "feed", extPkt.Packet.SSRC, + "fn", dd.FrameNumber, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "isKeyFrame", extPkt.IsKeyFrame, + ) + } + + d.previousLayer = d.currentLayer + d.currentLayer = highestDecodeTarget.Layer + + d.previousActiveDecodeTargetsBitmask = d.activeDecodeTargetsBitmask + d.activeDecodeTargetsBitmask = buffer.GetActiveDecodeTargetBitmask(d.currentLayer, ddwdt.DecodeTargets) + d.logger.Debugw( + "switch to target", + "highestDecodeTarget", highestDecodeTarget, + "previous", d.previousLayer, + "bitmask", *d.activeDecodeTargetsBitmask, + "fn", dd.FrameNumber, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "isKeyFrame", extPkt.IsKeyFrame, + ) + + result.IsRelevant = true + } + + ddExtension := &dede.DependencyDescriptorExtension{ + Descriptor: dd, + Structure: d.structure, + } + + unWrapFn := uint16(d.fnWrapper.UpdateAndGet(extFrameNum, ddwdt.StructureUpdated)) + var ddClone *dede.DependencyDescriptor + if unWrapFn != dd.FrameNumber { + clone := *dd + ddClone = &clone + ddClone.FrameNumber = unWrapFn + ddExtension.Descriptor = ddClone + } + + if dd.AttachedStructure == nil { + if d.activeDecodeTargetsBitmask != nil { + if ddClone == nil { + // clone and override activebitmask + // DD-TODO: if the packet that contains the bitmask is acknowledged by RR, then we don't need it until it changed. + clone := *dd + ddClone = &clone + ddExtension.Descriptor = ddClone + } + ddClone.ActiveDecodeTargetsBitmask = d.activeDecodeTargetsBitmask + // d.logger.Debugw("set active decode targets bitmask", "activeDecodeTargetsBitmask", d.activeDecodeTargetsBitmask) + } + } + + var ddMarshaled bool + func() { + defer func() { + if r := recover(); r != nil { + d.logger.Errorw("panic marshalling dependency descriptor extension", nil, + "efn", extFrameNum, + "sn", extPkt.Packet.SequenceNumber, + "keyframeRequired", ddwdt.ExtKeyFrameNum, + "currentKeyframe", d.extKeyFrameNum, + "panic", r, + "stack", string(debug.Stack())) + } + }() + bytes, err := ddExtension.Marshal() + if err != nil { + d.logger.Warnw("error marshalling dependency descriptor extension", err) + } else { + result.DependencyDescriptorExtension = bytes + ddMarshaled = true + } + }() + + if !ddMarshaled { + // drop packet if we can't marshal dependency descriptor + d.decisions.AddDropped(extFrameNum) + return + } + + if ddwdt.Integrity { + d.decisions.AddForwarded(extFrameNum) + } + result.RTPMarker = extPkt.Packet.Header.Marker || (dd.LastPacketInFrame && d.currentLayer.Spatial == int32(fd.SpatialId)) + result.IsSelected = true + return +} + +func (d *DependencyDescriptor) Rollback() { + d.activeDecodeTargetsBitmask = d.previousActiveDecodeTargetsBitmask + + d.Base.Rollback() +} + +func (d *DependencyDescriptor) updateDependencyStructure(structure *dede.FrameDependencyStructure, decodeTargets []buffer.DependencyDescriptorDecodeTarget, extFrameNum uint64) { + d.structure = structure + d.extKeyFrameNum = extFrameNum + d.keyFrameValid = true + + d.chains = d.chains[:0] + + for chainIdx := 0; chainIdx < structure.NumChains; chainIdx++ { + d.chains = append(d.chains, NewFrameChain(d.decisions, chainIdx, d.logger)) + } + + newTargets := make([]*DecodeTarget, 0, len(decodeTargets)) + for _, dt := range decodeTargets { + var chain *FrameChain + // When chain_cnt > 0, each Decode target MUST be protected by exactly one Chain. + if structure.NumChains > 0 { + chainIdx := structure.DecodeTargetProtectedByChain[dt.Target] + if chainIdx >= len(d.chains) { + // should not happen + d.logger.Errorw("DecodeTargetProtectedByChain chainIdx out of range", nil, "chainIdx", chainIdx, "NumChains", len(d.chains)) + } else { + chain = d.chains[chainIdx] + } + } + newTargets = append(newTargets, NewDecodeTarget(dt, chain)) + } + d.decodeTargetsLock.Lock() + d.decodeTargets = newTargets + d.decodeTargetsLock.Unlock() +} + +func (d *DependencyDescriptor) updateActiveDecodeTargets(activeDecodeTargetsBitmask uint32) { + for _, chain := range d.chains { + chain.BeginUpdateActive() + } + + d.decodeTargetsLock.RLock() + for _, dt := range d.decodeTargets { + dt.UpdateActive(activeDecodeTargetsBitmask) + } + d.decodeTargetsLock.RUnlock() + + for _, chain := range d.chains { + chain.EndUpdateActive() + } +} + +func (d *DependencyDescriptor) invalidateKeyFrame() { + d.keyFrameValid = false + d.chains = d.chains[:0] + d.decodeTargetsLock.Lock() + d.decodeTargets = d.decodeTargets[:0] + d.decodeTargetsLock.Unlock() +} + +func (d *DependencyDescriptor) CheckSync() (locked bool, layer int32) { + layer = d.GetRequestSpatial() + if !d.currentLayer.IsValid() || !d.keyFrameValid { + // always declare not locked when trying to resume from nothing + return false, layer + } + + d.decodeTargetsLock.RLock() + defer d.decodeTargetsLock.RUnlock() + for _, dt := range d.decodeTargets { + if dt.Active() && dt.Layer.Spatial == layer && dt.Valid() { + return true, layer + } + } + + return false, layer +} + +func (d *DependencyDescriptor) restart(generation int) { + d.restartGeneration = generation + d.invalidateKeyFrame() + d.decisions = NewSelectorDecisionCache(decisionCacheMaxElements, decisionCacheNackEntries) +} diff --git a/livekit/pkg/sfu/videolayerselector/dependencydescriptor_test.go b/livekit/pkg/sfu/videolayerselector/dependencydescriptor_test.go new file mode 100644 index 0000000..91d78d8 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/dependencydescriptor_test.go @@ -0,0 +1,405 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "sort" + "testing" + + "github.com/pion/rtp" + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/protocol/logger" +) + +func TestDecodeTarget(t *testing.T) { + target := buffer.DependencyDescriptorDecodeTarget{ + Target: 1, + Layer: buffer.VideoLayer{Spatial: 1, Temporal: 2}, + } + + t.Run("No Chain", func(t *testing.T) { + dt := NewDecodeTarget(target, nil) + require.True(t, dt.Valid()) + // no indication found + _, err := dt.OnFrame(1, &dd.FrameDependencyTemplate{ + DecodeTargetIndications: []dd.DecodeTargetIndication{}, + }) + require.Error(t, err) + + ret, err := dt.OnFrame(1, &dd.FrameDependencyTemplate{ + DecodeTargetIndications: []dd.DecodeTargetIndication{dd.DecodeTargetNotPresent, dd.DecodeTargetRequired}, + }) + require.NoError(t, err) + require.True(t, ret.TargetValid) + require.Equal(t, dd.DecodeTargetRequired, ret.DTI) + }) + + t.Run("With Chain", func(t *testing.T) { + decisions := NewSelectorDecisionCache(256, 80) + chain := NewFrameChain(decisions, 1, logger.GetLogger()) + dt := NewDecodeTarget(target, chain) + chain.BeginUpdateActive() + dt.UpdateActive(1 << dt.Target) + chain.EndUpdateActive() + require.True(t, dt.Active()) + require.False(t, dt.Valid()) + + // chain intact + frame := &dd.FrameDependencyTemplate{ + DecodeTargetIndications: []dd.DecodeTargetIndication{dd.DecodeTargetNotPresent, dd.DecodeTargetRequired}, + ChainDiffs: []int{0, 0}, + } + chain.OnFrame(1, frame) + require.True(t, dt.Valid()) + ret, err := dt.OnFrame(1, frame) + require.NoError(t, err) + require.True(t, ret.TargetValid) + require.Equal(t, dd.DecodeTargetRequired, ret.DTI) + + }) +} + +func TestFrameChain(t *testing.T) { + decisions := NewSelectorDecisionCache(256, 3) + chain := NewFrameChain(decisions, 0, logger.GetLogger()) + require.True(t, chain.Broken()) + + // chain intact + frameNoDiff := &dd.FrameDependencyTemplate{ + ChainDiffs: []int{0}, + } + // not active + require.False(t, chain.OnFrame(1, frameNoDiff)) + + chain.BeginUpdateActive() + chain.UpdateActive(true) + chain.EndUpdateActive() + + require.True(t, chain.OnFrame(1, frameNoDiff)) + decisions.AddForwarded(1) + + frameDiff1 := &dd.FrameDependencyTemplate{ + ChainDiffs: []int{1}, + } + + require.True(t, chain.OnFrame(2, frameDiff1)) + decisions.AddForwarded(2) + + // frame 5 arrives first , but frame 4 can be recovered by NACK + require.True(t, chain.OnFrame(5, frameDiff1)) + decisions.AddForwarded(5) + + // frame 4 arrives, chain remains intact + require.True(t, chain.OnFrame(4, frameDiff1)) + decisions.AddForwarded(4) + + // frame 3 missed by out of nack range, chain broken + decisions.AddForwarded(7) + require.True(t, chain.Broken()) + + // recovery by non-diff frame + require.True(t, chain.OnFrame(1000, frameNoDiff)) + require.False(t, chain.Broken()) + decisions.AddForwarded(1000) + + // broken by dropped frame + require.True(t, chain.OnFrame(1002, frameDiff1)) + decisions.AddDropped(1001) + require.True(t, chain.Broken()) + + // recovery by non-diff frame + require.True(t, chain.OnFrame(2000, frameNoDiff)) + decisions.AddForwarded(2000) + decisions.AddDropped(2001) + require.False(t, chain.OnFrame(2002, frameDiff1)) + require.True(t, chain.Broken()) +} + +func TestDependencyDescriptor(t *testing.T) { + ddSelector := NewDependencyDescriptor(logger.GetLogger()) + targetLayer := buffer.VideoLayer{Spatial: 1, Temporal: 2} + ddSelector.SetTarget(targetLayer) + ddSelector.SetRequestSpatial(1) + + // no dd ext, dropped + ret := ddSelector.Select(&buffer.ExtPacket{Packet: &rtp.Packet{}}, 0) + require.False(t, ret.IsSelected) + require.False(t, ret.IsRelevant) + + // non key frame, dropped + ret = ddSelector.Select(&buffer.ExtPacket{ + IsKeyFrame: false, + DependencyDescriptor: &buffer.ExtDependencyDescriptor{ + Descriptor: &dd.DependencyDescriptor{ + FrameNumber: 1, + FrameDependencies: &dd.FrameDependencyTemplate{ + SpatialId: int(targetLayer.Spatial), + TemporalId: int(targetLayer.Temporal), + }, + }, + }, + Packet: &rtp.Packet{}, + }, 0) + require.False(t, ret.IsSelected) + require.False(t, ret.IsRelevant) + + frames := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 3) + // key frame, update structure and decode targets + ret = ddSelector.Select(frames[0], 0) + require.True(t, ret.IsSelected) + require.Equal(t, ddSelector.GetCurrent(), ddSelector.GetTarget()) + sync, _ := ddSelector.CheckSync() + require.True(t, sync) + + // forward frame belongs to target layer + // drop frame exceeds target layer (not present in target layer or lower layer) + // forward frame not present in target layer but present in lower layer + var ( + belongTargetCase bool + exceedTargetCase bool + lowerTargetCase bool + ) + idx := 1 + var frameForwarded, frameDropped []*buffer.ExtPacket + for ; idx < len(frames); idx++ { + fd := frames[idx].DependencyDescriptor.Descriptor.FrameDependencies + ret = ddSelector.Select(frames[idx], 0) + switch { + case fd.SpatialId == int(targetLayer.Spatial) && fd.TemporalId == int(targetLayer.Temporal): + require.True(t, ret.IsSelected) + belongTargetCase = true + frameForwarded = append(frameForwarded, frames[idx]) + case fd.SpatialId < int(targetLayer.Spatial) && fd.TemporalId == 0: + require.True(t, ret.IsSelected) + lowerTargetCase = true + frameForwarded = append(frameForwarded, frames[idx]) + case fd.SpatialId > int(targetLayer.Spatial) || fd.TemporalId > int(targetLayer.Temporal): + require.False(t, ret.IsSelected) + exceedTargetCase = true + frameDropped = append(frameDropped, frames[idx]) + } + + if belongTargetCase && exceedTargetCase && lowerTargetCase { + break + } + } + + require.True(t, belongTargetCase && exceedTargetCase && lowerTargetCase) + + // select frame already forwarded + ret = ddSelector.Select(frameForwarded[0], 0) + require.True(t, ret.IsSelected) + + // drop frame already dropped + ret = ddSelector.Select(frameDropped[0], 0) + require.False(t, ret.IsSelected) + + // drop frame present but not decodable (dependency frame missed) + idx++ + for ; idx < len(frames); idx++ { + fd := frames[idx].DependencyDescriptor.Descriptor.FrameDependencies + ret = ddSelector.Select(frames[idx], 0) + if fd.SpatialId == int(targetLayer.Spatial) && fd.TemporalId == int(targetLayer.Temporal) { + break + } + } + notDecodableFrame := frames[idx] + notDecodableFrame.DependencyDescriptor.Descriptor.FrameDependencies.FrameDiffs = []int{ + int(notDecodableFrame.DependencyDescriptor.Descriptor.FrameNumber - frameDropped[0].DependencyDescriptor.Descriptor.FrameNumber), + } + ret = ddSelector.Select(notDecodableFrame, 0) + require.False(t, ret.IsSelected) + + // target layer broken + idx++ + for ; idx < len(frames); idx++ { + fd := frames[idx].DependencyDescriptor.Descriptor.FrameDependencies + ret = ddSelector.Select(frames[idx], 0) + if fd.SpatialId == int(targetLayer.Spatial) && fd.TemporalId == int(targetLayer.Temporal) { + break + } + } + brokenFrame := frames[idx] + brokenFrame.DependencyDescriptor.Descriptor.FrameDependencies.ChainDiffs[targetLayer.Spatial] = + int(notDecodableFrame.DependencyDescriptor.Descriptor.FrameNumber - frameDropped[0].DependencyDescriptor.Descriptor.FrameNumber) + ret = ddSelector.Select(brokenFrame, 0) + require.False(t, ret.IsSelected) + + // switch to lower layer, forward frame + idx++ + var switchToLower bool + for ; idx < len(frames); idx++ { + ret = ddSelector.Select(frames[idx], 0) + if ret.IsSelected { + require.True(t, targetLayer.GreaterThan(ddSelector.GetCurrent())) + switchToLower = true + break + } + } + require.True(t, switchToLower) + + // not sync with requested layer + ddSelector.SetRequestSpatial(targetLayer.Spatial) + locked, layer := ddSelector.CheckSync() + require.False(t, locked) + require.Equal(t, targetLayer.Spatial, layer) + // request to current layer, sync + ddSelector.SetRequestSpatial(ddSelector.GetCurrent().Spatial) + locked, _ = ddSelector.CheckSync() + require.True(t, locked) + + // should drop frame that relies on a keyframe is not present in current selection + framesPrevious := createDDFrames(buffer.VideoLayer{Spatial: 2, Temporal: 2}, 1000) + ret = ddSelector.Select(framesPrevious[1], 0) + require.False(t, ret.IsSelected) + // keyframe lost, out of sync + locked, _ = ddSelector.CheckSync() + require.False(t, locked) +} + +func createDDFrames(maxLayer buffer.VideoLayer, startFrameNumber uint16) []*buffer.ExtPacket { + var frames []*buffer.ExtPacket + var activeBitMask uint32 + var decodeTargets []buffer.DependencyDescriptorDecodeTarget + var decodeTargetsProtectByChain []int + for i := 0; i <= int(maxLayer.Spatial); i++ { + for j := 0; j <= int(maxLayer.Temporal); j++ { + decodeTargets = append(decodeTargets, buffer.DependencyDescriptorDecodeTarget{ + Target: i*int(maxLayer.Temporal+1) + j, + Layer: buffer.VideoLayer{Spatial: int32(i), Temporal: int32(j)}, + }) + decodeTargetsProtectByChain = append(decodeTargetsProtectByChain, i) + activeBitMask |= 1 << uint(i*int(maxLayer.Temporal+1)+j) + } + } + sort.Slice(decodeTargets, func(i, j int) bool { + return decodeTargets[i].Layer.GreaterThan(decodeTargets[j].Layer) + }) + + chainDiffs := make([]int, int(maxLayer.Spatial)+1) + dtis := make([]dd.DecodeTargetIndication, len(decodeTargets)) + for _, dt := range decodeTargets { + dtis[dt.Target] = dd.DecodeTargetSwitch + } + + templates := make([]*dd.FrameDependencyTemplate, len(decodeTargets)) + + for _, dt := range decodeTargets { + templates[dt.Target] = &dd.FrameDependencyTemplate{ + SpatialId: int(dt.Layer.Spatial), + TemporalId: int(dt.Layer.Temporal), + ChainDiffs: chainDiffs, + DecodeTargetIndications: dtis, + } + } + keyFrame := &buffer.ExtPacket{ + IsKeyFrame: true, + DependencyDescriptor: &buffer.ExtDependencyDescriptor{ + Descriptor: &dd.DependencyDescriptor{ + FrameNumber: startFrameNumber, + FrameDependencies: &dd.FrameDependencyTemplate{ + SpatialId: 0, + TemporalId: 0, + ChainDiffs: chainDiffs, + DecodeTargetIndications: dtis, + }, + AttachedStructure: &dd.FrameDependencyStructure{ + NumDecodeTargets: int((maxLayer.Spatial + 1) * (maxLayer.Temporal + 1)), + NumChains: int(maxLayer.Spatial) + 1, + DecodeTargetProtectedByChain: decodeTargetsProtectByChain, + Templates: templates, + }, + ActiveDecodeTargetsBitmask: &activeBitMask, + }, + DecodeTargets: decodeTargets, + StructureUpdated: true, + ActiveDecodeTargetsUpdated: true, + Integrity: true, + ExtFrameNum: uint64(startFrameNumber), + ExtKeyFrameNum: uint64(startFrameNumber), + }, + Packet: &rtp.Packet{ + Header: rtp.Header{ + SSRC: 1234, + }, + }, + } + + frames = append(frames, keyFrame) + + chainPrevFrame := make(map[int]int) + for i := 0; i <= int(maxLayer.Spatial); i++ { + chainPrevFrame[i] = int(startFrameNumber) + } + startFrameNumber++ + for i := 0; i < 10; i++ { + for j := len(decodeTargets) - 1; j >= 0; j-- { + dt := decodeTargets[j] + frameChainDiffs := make([]int, len(chainDiffs)) + for i := range frameChainDiffs { + frameChainDiffs[i] = int(startFrameNumber) - chainPrevFrame[i] + } + + frameDtis := make([]dd.DecodeTargetIndication, len(dtis)) + for k := range frameDtis { + if k >= dt.Target { + if dt.Layer.Temporal == 0 { + frameDtis[k] = dd.DecodeTargetRequired + } else { + frameDtis[k] = dd.DecodeTargetDiscardable + } + } else { + frameDtis[k] = dd.DecodeTargetNotPresent + } + } + + frame := &buffer.ExtPacket{ + DependencyDescriptor: &buffer.ExtDependencyDescriptor{ + Descriptor: &dd.DependencyDescriptor{ + FrameNumber: startFrameNumber, + FrameDependencies: &dd.FrameDependencyTemplate{ + SpatialId: int(dt.Layer.Spatial), + TemporalId: int(dt.Layer.Temporal), + ChainDiffs: frameChainDiffs, + DecodeTargetIndications: frameDtis, + }, + }, + DecodeTargets: decodeTargets, + Integrity: true, + ExtFrameNum: uint64(startFrameNumber), + ExtKeyFrameNum: keyFrame.DependencyDescriptor.ExtFrameNum, + }, + Packet: &rtp.Packet{ + Header: rtp.Header{ + SSRC: 1234, + }, + }, + } + + startFrameNumber++ + + if dt.Layer.Temporal == 0 { + chainPrevFrame[int(dt.Layer.Spatial)] = int(startFrameNumber) + } + + frames = append(frames, frame) + } + } + + return frames +} diff --git a/livekit/pkg/sfu/videolayerselector/framechain.go b/livekit/pkg/sfu/videolayerselector/framechain.go new file mode 100644 index 0000000..19ddf7e --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/framechain.go @@ -0,0 +1,139 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + dd "github.com/livekit/livekit-server/pkg/sfu/rtpextension/dependencydescriptor" + "github.com/livekit/protocol/logger" +) + +type FrameChain struct { + logger logger.Logger + decisions *SelectorDecisionCache + broken bool + chainIdx int + active bool + updatingActive bool + + expectFrames []uint64 +} + +func NewFrameChain(decisions *SelectorDecisionCache, chainIdx int, logger logger.Logger) *FrameChain { + return &FrameChain{ + logger: logger, + decisions: decisions, + broken: true, + chainIdx: chainIdx, + active: false, + } +} + +func (fc *FrameChain) OnFrame(extFrameNum uint64, fd *dd.FrameDependencyTemplate) bool { + if !fc.active { + return false + } + + if len(fd.ChainDiffs) <= fc.chainIdx { + fc.logger.Warnw("invalid frame chain diff", nil, "chanIdx", fc.chainIdx, "frame", extFrameNum, "fd", fd) + return fc.broken + } + + // A decodable frame with frame_chain_fdiff equal to 0 indicates that the Chain is intact. + if fd.ChainDiffs[fc.chainIdx] == 0 { + if fc.broken { + fc.broken = false + // fc.logger.Debugw("frame chain intact", "chanIdx", fc.chainIdx, "frame", extFrameNum) + } + fc.expectFrames = fc.expectFrames[:0] + return true + } + + if fc.broken { + return false + } + + prevFrameInChain := extFrameNum - uint64(fd.ChainDiffs[fc.chainIdx]) + sd, err := fc.decisions.GetDecision(prevFrameInChain) + if err != nil { + fc.logger.Debugw("could not get decision", "err", err, "chanIdx", fc.chainIdx, "frame", extFrameNum, "prevFrame", prevFrameInChain) + } + + var intact bool + switch { + case sd == selectorDecisionForwarded: + intact = true + + case sd == selectorDecisionUnknown: + // If the previous frame is unknown, means it has not arrived but could be recovered by NACK / out-of-order arrival, + // set up a expected callback here to determine if the chain is broken or intact + if fc.decisions.ExpectDecision(prevFrameInChain, fc.OnExpectFrameChanged) { + intact = true + fc.expectFrames = append(fc.expectFrames, prevFrameInChain) + } + } + + if !intact { + fc.broken = true + // fc.logger.Debugw("frame chain broken", "chanIdx", fc.chainIdx, "sd", sd, "frame", extFrameNum, "prevFrame", prevFrameInChain) + } + return intact +} + +func (fc *FrameChain) OnExpectFrameChanged(frameNum uint64, decision selectorDecision) { + if fc.broken { + return + } + + for i, f := range fc.expectFrames { + if f == frameNum { + if decision != selectorDecisionForwarded { + fc.broken = true + // fc.logger.Debugw("frame chain broken", "chanIdx", fc.chainIdx, "sd", decision, "frame", frameNum) + } + fc.expectFrames[i] = fc.expectFrames[len(fc.expectFrames)-1] + fc.expectFrames = fc.expectFrames[:len(fc.expectFrames)-1] + break + } + } +} + +func (fc *FrameChain) Broken() bool { + return fc.broken +} + +func (fc *FrameChain) BeginUpdateActive() { + fc.updatingActive = false +} + +func (fc *FrameChain) UpdateActive(active bool) { + fc.updatingActive = fc.updatingActive || active +} + +func (fc *FrameChain) EndUpdateActive() { + active := fc.updatingActive + fc.updatingActive = false + + if active == fc.active { + return + } + + // if the chain transit from inactive to active, reset broken to wait a decodable SWITCH frame + if !fc.active { + fc.broken = true + fc.logger.Debugw("frame chain broken by inactive", "chanIdx", fc.chainIdx) + } + + fc.active = active +} diff --git a/livekit/pkg/sfu/videolayerselector/framenumberwrapper.go b/livekit/pkg/sfu/videolayerselector/framenumberwrapper.go new file mode 100644 index 0000000..43c9c02 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/framenumberwrapper.go @@ -0,0 +1,58 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import "github.com/livekit/protocol/logger" + +type FrameNumberWrapper struct { + offset uint64 + last uint64 + inited bool + logger logger.Logger +} + +// UpdateAndGet returns the wrapped frame number from the given frame number, and updates the offset to +// make sure the returned frame number is always inorder. Should only updateOffset if the new frame is a keyframe +// because frame dependencies uses on the frame number diff so frames inside a GOP should have the same offset. +func (f *FrameNumberWrapper) UpdateAndGet(new uint64, updateOffset bool) uint64 { + if !f.inited { + f.last = new + f.inited = true + return new + } + + if new <= f.last { + return new + f.offset + } + + if updateOffset { + new16 := uint16(new + f.offset) + last16 := uint16(f.last + f.offset) + // if new frame number wraps around and is considered as earlier by client, increase offset to make it later + if diff := new16 - last16; diff > 0x8000 || (diff == 0x8000 && new16 < last16) { + // increase offset by 6000, nearly 10 seconds for 30fps video with 3 spatial layers + prevOffset := f.offset + f.offset += uint64(65535 - diff + 6000) + + f.logger.Debugw("wrap around frame number seen, update offset", "new", new, "last", f.last, "offset", f.offset, "prevOffset", prevOffset, "lastWrapFn", last16, "newWrapFn", new16) + } + } + f.last = new + return new + f.offset +} + +func (f *FrameNumberWrapper) LastOrigin() uint64 { + return f.last +} diff --git a/livekit/pkg/sfu/videolayerselector/framenumberwrapper_test.go b/livekit/pkg/sfu/videolayerselector/framenumberwrapper_test.go new file mode 100644 index 0000000..4d0f6f2 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/framenumberwrapper_test.go @@ -0,0 +1,108 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "testing" + + "math/rand" + + "github.com/stretchr/testify/require" + + "github.com/livekit/mediatransportutil/pkg/utils" + "github.com/livekit/protocol/logger" +) + +func TestFrameNumberWrapper(t *testing.T) { + + logger.InitFromConfig(&logger.Config{Level: "debug"}, t.Name()) + + fnWrap := &FrameNumberWrapper{logger: logger.GetLogger()} + + fnWrapAround := utils.NewWrapAround[uint16, uint64](utils.WrapAroundParams{IsRestartAllowed: false}) + + firstF := uint16(1000) + + testFrameOrder := func(frame uint16, isKeyFrame bool, frame2 uint16, isKeyFrame2, expectInorder bool) { + frameUnwrap := fnWrapAround.Update(frame).ExtendedVal + wrappedFrame := uint16(fnWrap.UpdateAndGet(frameUnwrap, isKeyFrame)) + + // make sure wrap around always get in order frame number + fnWrapAround.Update(frame + (frame2-frame)/2) + + frame2Unwrap := fnWrapAround.Update(frame2).ExtendedVal + wrappedFrame2 := uint16(fnWrap.UpdateAndGet(frame2Unwrap, isKeyFrame2)) + // keeps order + require.Equal(t, expectInorder, inOrder(wrappedFrame2, wrappedFrame), "frame %d, frame2 %d, wrappedFrame %d, wrapped Frame2 %d, frameUnwrap %d, frame2Unwrap %d", frame, frame2, wrappedFrame, wrappedFrame2, frameUnwrap, frame2Unwrap) + // frame number diff should be the same if frame2 is not a key frame + if !isKeyFrame2 { + require.Equal(t, frame2-frame, wrappedFrame2-wrappedFrame) + } + } + + secondF := getFrame(firstF, true) + testFrameOrder(firstF, true, secondF, false, true) + + // non key frame keeps diff and order + for i := 0; i < 100; i++ { + // frame in order + firstF = secondF + secondF = getFrame(firstF, true) + testFrameOrder(firstF, false, secondF, false, true) + + // frame out of order + firstF = secondF + secondF = getFrame(firstF, false) + // it is possible that an out of order non-keyframe has been converted to in order frame number if the diff is 32768 + // that is ok because the client can't decode in such case and always need to wait for the key frame. + // so it is just a failure of test case and increase the frame number here. + if secondF-firstF == 0x8000 { + secondF++ + } + testFrameOrder(firstF, false, secondF, false, false) + + // key frame in order + firstF = secondF + secondF = getFrame(firstF, true) + testFrameOrder(firstF, false, secondF, true, true) + + // frame in order + firstF = secondF + secondF = getFrame(firstF, true) + testFrameOrder(firstF, false, secondF, false, true) + + // key frame out of order but should be in order after wrap around + firstF = secondF + secondF = getFrame(firstF, false) + testFrameOrder(firstF, false, secondF, true, true) + } +} + +func inOrder(a, b uint16) bool { + return a-b < 0x8000 || (a-b == 0x8000 && a > b) +} + +func getFrame(base uint16, inorder bool) uint16 { + if inorder { + return base + uint16(rand.Intn(0x8000)) + } + + for { + ret := base + uint16(rand.Intn(0x8000)) + 0x8000 + if !inOrder(ret, base) { + return ret + } + } +} diff --git a/livekit/pkg/sfu/videolayerselector/null.go b/livekit/pkg/sfu/videolayerselector/null.go new file mode 100644 index 0000000..644a8f9 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/null.go @@ -0,0 +1,29 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "github.com/livekit/protocol/logger" +) + +type Null struct { + *Base +} + +func NewNull(logger logger.Logger) *Null { + return &Null{ + Base: NewBase(logger), + } +} diff --git a/livekit/pkg/sfu/videolayerselector/selectordecisioncache.go b/livekit/pkg/sfu/videolayerselector/selectordecisioncache.go new file mode 100644 index 0000000..ea60086 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/selectordecisioncache.go @@ -0,0 +1,197 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "fmt" +) + +// ---------------------------------------------------------------------- + +type selectorDecision int + +const ( + selectorDecisionMissing selectorDecision = iota + selectorDecisionDropped + selectorDecisionForwarded + selectorDecisionUnknown +) + +func (s selectorDecision) String() string { + switch s { + case selectorDecisionMissing: + return "MISSING" + case selectorDecisionDropped: + return "DROPPED" + case selectorDecisionForwarded: + return "FORWARDED" + case selectorDecisionUnknown: + return "UNKNOWN" + default: + return fmt.Sprintf("%d", int(s)) + } +} + +// ---------------------------------------------------------------------- + +type SelectorDecisionCache struct { + initialized bool + base uint64 + last uint64 + masks []uint64 + numEntries uint64 + numNackEntries uint64 + + onExpectEntityChanged map[uint64][]func(entity uint64, decision selectorDecision) +} + +func NewSelectorDecisionCache(maxNumElements uint64, numNackEntries uint64) *SelectorDecisionCache { + numElements := (maxNumElements*2 + 63) / 64 + return &SelectorDecisionCache{ + masks: make([]uint64, numElements), + numEntries: numElements * 32, // 2 bits per entry + numNackEntries: numNackEntries, + onExpectEntityChanged: make(map[uint64][]func(entity uint64, decision selectorDecision)), + } +} + +func (s *SelectorDecisionCache) AddForwarded(entity uint64) { + s.addEntity(entity, selectorDecisionForwarded) +} + +func (s *SelectorDecisionCache) AddDropped(entity uint64) { + s.addEntity(entity, selectorDecisionDropped) +} + +func (s *SelectorDecisionCache) GetDecision(entity uint64) (selectorDecision, error) { + if !s.initialized || entity < s.base { + return selectorDecisionMissing, nil + } + + if entity > s.last { + return selectorDecisionUnknown, nil + } + + offset := s.last - entity + if offset >= s.numEntries { + // asking for something too old + return selectorDecisionMissing, fmt.Errorf("too old, oldest: %d, asking: %d", s.last-s.numEntries+1, entity) + } + + return s.getEntity(entity), nil +} + +func (s *SelectorDecisionCache) ExpectDecision(entity uint64, f func(entity uint64, decision selectorDecision)) bool { + if !s.initialized || entity < s.base { + return false + } + + if entity < s.last { + offset := s.last - entity + if offset >= s.numEntries { + return false // too old + } + } + + s.onExpectEntityChanged[entity] = append(s.onExpectEntityChanged[entity], f) + return true +} + +func (s *SelectorDecisionCache) addEntity(entity uint64, sd selectorDecision) { + if !s.initialized { + s.initialized = true + s.base = entity + s.last = entity + s.setEntity(entity, sd) + return + } + + if entity <= s.base { + // before base, too old + return + } + + if entity <= s.last { + s.setEntity(entity, sd) + return + } + + for e := s.last + 1; e != entity; e++ { + s.setEntity(e, selectorDecisionUnknown) + } + + // update [last+1-nack, entity-nack) to missing + missingStart := s.last + if missingStart > s.numNackEntries+s.base { + missingStart -= s.numNackEntries + } else { + missingStart = s.base + } + missingEnd := entity + if missingEnd > s.numNackEntries+s.base { + missingEnd -= s.numNackEntries + } else { + missingEnd = s.base + } + if missingEnd > missingStart { + for e := missingStart; e != missingEnd; e++ { + s.setEntityIfUnknown(e, selectorDecisionMissing) + } + } + + s.setEntity(entity, sd) + s.last = entity + + for e, fns := range s.onExpectEntityChanged { + if e+s.numEntries < s.last { + delete(s.onExpectEntityChanged, e) + for _, f := range fns { + f(e, selectorDecisionMissing) + } + } + } +} + +func (s *SelectorDecisionCache) setEntityIfUnknown(entity uint64, sd selectorDecision) { + if s.getEntity(entity) == selectorDecisionUnknown { + s.setEntity(entity, sd) + } +} + +func (s *SelectorDecisionCache) setEntity(entity uint64, sd selectorDecision) { + index, bitpos := s.getPos(entity) + s.masks[index] &= ^(0x3 << bitpos) // clear before bitwise OR + s.masks[index] |= (uint64(sd) & 0x3) << bitpos + + if sd != selectorDecisionUnknown { + if fns, ok := s.onExpectEntityChanged[entity]; ok { + delete(s.onExpectEntityChanged, entity) + for _, f := range fns { + f(entity, sd) + } + } + } +} + +func (s *SelectorDecisionCache) getEntity(entity uint64) selectorDecision { + index, bitpos := s.getPos(entity) + return selectorDecision((s.masks[index] >> bitpos) & 0x3) +} + +func (s *SelectorDecisionCache) getPos(entity uint64) (int, int) { + // 2 bits per entity, a uint64 mask can hold 32 entities + offset := (entity - s.base) % s.numEntries + return int(offset >> 5), int(offset&0x1F) * 2 +} diff --git a/livekit/pkg/sfu/videolayerselector/simulcast.go b/livekit/pkg/sfu/videolayerselector/simulcast.go new file mode 100644 index 0000000..664b2a7 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/simulcast.go @@ -0,0 +1,143 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" +) + +type Simulcast struct { + *Base +} + +func NewSimulcast(logger logger.Logger) *Simulcast { + return &Simulcast{ + Base: NewBase(logger), + } +} + +func NewSimulcastFromOther(vls VideoLayerSelector) *Simulcast { + switch vls := vls.(type) { + case *Null: + return &Simulcast{ + Base: vls.Base, + } + + case *Simulcast: + return &Simulcast{ + Base: vls.Base, + } + + case *DependencyDescriptor: + return &Simulcast{ + Base: vls.Base, + } + + case *VP9: + return &Simulcast{ + Base: vls.Base, + } + + default: + return nil + } +} + +func (s *Simulcast) IsOvershootOkay() bool { + return true +} + +func (s *Simulcast) Select(extPkt *buffer.ExtPacket, layer int32) (result VideoLayerSelectorResult) { + populateSwitches := func(isActive bool, reason string) { + result.IsSwitching = true + + if !isActive { + result.IsResuming = true + } + + if reason != "" { + s.logger.Debugw( + reason, + "previous", s.previousLayer, + "current", s.currentLayer, + "previousTarget", s.previousTargetLayer, + "target", s.targetLayer, + "max", s.maxLayer, + "layer", layer, + "req", s.requestSpatial, + "maxSeen", s.maxSeenLayer, + "feed", extPkt.Packet.SSRC, + ) + } + } + + if s.currentLayer.Spatial != s.targetLayer.Spatial { + currentLayer := s.currentLayer + + // Two things to check when not locked to target + // 1. Opportunistic layer upgrade - needs a key frame + // 2. Need to downgrade - needs a key frame + isActive := s.currentLayer.IsValid() + found := false + reason := "" + if extPkt.IsKeyFrame { + if layer > s.currentLayer.Spatial && layer <= s.targetLayer.Spatial { + reason = "upgrading layer" + found = true + } + + if layer < s.currentLayer.Spatial && layer >= s.targetLayer.Spatial { + reason = "downgrading layer" + found = true + } + + if found { + currentLayer.Spatial = layer + currentLayer.Temporal = extPkt.VideoLayer.Temporal + } + } + + if found { + s.previousLayer = s.currentLayer + s.currentLayer = currentLayer + + s.previousTargetLayer = s.targetLayer + if s.currentLayer.Spatial >= s.maxLayer.Spatial || s.currentLayer.Spatial == s.maxSeenLayer.Spatial { + s.targetLayer.Spatial = s.currentLayer.Spatial + } + + populateSwitches(isActive, reason) + } + } + + // if locked to higher than max layer due to overshoot, check if it can be dialed back + if s.currentLayer.Spatial > s.maxLayer.Spatial && layer <= s.maxLayer.Spatial && extPkt.IsKeyFrame { + s.previousLayer = s.currentLayer + s.currentLayer.Spatial = layer + + s.previousTargetLayer = s.targetLayer + if s.currentLayer.Spatial >= s.maxLayer.Spatial || s.currentLayer.Spatial == s.maxSeenLayer.Spatial { + s.targetLayer.Spatial = layer + } + + populateSwitches(true, "adjusting overshoot") + } + + result.RTPMarker = extPkt.Packet.Marker + result.IsSelected = layer == s.currentLayer.Spatial + result.IsRelevant = false + return +} diff --git a/livekit/pkg/sfu/videolayerselector/temporallayerselector/null.go b/livekit/pkg/sfu/videolayerselector/temporallayerselector/null.go new file mode 100644 index 0000000..51f2fb5 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/temporallayerselector/null.go @@ -0,0 +1,31 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package temporallayerselector + +import ( + "github.com/livekit/livekit-server/pkg/sfu/buffer" +) + +type Null struct{} + +func NewNull() *Null { + return &Null{} +} + +func Select(_extPkt *buffer.ExtPacket, current int32, _target int32) (this int32, next int32) { + this = current + next = current + return +} diff --git a/livekit/pkg/sfu/videolayerselector/temporallayerselector/temporallayerselector.go b/livekit/pkg/sfu/videolayerselector/temporallayerselector/temporallayerselector.go new file mode 100644 index 0000000..1d691b4 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/temporallayerselector/temporallayerselector.go @@ -0,0 +1,21 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package temporallayerselector + +import "github.com/livekit/livekit-server/pkg/sfu/buffer" + +type TemporalLayerSelector interface { + Select(extPkt *buffer.ExtPacket, current int32, target int32) (this int32, next int32) +} diff --git a/livekit/pkg/sfu/videolayerselector/temporallayerselector/vp8.go b/livekit/pkg/sfu/videolayerselector/temporallayerselector/vp8.go new file mode 100644 index 0000000..c0a86c4 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/temporallayerselector/vp8.go @@ -0,0 +1,56 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package temporallayerselector + +import ( + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" +) + +type VP8 struct { + logger logger.Logger +} + +func NewVP8(logger logger.Logger) *VP8 { + return &VP8{ + logger: logger, + } +} + +func (v *VP8) Select(extPkt *buffer.ExtPacket, current int32, target int32) (this int32, next int32) { + this = current + next = current + if current == target { + return + } + + vp8, ok := extPkt.Payload.(buffer.VP8) + if !ok { + return + } + + tid := extPkt.Temporal + if current < target { + if tid > current && tid <= target && vp8.S { + this = tid + next = tid + } + } else { + if extPkt.Packet.Marker { + next = target + } + } + return +} diff --git a/livekit/pkg/sfu/videolayerselector/videolayerselector.go b/livekit/pkg/sfu/videolayerselector/videolayerselector.go new file mode 100644 index 0000000..883e46a --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/videolayerselector.go @@ -0,0 +1,65 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/videolayerselector/temporallayerselector" + "github.com/livekit/protocol/logger" +) + +type VideoLayerSelectorResult struct { + IsSelected bool + IsRelevant bool + IsSwitching bool + IsResuming bool + RTPMarker bool + DependencyDescriptorExtension []byte +} + +type VideoLayerSelector interface { + getBase() *Base + + getLogger() logger.Logger + + IsOvershootOkay() bool + + SetTemporalLayerSelector(tls temporallayerselector.TemporalLayerSelector) + + SetMax(maxLayer buffer.VideoLayer) + SetMaxSpatial(layer int32) + SetMaxTemporal(layer int32) + GetMax() buffer.VideoLayer + + SetTarget(targetLayer buffer.VideoLayer) + GetTarget() buffer.VideoLayer + + SetRequestSpatial(layer int32) + GetRequestSpatial() int32 + + CheckSync() (locked bool, layer int32) + + SetMaxSeen(maxSeenLayer buffer.VideoLayer) + SetMaxSeenSpatial(layer int32) + SetMaxSeenTemporal(layer int32) + GetMaxSeen() buffer.VideoLayer + + SetCurrent(currentLayer buffer.VideoLayer) + GetCurrent() buffer.VideoLayer + + Select(extPkt *buffer.ExtPacket, layer int32) VideoLayerSelectorResult + SelectTemporal(extPkt *buffer.ExtPacket) int32 + Rollback() +} diff --git a/livekit/pkg/sfu/videolayerselector/vp9.go b/livekit/pkg/sfu/videolayerselector/vp9.go new file mode 100644 index 0000000..4b11d37 --- /dev/null +++ b/livekit/pkg/sfu/videolayerselector/vp9.go @@ -0,0 +1,108 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package videolayerselector + +import ( + "github.com/pion/rtp/codecs" + + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/protocol/logger" +) + +type VP9 struct { + *Base +} + +func NewVP9(logger logger.Logger) *VP9 { + return &VP9{ + Base: NewBase(logger), + } +} + +func NewVP9FromOther(vls VideoLayerSelector) *VP9 { + return &VP9{Base: vls.getBase()} +} + +func (v *VP9) IsOvershootOkay() bool { + return false +} + +func (v *VP9) Select(extPkt *buffer.ExtPacket, _layer int32) (result VideoLayerSelectorResult) { + vp9, ok := extPkt.Payload.(codecs.VP9Packet) + if !ok { + return + } + + currentLayer := v.currentLayer + if v.currentLayer != v.targetLayer { + updatedLayer := v.currentLayer + + if !v.currentLayer.IsValid() { + if !extPkt.IsKeyFrame { + return + } + + updatedLayer = extPkt.VideoLayer + } else { + if v.currentLayer.Temporal != v.targetLayer.Temporal { + if v.currentLayer.Temporal < v.targetLayer.Temporal { + // temporal scale up + if extPkt.VideoLayer.Temporal > v.currentLayer.Temporal && extPkt.VideoLayer.Temporal <= v.targetLayer.Temporal && vp9.U && vp9.B { + currentLayer.Temporal = extPkt.VideoLayer.Temporal + updatedLayer.Temporal = extPkt.VideoLayer.Temporal + } + } else { + // temporal scale down + if vp9.E { + updatedLayer.Temporal = v.targetLayer.Temporal + } + } + } + + if v.currentLayer.Spatial != v.targetLayer.Spatial { + if v.currentLayer.Spatial < v.targetLayer.Spatial { + // spatial scale up + if extPkt.VideoLayer.Spatial > v.currentLayer.Spatial && extPkt.VideoLayer.Spatial <= v.targetLayer.Spatial && !vp9.P && vp9.B { + currentLayer.Spatial = extPkt.VideoLayer.Spatial + updatedLayer.Spatial = extPkt.VideoLayer.Spatial + } + } else { + // spatial scale down + if vp9.E { + updatedLayer.Spatial = v.targetLayer.Spatial + } + } + } + } + + if updatedLayer != v.currentLayer { + result.IsSwitching = true + if !v.currentLayer.IsValid() && updatedLayer.IsValid() { + result.IsResuming = true + } + + v.previousLayer = v.currentLayer + v.currentLayer = updatedLayer + } + } + + result.RTPMarker = extPkt.Packet.Marker + if vp9.E && extPkt.VideoLayer.Spatial == currentLayer.Spatial && (vp9.P || v.targetLayer.Spatial <= v.currentLayer.Spatial) { + result.RTPMarker = true + } + result.IsSelected = !extPkt.VideoLayer.GreaterThan(currentLayer) + result.IsRelevant = true + return +} diff --git a/livekit/pkg/telemetry/analyticsservice.go b/livekit/pkg/telemetry/analyticsservice.go new file mode 100644 index 0000000..5868ca8 --- /dev/null +++ b/livekit/pkg/telemetry/analyticsservice.go @@ -0,0 +1,118 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "context" + + "go.uber.org/atomic" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" +) + +//counterfeiter:generate . AnalyticsService +type AnalyticsService interface { + SendStats(ctx context.Context, stats []*livekit.AnalyticsStat) + SendEvent(ctx context.Context, events *livekit.AnalyticsEvent) + SendNodeRoomStates(ctx context.Context, nodeRooms *livekit.AnalyticsNodeRooms) + RoomProjectReporter(ctx context.Context) roomobs.ProjectReporter +} + +// ---------------------------- + +var _ AnalyticsService = &NullAnalyticService{} + +type NullAnalyticService struct{} + +func (n NullAnalyticService) SendStats(_ context.Context, _ []*livekit.AnalyticsStat) {} +func (n NullAnalyticService) SendEvent(_ context.Context, _ *livekit.AnalyticsEvent) {} +func (n NullAnalyticService) SendNodeRoomStates(_ context.Context, _ *livekit.AnalyticsNodeRooms) {} +func (n NullAnalyticService) RoomProjectReporter(_ctx context.Context) roomobs.ProjectReporter { + return nil +} + +// ---------------------------- + +type analyticsService struct { + analyticsKey string + nodeID string + sequenceNumber atomic.Uint64 + + events rpc.AnalyticsRecorderService_IngestEventsClient + stats rpc.AnalyticsRecorderService_IngestStatsClient + nodeRooms rpc.AnalyticsRecorderService_IngestNodeRoomStatesClient +} + +func NewAnalyticsService(_ *config.Config, currentNode routing.LocalNode) AnalyticsService { + return &analyticsService{ + analyticsKey: "", // TODO: conf.AnalyticsKey + nodeID: string(currentNode.NodeID()), + } +} + +func (a *analyticsService) SendStats(_ context.Context, stats []*livekit.AnalyticsStat) { + if a.stats == nil { + return + } + + for _, stat := range stats { + stat.Id = guid.New("AS_") + stat.AnalyticsKey = a.analyticsKey + stat.Node = a.nodeID + } + if err := a.stats.Send(&livekit.AnalyticsStats{Stats: stats}); err != nil { + logger.Errorw("failed to send stats", err) + } +} + +func (a *analyticsService) SendEvent(_ context.Context, event *livekit.AnalyticsEvent) { + if a.events == nil { + return + } + + event.Id = guid.New("AE_") + event.NodeId = a.nodeID + event.AnalyticsKey = a.analyticsKey + if err := a.events.Send(&livekit.AnalyticsEvents{ + Events: []*livekit.AnalyticsEvent{event}, + }); err != nil { + logger.Errorw("failed to send event", err, "eventType", event.Type.String()) + } +} + +func (a *analyticsService) SendNodeRoomStates(_ context.Context, nodeRooms *livekit.AnalyticsNodeRooms) { + if a.nodeRooms == nil { + return + } + + nodeRooms.NodeId = a.nodeID + nodeRooms.SequenceNumber = a.sequenceNumber.Add(1) + nodeRooms.Timestamp = timestamppb.Now() + if err := a.nodeRooms.Send(nodeRooms); err != nil { + logger.Errorw("failed to send node room states", err) + } +} + +func (a *analyticsService) RoomProjectReporter(_ context.Context) roomobs.ProjectReporter { + return roomobs.NewNoopProjectReporter() +} diff --git a/livekit/pkg/telemetry/events.go b/livekit/pkg/telemetry/events.go new file mode 100644 index 0000000..a6c0064 --- /dev/null +++ b/livekit/pkg/telemetry/events.go @@ -0,0 +1,617 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "context" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/egress" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/webhook" +) + +func (t *telemetryService) NotifyEvent(ctx context.Context, event *livekit.WebhookEvent, opts ...webhook.NotifyOption) { + if t.notifier == nil { + return + } + + event.CreatedAt = time.Now().Unix() + event.Id = guid.New("EV_") + + if err := t.notifier.QueueNotify(ctx, event, opts...); err != nil { + logger.Warnw("failed to notify webhook", err, "event", event.Event) + } +} + +func (t *telemetryService) RoomStarted(ctx context.Context, room *livekit.Room) { + t.enqueue(func() { + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventRoomStarted, + Room: room, + }) + + t.SendEvent(ctx, &livekit.AnalyticsEvent{ + Type: livekit.AnalyticsEventType_ROOM_CREATED, + Timestamp: ×tamppb.Timestamp{Seconds: room.CreationTime}, + Room: room, + }) + }) +} + +func (t *telemetryService) RoomEnded(ctx context.Context, room *livekit.Room) { + t.enqueue(func() { + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventRoomFinished, + Room: room, + }) + + t.SendEvent(ctx, &livekit.AnalyticsEvent{ + Type: livekit.AnalyticsEventType_ROOM_ENDED, + Timestamp: timestamppb.Now(), + RoomId: room.Sid, + Room: room, + }) + }) +} + +func (t *telemetryService) ParticipantJoined( + ctx context.Context, + room *livekit.Room, + participant *livekit.ParticipantInfo, + clientInfo *livekit.ClientInfo, + clientMeta *livekit.AnalyticsClientMeta, + shouldSendEvent bool, + guard *ReferenceGuard, +) { + t.enqueue(func() { + _, found := t.getOrCreateWorker( + ctx, + livekit.RoomID(room.Sid), + livekit.RoomName(room.Name), + livekit.ParticipantID(participant.Sid), + livekit.ParticipantIdentity(participant.Identity), + guard, + ) + if !found { + prometheus.IncrementParticipantRtcConnected(1) + prometheus.AddParticipant() + } + + if shouldSendEvent { + ev := newParticipantEvent(livekit.AnalyticsEventType_PARTICIPANT_JOINED, room, participant) + ev.ClientInfo = clientInfo + ev.ClientMeta = clientMeta + t.SendEvent(ctx, ev) + } + }) +} + +func (t *telemetryService) ParticipantActive( + ctx context.Context, + room *livekit.Room, + participant *livekit.ParticipantInfo, + clientMeta *livekit.AnalyticsClientMeta, + isMigration bool, + guard *ReferenceGuard, +) { + t.enqueue(func() { + if !isMigration { + // a participant is considered "joined" only when they become "active" + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventParticipantJoined, + Room: room, + Participant: participant, + }) + } + + worker, found := t.getOrCreateWorker( + ctx, + livekit.RoomID(room.Sid), + livekit.RoomName(room.Name), + livekit.ParticipantID(participant.Sid), + livekit.ParticipantIdentity(participant.Identity), + guard, + ) + if !found { + // need to also account for participant count + prometheus.AddParticipant() + } + worker.SetConnected() + + ev := newParticipantEvent(livekit.AnalyticsEventType_PARTICIPANT_ACTIVE, room, participant) + ev.ClientMeta = clientMeta + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) ParticipantResumed( + ctx context.Context, + room *livekit.Room, + participant *livekit.ParticipantInfo, + nodeID livekit.NodeID, + reason livekit.ReconnectReason, +) { + t.enqueue(func() { + // create a worker if needed. + // + // Signalling channel stats collector and media channel stats collector could both call + // ParticipantJoined and ParticipantLeft. + // + // On a resume, the signalling channel collector would call `ParticipantLeft` which would close + // the corresponding participant's stats worker. + // + // So, on a successful resume, create the worker if needed. + _, found := t.getOrCreateWorker( + ctx, + livekit.RoomID(room.Sid), + livekit.RoomName(room.Name), + livekit.ParticipantID(participant.Sid), + livekit.ParticipantIdentity(participant.Identity), + nil, + ) + if !found { + prometheus.AddParticipant() + } + + ev := newParticipantEvent(livekit.AnalyticsEventType_PARTICIPANT_RESUMED, room, participant) + ev.ClientMeta = &livekit.AnalyticsClientMeta{ + Node: string(nodeID), + ReconnectReason: reason, + } + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) ParticipantLeft(ctx context.Context, + room *livekit.Room, + participant *livekit.ParticipantInfo, + shouldSendEvent bool, + guard *ReferenceGuard, +) { + t.enqueue(func() { + isConnected := false + if worker, ok := t.getWorker(livekit.ParticipantID(participant.Sid)); ok { + isConnected = worker.IsConnected() + if worker.Close(guard) { + prometheus.SubParticipant() + } + } + + if shouldSendEvent { + webhookEvent := webhook.EventParticipantLeft + analyticsEvent := livekit.AnalyticsEventType_PARTICIPANT_LEFT + if !isConnected { + webhookEvent = webhook.EventParticipantConnectionAborted + analyticsEvent = livekit.AnalyticsEventType_PARTICIPANT_CONNECTION_ABORTED + } + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhookEvent, + Room: room, + Participant: participant, + }) + + t.SendEvent(ctx, newParticipantEvent(analyticsEvent, room, participant)) + } + }) +} + +func (t *telemetryService) TrackPublishRequested( + ctx context.Context, + participantID livekit.ParticipantID, + identity livekit.ParticipantIdentity, + track *livekit.TrackInfo, +) { + t.enqueue(func() { + prometheus.RecordTrackPublishAttempt(track.Type.String()) + room := t.getRoomDetails(participantID) + ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_PUBLISH_REQUESTED, room, participantID, track) + if ev.Participant != nil { + ev.Participant.Identity = string(identity) + } + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackPublished( + ctx context.Context, + participantID livekit.ParticipantID, + identity livekit.ParticipantIdentity, + track *livekit.TrackInfo, + shouldSendEvent bool, +) { + t.enqueue(func() { + prometheus.AddPublishedTrack(track.Type.String()) + prometheus.RecordTrackPublishSuccess(track.Type.String()) + if !shouldSendEvent { + return + } + + room := t.getRoomDetails(participantID) + participant := &livekit.ParticipantInfo{ + Sid: string(participantID), + Identity: string(identity), + } + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventTrackPublished, + Room: room, + Participant: participant, + Track: track, + }) + + ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_PUBLISHED, room, participantID, track) + ev.Participant = participant + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackPublishedUpdate(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) { + t.enqueue(func() { + room := t.getRoomDetails(participantID) + t.SendEvent(ctx, newTrackEvent(livekit.AnalyticsEventType_TRACK_PUBLISHED_UPDATE, room, participantID, track)) + }) +} + +func (t *telemetryService) TrackMaxSubscribedVideoQuality( + ctx context.Context, + participantID livekit.ParticipantID, + track *livekit.TrackInfo, + mime mime.MimeType, + maxQuality livekit.VideoQuality, +) { + t.enqueue(func() { + room := t.getRoomDetails(participantID) + ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_MAX_SUBSCRIBED_VIDEO_QUALITY, room, participantID, track) + ev.MaxSubscribedVideoQuality = maxQuality + ev.Mime = mime.String() + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackSubscribeRequested( + ctx context.Context, + participantID livekit.ParticipantID, + track *livekit.TrackInfo, +) { + t.enqueue(func() { + prometheus.RecordTrackSubscribeAttempt() + + room := t.getRoomDetails(participantID) + ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_SUBSCRIBE_REQUESTED, room, participantID, track) + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackSubscribed( + ctx context.Context, + participantID livekit.ParticipantID, + track *livekit.TrackInfo, + publisher *livekit.ParticipantInfo, + shouldSendEvent bool, +) { + t.enqueue(func() { + prometheus.RecordTrackSubscribeSuccess(track.Type.String()) + + if !shouldSendEvent { + return + } + + room := t.getRoomDetails(participantID) + ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_SUBSCRIBED, room, participantID, track) + ev.Publisher = publisher + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackSubscribeFailed( + ctx context.Context, + participantID livekit.ParticipantID, + trackID livekit.TrackID, + err error, + isUserError bool, +) { + t.enqueue(func() { + prometheus.RecordTrackSubscribeFailure(err, isUserError) + + room := t.getRoomDetails(participantID) + ev := newTrackEvent(livekit.AnalyticsEventType_TRACK_SUBSCRIBE_FAILED, room, participantID, &livekit.TrackInfo{ + Sid: string(trackID), + }) + ev.Error = err.Error() + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackUnsubscribed( + ctx context.Context, + participantID livekit.ParticipantID, + track *livekit.TrackInfo, + shouldSendEvent bool, +) { + t.enqueue(func() { + prometheus.RecordTrackUnsubscribed(track.Type.String()) + + if shouldSendEvent { + room := t.getRoomDetails(participantID) + t.SendEvent(ctx, newTrackEvent(livekit.AnalyticsEventType_TRACK_UNSUBSCRIBED, room, participantID, track)) + } + }) +} + +func (t *telemetryService) TrackUnpublished( + ctx context.Context, + participantID livekit.ParticipantID, + identity livekit.ParticipantIdentity, + track *livekit.TrackInfo, + shouldSendEvent bool, +) { + t.enqueue(func() { + prometheus.SubPublishedTrack(track.Type.String()) + if !shouldSendEvent { + return + } + + room := t.getRoomDetails(participantID) + participant := &livekit.ParticipantInfo{ + Sid: string(participantID), + Identity: string(identity), + } + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventTrackUnpublished, + Room: room, + Participant: participant, + Track: track, + }) + + t.SendEvent(ctx, newTrackEvent(livekit.AnalyticsEventType_TRACK_UNPUBLISHED, room, participantID, track)) + }) +} + +func (t *telemetryService) TrackMuted( + ctx context.Context, + participantID livekit.ParticipantID, + track *livekit.TrackInfo, +) { + t.enqueue(func() { + room := t.getRoomDetails(participantID) + t.SendEvent(ctx, newTrackEvent(livekit.AnalyticsEventType_TRACK_MUTED, room, participantID, track)) + }) +} + +func (t *telemetryService) TrackUnmuted( + ctx context.Context, + participantID livekit.ParticipantID, + track *livekit.TrackInfo, +) { + t.enqueue(func() { + room := t.getRoomDetails(participantID) + t.SendEvent(ctx, newTrackEvent(livekit.AnalyticsEventType_TRACK_UNMUTED, room, participantID, track)) + }) +} + +func (t *telemetryService) TrackPublishRTPStats( + ctx context.Context, + participantID livekit.ParticipantID, + trackID livekit.TrackID, + mimeType mime.MimeType, + layer int, + stats *livekit.RTPStats, +) { + t.enqueue(func() { + room := t.getRoomDetails(participantID) + ev := newRoomEvent(livekit.AnalyticsEventType_TRACK_PUBLISH_STATS, room) + ev.ParticipantId = string(participantID) + ev.TrackId = string(trackID) + ev.Mime = mimeType.String() + ev.VideoLayer = int32(layer) + ev.RtpStats = stats + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) TrackSubscribeRTPStats( + ctx context.Context, + participantID livekit.ParticipantID, + trackID livekit.TrackID, + mimeType mime.MimeType, + stats *livekit.RTPStats, +) { + t.enqueue(func() { + room := t.getRoomDetails(participantID) + ev := newRoomEvent(livekit.AnalyticsEventType_TRACK_SUBSCRIBE_STATS, room) + ev.ParticipantId = string(participantID) + ev.TrackId = string(trackID) + ev.Mime = mimeType.String() + ev.RtpStats = stats + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) NotifyEgressEvent(ctx context.Context, event string, info *livekit.EgressInfo) { + opts := egress.GetEgressNotifyOptions(info) + + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: event, + EgressInfo: info, + }, opts...) +} + +func (t *telemetryService) EgressStarted(ctx context.Context, info *livekit.EgressInfo) { + + t.enqueue(func() { + t.NotifyEgressEvent(ctx, webhook.EventEgressStarted, info) + + t.SendEvent(ctx, newEgressEvent(livekit.AnalyticsEventType_EGRESS_STARTED, info)) + }) +} + +func (t *telemetryService) EgressUpdated(ctx context.Context, info *livekit.EgressInfo) { + t.enqueue(func() { + t.NotifyEgressEvent(ctx, webhook.EventEgressUpdated, info) + + t.SendEvent(ctx, newEgressEvent(livekit.AnalyticsEventType_EGRESS_UPDATED, info)) + }) +} + +func (t *telemetryService) EgressEnded(ctx context.Context, info *livekit.EgressInfo) { + t.enqueue(func() { + t.NotifyEgressEvent(ctx, webhook.EventEgressEnded, info) + + t.SendEvent(ctx, newEgressEvent(livekit.AnalyticsEventType_EGRESS_ENDED, info)) + }) +} + +func (t *telemetryService) IngressCreated(ctx context.Context, info *livekit.IngressInfo) { + t.enqueue(func() { + t.SendEvent(ctx, newIngressEvent(livekit.AnalyticsEventType_INGRESS_CREATED, info)) + }) +} + +func (t *telemetryService) IngressDeleted(ctx context.Context, info *livekit.IngressInfo) { + t.enqueue(func() { + t.SendEvent(ctx, newIngressEvent(livekit.AnalyticsEventType_INGRESS_DELETED, info)) + }) +} + +func (t *telemetryService) IngressStarted(ctx context.Context, info *livekit.IngressInfo) { + t.enqueue(func() { + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventIngressStarted, + IngressInfo: info, + }) + + t.SendEvent(ctx, newIngressEvent(livekit.AnalyticsEventType_INGRESS_STARTED, info)) + }) +} + +func (t *telemetryService) IngressUpdated(ctx context.Context, info *livekit.IngressInfo) { + t.enqueue(func() { + t.SendEvent(ctx, newIngressEvent(livekit.AnalyticsEventType_INGRESS_UPDATED, info)) + }) +} + +func (t *telemetryService) IngressEnded(ctx context.Context, info *livekit.IngressInfo) { + t.enqueue(func() { + t.NotifyEvent(ctx, &livekit.WebhookEvent{ + Event: webhook.EventIngressEnded, + IngressInfo: info, + }) + + t.SendEvent(ctx, newIngressEvent(livekit.AnalyticsEventType_INGRESS_ENDED, info)) + }) +} + +func (t *telemetryService) Report(ctx context.Context, reportInfo *livekit.ReportInfo) { + t.enqueue(func() { + ev := &livekit.AnalyticsEvent{ + Type: livekit.AnalyticsEventType_REPORT, + Timestamp: timestamppb.Now(), + Report: reportInfo, + } + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) APICall(ctx context.Context, apiCallInfo *livekit.APICallInfo) { + t.enqueue(func() { + ev := &livekit.AnalyticsEvent{ + Type: livekit.AnalyticsEventType_API_CALL, + Timestamp: timestamppb.Now(), + ApiCall: apiCallInfo, + } + t.SendEvent(ctx, ev) + }) +} + +func (t *telemetryService) Webhook(ctx context.Context, webhookInfo *livekit.WebhookInfo) { + t.enqueue(func() { + ev := &livekit.AnalyticsEvent{ + Type: livekit.AnalyticsEventType_WEBHOOK, + Timestamp: timestamppb.Now(), + Webhook: webhookInfo, + } + t.SendEvent(ctx, ev) + }) +} + +// returns a livekit.Room with only name and sid filled out +// returns nil if room is not found +func (t *telemetryService) getRoomDetails(participantID livekit.ParticipantID) *livekit.Room { + if worker, ok := t.getWorker(participantID); ok { + return &livekit.Room{ + Sid: string(worker.roomID), + Name: string(worker.roomName), + } + } + + return nil +} + +func newRoomEvent(event livekit.AnalyticsEventType, room *livekit.Room) *livekit.AnalyticsEvent { + ev := &livekit.AnalyticsEvent{ + Type: event, + Timestamp: timestamppb.Now(), + } + if room != nil { + ev.Room = room + ev.RoomId = room.Sid + } + return ev +} + +func newParticipantEvent(event livekit.AnalyticsEventType, room *livekit.Room, participant *livekit.ParticipantInfo) *livekit.AnalyticsEvent { + ev := newRoomEvent(event, room) + if participant != nil { + ev.ParticipantId = participant.Sid + ev.Participant = participant + } + return ev +} + +func newTrackEvent(event livekit.AnalyticsEventType, room *livekit.Room, participantID livekit.ParticipantID, track *livekit.TrackInfo) *livekit.AnalyticsEvent { + ev := newParticipantEvent(event, room, &livekit.ParticipantInfo{ + Sid: string(participantID), + }) + if track != nil { + ev.TrackId = track.Sid + ev.Track = track + } + return ev +} + +func newEgressEvent(event livekit.AnalyticsEventType, egress *livekit.EgressInfo) *livekit.AnalyticsEvent { + return &livekit.AnalyticsEvent{ + Type: event, + Timestamp: timestamppb.Now(), + EgressId: egress.EgressId, + RoomId: egress.RoomId, + Egress: egress, + } +} + +func newIngressEvent(event livekit.AnalyticsEventType, ingress *livekit.IngressInfo) *livekit.AnalyticsEvent { + return &livekit.AnalyticsEvent{ + Type: event, + Timestamp: timestamppb.Now(), + IngressId: ingress.IngressId, + Ingress: ingress, + } +} diff --git a/livekit/pkg/telemetry/events_test.go b/livekit/pkg/telemetry/events_test.go new file mode 100644 index 0000000..0be7de1 --- /dev/null +++ b/livekit/pkg/telemetry/events_test.go @@ -0,0 +1,243 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/livekit" +) + +func Test_OnParticipantJoin_EventIsSent(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{Sid: "RoomSid", Name: "RoomName"} + partSID := "part1" + clientInfo := &livekit.ClientInfo{ + Sdk: 2, + Version: "v1", + Os: "mac", + OsVersion: "v1", + DeviceModel: "DM1", + Browser: "chrome", + BrowserVersion: "97.0.1", + } + clientMeta := &livekit.AnalyticsClientMeta{ + Region: "dark-side", + Node: "moon", + ClientAddr: "127.0.0.1", + ClientConnectTime: 420, + } + participantInfo := &livekit.ParticipantInfo{Sid: partSID} + guard := &telemetry.ReferenceGuard{} + + // do + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, clientInfo, clientMeta, true, guard) + time.Sleep(time.Millisecond * 500) + + // test + require.Equal(t, 1, fixture.analytics.SendEventCallCount()) + _, event := fixture.analytics.SendEventArgsForCall(0) + require.Equal(t, livekit.AnalyticsEventType_PARTICIPANT_JOINED, event.Type) + require.Equal(t, partSID, event.ParticipantId) + require.Equal(t, participantInfo, event.Participant) + require.Equal(t, room.Sid, event.RoomId) + require.Equal(t, room, event.Room) + + require.Equal(t, clientInfo.Sdk, event.ClientInfo.Sdk) + require.Equal(t, clientInfo.Version, event.ClientInfo.Version) + require.Equal(t, clientInfo.Os, event.ClientInfo.Os) + require.Equal(t, clientInfo.OsVersion, event.ClientInfo.OsVersion) + require.Equal(t, clientInfo.DeviceModel, event.ClientInfo.DeviceModel) + require.Equal(t, clientInfo.Browser, event.ClientInfo.Browser) + require.Equal(t, clientInfo.BrowserVersion, event.ClientInfo.BrowserVersion) + + require.Equal(t, clientMeta.Region, event.ClientMeta.Region) + require.Equal(t, clientMeta.Node, event.ClientMeta.Node) + require.Equal(t, clientMeta.ClientAddr, event.ClientMeta.ClientAddr) + require.Equal(t, clientMeta.ClientConnectTime, event.ClientMeta.ClientConnectTime) +} + +func Test_OnParticipantLeft_EventIsSent(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{Sid: "RoomSid", Name: "RoomName"} + partSID := "part1" + participantInfo := &livekit.ParticipantInfo{Sid: partSID} + guard := &telemetry.ReferenceGuard{} + + // do + fixture.sut.ParticipantActive(context.Background(), room, participantInfo, &livekit.AnalyticsClientMeta{}, false, guard) + fixture.sut.ParticipantLeft(context.Background(), room, participantInfo, true, guard) + time.Sleep(time.Millisecond * 500) + + // test + require.Equal(t, 2, fixture.analytics.SendEventCallCount()) + _, event := fixture.analytics.SendEventArgsForCall(1) + require.Equal(t, livekit.AnalyticsEventType_PARTICIPANT_LEFT, event.Type) + require.Equal(t, partSID, event.ParticipantId) + require.Equal(t, room.Sid, event.RoomId) + require.Equal(t, room, event.Room) +} + +func Test_OnTrackUpdate_EventIsSent(t *testing.T) { + fixture := createFixture() + + // prepare + partID := "part1" + trackID := "track1" + layer := &livekit.VideoLayer{ + Quality: livekit.VideoQuality_HIGH, + Width: uint32(360), + Height: uint32(720), + Bitrate: 2048, + } + + trackInfo := &livekit.TrackInfo{ + Sid: trackID, + Type: livekit.TrackType_VIDEO, + Muted: false, + Simulcast: false, + DisableDtx: false, + Layers: []*livekit.VideoLayer{layer}, + } + + // do + fixture.sut.TrackPublishedUpdate(context.Background(), livekit.ParticipantID(partID), trackInfo) + time.Sleep(time.Millisecond * 500) + + // test + require.Equal(t, 1, fixture.analytics.SendEventCallCount()) + _, event := fixture.analytics.SendEventArgsForCall(0) + require.Equal(t, livekit.AnalyticsEventType_TRACK_PUBLISHED_UPDATE, event.Type) + require.Equal(t, partID, event.ParticipantId) + + require.Equal(t, trackID, event.Track.Sid) + require.NotNil(t, event.Track.Layers) + require.Equal(t, layer.Width, event.Track.Layers[0].Width) + require.Equal(t, layer.Height, event.Track.Layers[0].Height) + require.Equal(t, layer.Quality, event.Track.Layers[0].Quality) + +} + +func Test_OnParticipantActive_EventIsSent(t *testing.T) { + fixture := createFixture() + + // prepare participant to change status + room := &livekit.Room{Sid: "RoomSid", Name: "RoomName"} + partSID := "part1" + + clientInfo := &livekit.ClientInfo{ + Sdk: 2, + Version: "v1", + Os: "mac", + OsVersion: "v1", + DeviceModel: "DM1", + Browser: "chrome", + BrowserVersion: "97.0.1", + } + clientMeta := &livekit.AnalyticsClientMeta{ + Region: "dark-side", + Node: "moon", + ClientAddr: "127.0.0.1", + } + participantInfo := &livekit.ParticipantInfo{Sid: partSID} + guard := &telemetry.ReferenceGuard{} + + // do + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, clientInfo, clientMeta, true, guard) + time.Sleep(time.Millisecond * 500) + + // test + require.Equal(t, 1, fixture.analytics.SendEventCallCount()) + _, event := fixture.analytics.SendEventArgsForCall(0) + + // test + // do + clientMetaConnect := &livekit.AnalyticsClientMeta{ + ClientConnectTime: 420, + } + + fixture.sut.ParticipantActive(context.Background(), room, participantInfo, clientMetaConnect, false, guard) + time.Sleep(time.Millisecond * 500) + + require.Equal(t, 2, fixture.analytics.SendEventCallCount()) + _, eventActive := fixture.analytics.SendEventArgsForCall(1) + require.Equal(t, livekit.AnalyticsEventType_PARTICIPANT_ACTIVE, eventActive.Type) + require.Equal(t, partSID, eventActive.ParticipantId) + require.Equal(t, room.Sid, eventActive.RoomId) + require.Equal(t, room, event.Room) + + require.Equal(t, clientMetaConnect.ClientConnectTime, eventActive.ClientMeta.ClientConnectTime) +} + +func Test_OnTrackSubscribed_EventIsSent(t *testing.T) { + fixture := createFixture() + + // prepare participant to change status + room := &livekit.Room{Sid: "RoomSid", Name: "RoomName"} + partSID := "part1" + publisherInfo := &livekit.ParticipantInfo{Sid: "pub1", Identity: "publisher"} + trackInfo := &livekit.TrackInfo{Sid: "tr1", Type: livekit.TrackType_VIDEO} + + clientInfo := &livekit.ClientInfo{ + Sdk: 2, + Version: "v1", + Os: "mac", + OsVersion: "v1", + DeviceModel: "DM1", + Browser: "chrome", + BrowserVersion: "97.0.1", + } + clientMeta := &livekit.AnalyticsClientMeta{ + Region: "dark-side", + Node: "moon", + ClientAddr: "127.0.0.1", + } + participantInfo := &livekit.ParticipantInfo{Sid: partSID} + guard := &telemetry.ReferenceGuard{} + + // do + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, clientInfo, clientMeta, true, guard) + time.Sleep(time.Millisecond * 500) + + // test + require.Equal(t, 1, fixture.analytics.SendEventCallCount()) + _, event := fixture.analytics.SendEventArgsForCall(0) + require.Equal(t, room, event.Room) + + // do + fixture.sut.TrackSubscribed(context.Background(), livekit.ParticipantID(partSID), trackInfo, publisherInfo, true) + time.Sleep(time.Millisecond * 500) + + require.Eventually(t, func() bool { + return fixture.analytics.SendEventCallCount() == 2 + }, time.Second, time.Millisecond*50, "expected send event to be called twice") + _, eventTrackSubscribed := fixture.analytics.SendEventArgsForCall(1) + require.Equal(t, livekit.AnalyticsEventType_TRACK_SUBSCRIBED, eventTrackSubscribed.Type) + require.Equal(t, partSID, eventTrackSubscribed.ParticipantId) + require.Equal(t, trackInfo.Sid, eventTrackSubscribed.Track.Sid) + require.Equal(t, trackInfo.Type, eventTrackSubscribed.Track.Type) + require.Equal(t, publisherInfo.Sid, eventTrackSubscribed.Publisher.Sid) + require.Equal(t, publisherInfo.Identity, eventTrackSubscribed.Publisher.Identity) + +} diff --git a/livekit/pkg/telemetry/prometheus/datapacket.go b/livekit/pkg/telemetry/prometheus/datapacket.go new file mode 100644 index 0000000..541016d --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/datapacket.go @@ -0,0 +1,75 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "slices" + "strings" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/livekit/protocol/livekit" +) + +var ( + promDataPacketStreamLabels = []string{"type", "mime_type"} + promDataPacketStreamMimeTypes = []string{"text", "image", "application", "audio", "video"} + + promDataPacketStreamDestCount *prometheus.HistogramVec + promDataPacketStreamSize *prometheus.HistogramVec +) + +func initDataPacketStats(nodeID string, nodeType livekit.NodeType) { + promDataPacketStreamDestCount = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "datapacket_stream", + Name: "dest_count", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{1, 2, 3, 4, 5, 10, 15, 25, 50}, + }, promDataPacketStreamLabels) + promDataPacketStreamSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "datapacket_stream", + Name: "bytes", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{128, 512, 2048, 8192, 32768, 131072, 524288, 2097152, 8388608, 33554432}, + }, promDataPacketStreamLabels) + + prometheus.MustRegister(promDataPacketStreamDestCount) + prometheus.MustRegister(promDataPacketStreamSize) +} + +func RecordDataPacketStream(h *livekit.DataStream_Header, destCount int) { + streamType := "unknown" + switch h.ContentHeader.(type) { + case *livekit.DataStream_Header_TextHeader: + streamType = "text" + case *livekit.DataStream_Header_ByteHeader: + streamType = "bytes" + } + + mimeType := strings.ToLower(h.MimeType) + if i := strings.IndexByte(mimeType, '/'); i != -1 { + mimeType = mimeType[:i] + } + if !slices.Contains(promDataPacketStreamMimeTypes, mimeType) { + mimeType = "unknown" + } + + promDataPacketStreamDestCount.WithLabelValues(streamType, mimeType).Observe(float64(destCount)) + if h.TotalLength != nil { + promDataPacketStreamSize.WithLabelValues(streamType, mimeType).Observe(float64(*h.TotalLength)) + } +} diff --git a/livekit/pkg/telemetry/prometheus/debug.go b/livekit/pkg/telemetry/prometheus/debug.go new file mode 100644 index 0000000..1f74598 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/debug.go @@ -0,0 +1,40 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + + "github.com/livekit/protocol/livekit" +) + +var ( + refCounts *prometheus.GaugeVec +) + +func initDebugStats(nodeID string, nodeType livekit.NodeType) { + refCounts = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "debug", + Name: "ref_count", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"referrer"}) + + prometheus.MustRegister(refCounts) +} + +func AddRef(referrer string, n int) { + refCounts.WithLabelValues(referrer).Add(float64(n)) +} diff --git a/livekit/pkg/telemetry/prometheus/node.go b/livekit/pkg/telemetry/prometheus/node.go new file mode 100644 index 0000000..c00a574 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/node.go @@ -0,0 +1,289 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/twitchtv/twirp" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/rpc" + "github.com/livekit/protocol/utils/hwstats" + "github.com/livekit/protocol/webhook" +) + +const ( + livekitNamespace string = "livekit" +) + +var ( + initialized atomic.Bool + + promMessageCounter *prometheus.CounterVec + promServiceOperationCounter *prometheus.CounterVec + promTwirpRequestStatusCounter *prometheus.CounterVec + + sysPacketsStart uint32 + sysDroppedPacketsStart uint32 + promSysPacketGauge *prometheus.GaugeVec + + cpuStats *hwstats.CPUStats + memoryStats *hwstats.MemoryStats +) + +func Init(nodeID string, nodeType livekit.NodeType) error { + if initialized.Swap(true) { + return nil + } + + promMessageCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "node", + Name: "messages", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, + []string{"type", "status", "direction"}, + ) + + promServiceOperationCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "node", + Name: "service_operation", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, + []string{"type", "status", "error_type"}, + ) + + promTwirpRequestStatusCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "node", + Name: "twirp_request_status", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, + []string{"service", "method", "status", "code"}, + ) + + promSysPacketGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "node", + Name: "packet_total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Help: "System level packet count. Count starts at 0 when service is first started.", + }, + []string{"type"}, + ) + + prometheus.MustRegister(promMessageCounter) + prometheus.MustRegister(promServiceOperationCounter) + prometheus.MustRegister(promTwirpRequestStatusCounter) + prometheus.MustRegister(promSysPacketGauge) + + sysPacketsStart, sysDroppedPacketsStart, _ = getTCStats() + + initPacketStats(nodeID, nodeType) + initRoomStats(nodeID, nodeType) + rpc.InitPSRPCStats(prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}) + webhook.InitWebhookStats(prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}) + initQualityStats(nodeID, nodeType) + initDataPacketStats(nodeID, nodeType) + initDebugStats(nodeID, nodeType) + + var err error + cpuStats, err = hwstats.NewCPUStats(nil) + if err != nil { + return err + } + + memoryStats, err = hwstats.NewMemoryStats() + if err != nil { + return err + } + + return nil +} + +func GetNodeStats(nodeStartedAt int64, prevStats []*livekit.NodeStats, rateIntervals []time.Duration) (*livekit.NodeStats, error) { + loadAvg, err := getLoadAvg() + if err != nil { + return nil, err + } + + // On MacOS, get "\"vm_stat\": executable file not found in $PATH" although it is in /usr/bin + // So, do not error out. Use the information if it is available. + memUsed, memTotal, _ := memoryStats.GetMemory() + + sysPackets, sysDroppedPackets, _ := getTCStats() + promSysPacketGauge.WithLabelValues("out").Set(float64(sysPackets - sysPacketsStart)) + promSysPacketGauge.WithLabelValues("dropped").Set(float64(sysDroppedPackets - sysDroppedPacketsStart)) + + stats := &livekit.NodeStats{ + StartedAt: nodeStartedAt, + UpdatedAt: time.Now().Unix(), + NumRooms: roomCurrent.Load(), + NumClients: participantCurrent.Load(), + NumTracksIn: trackPublishedCurrent.Load(), + NumTracksOut: trackSubscribedCurrent.Load(), + NumTrackPublishAttempts: trackPublishAttempts.Load(), + NumTrackPublishSuccess: trackPublishSuccess.Load(), + NumTrackPublishCancels: trackPublishCancels.Load(), + NumTrackSubscribeAttempts: trackSubscribeAttempts.Load(), + NumTrackSubscribeSuccess: trackSubscribeSuccess.Load(), + NumTrackSubscribeCancels: trackSubscribeCancels.Load(), + BytesIn: bytesIn.Load(), + BytesOut: bytesOut.Load(), + PacketsIn: packetsIn.Load(), + PacketsOut: packetsOut.Load(), + RetransmitBytesOut: retransmitBytes.Load(), + RetransmitPacketsOut: retransmitPackets.Load(), + NackTotal: nackTotal.Load(), + ParticipantSignalConnected: participantSignalConnected.Load(), + ParticipantRtcInit: participantRTCInit.Load(), + ParticipantRtcConnected: participantRTCConnected.Load(), + ParticipantRtcCanceled: participantRTCCanceled.Load(), + ForwardLatency: forwardLatency.Load(), + ForwardJitter: forwardJitter.Load(), + NumCpus: uint32(cpuStats.NumCPU()), // this will round down to the nearest integer + CpuLoad: float32(cpuStats.GetCPULoad()), + MemoryTotal: memTotal, + MemoryUsed: memUsed, + LoadAvgLast1Min: float32(loadAvg.Loadavg1), + LoadAvgLast5Min: float32(loadAvg.Loadavg5), + LoadAvgLast15Min: float32(loadAvg.Loadavg15), + SysPacketsOut: sysPackets, + SysPacketsDropped: sysDroppedPackets, + } + + for _, rateInterval := range rateIntervals { + for idx := len(prevStats) - 1; idx >= 0; idx-- { + prev := prevStats[idx] + if prev == nil { + continue + } + + if stats.UpdatedAt-prev.UpdatedAt >= int64(rateInterval.Seconds()) { + if rate := getNodeStatsRate(append(prevStats[idx:], stats)); rate != nil { + stats.Rates = append(stats.Rates, rate) + } + break + } + } + } + + return stats, nil +} + +func getNodeStatsRate(statsHistory []*livekit.NodeStats) *livekit.NodeStatsRate { + if len(statsHistory) == 0 { + return nil + } + + elapsed := statsHistory[len(statsHistory)-1].UpdatedAt - statsHistory[0].UpdatedAt + if elapsed <= 0 { + return nil + } + + // time weighted averages + var cpuLoad, memoryUsed, memoryTotal, memoryLoad float32 + for idx := len(statsHistory) - 1; idx > 0; idx-- { + stats := statsHistory[idx] + prevStats := statsHistory[idx-1] + if stats == nil || prevStats == nil { + continue + } + + spanElapsed := stats.UpdatedAt - prevStats.UpdatedAt + if spanElapsed <= 0 { + continue + } + + cpuLoad += stats.CpuLoad * float32(spanElapsed) + memoryUsed += float32(stats.MemoryUsed) * float32(spanElapsed) + memoryTotal += float32(stats.MemoryTotal) * float32(spanElapsed) + if stats.MemoryTotal > 0 { + memoryLoad += float32(stats.MemoryUsed) / float32(stats.MemoryTotal) * float32(spanElapsed) + } + } + + earlier := statsHistory[0] + later := statsHistory[len(statsHistory)-1] + rate := &livekit.NodeStatsRate{ + StartedAt: earlier.UpdatedAt, + EndedAt: later.UpdatedAt, + Duration: elapsed, + BytesIn: perSec(earlier.BytesIn, later.BytesIn, elapsed), + BytesOut: perSec(earlier.BytesOut, later.BytesOut, elapsed), + PacketsIn: perSec(earlier.PacketsIn, later.PacketsIn, elapsed), + PacketsOut: perSec(earlier.PacketsOut, later.PacketsOut, elapsed), + RetransmitBytesOut: perSec(earlier.RetransmitBytesOut, later.RetransmitBytesOut, elapsed), + RetransmitPacketsOut: perSec(earlier.RetransmitPacketsOut, later.RetransmitPacketsOut, elapsed), + NackTotal: perSec(earlier.NackTotal, later.NackTotal, elapsed), + ParticipantSignalConnected: perSec(earlier.ParticipantSignalConnected, later.ParticipantSignalConnected, elapsed), + ParticipantRtcInit: perSec(earlier.ParticipantRtcInit, later.ParticipantRtcInit, elapsed), + ParticipantRtcConnected: perSec(earlier.ParticipantRtcConnected, later.ParticipantRtcConnected, elapsed), + ParticipantRtcCanceled: perSec(earlier.ParticipantRtcCanceled, later.ParticipantRtcCanceled, elapsed), + SysPacketsOut: perSec(uint64(earlier.SysPacketsOut), uint64(later.SysPacketsOut), elapsed), + SysPacketsDropped: perSec(uint64(earlier.SysPacketsDropped), uint64(later.SysPacketsDropped), elapsed), + TrackPublishAttempts: perSec(uint64(earlier.NumTrackPublishAttempts), uint64(later.NumTrackPublishAttempts), elapsed), + TrackPublishSuccess: perSec(uint64(earlier.NumTrackPublishSuccess), uint64(later.NumTrackPublishSuccess), elapsed), + TrackPublishCancels: perSec(uint64(earlier.NumTrackPublishCancels), uint64(later.NumTrackPublishCancels), elapsed), + TrackSubscribeAttempts: perSec(uint64(earlier.NumTrackSubscribeAttempts), uint64(later.NumTrackSubscribeAttempts), elapsed), + TrackSubscribeSuccess: perSec(uint64(earlier.NumTrackSubscribeSuccess), uint64(later.NumTrackSubscribeSuccess), elapsed), + TrackSubscribeCancels: perSec(uint64(earlier.NumTrackSubscribeCancels), uint64(later.NumTrackSubscribeCancels), elapsed), + CpuLoad: cpuLoad / float32(elapsed), + MemoryLoad: memoryLoad / float32(elapsed), + MemoryUsed: memoryUsed / float32(elapsed), + MemoryTotal: memoryTotal / float32(elapsed), + } + return rate +} + +func perSec(prev, curr uint64, secs int64) float32 { + return float32(curr-prev) / float32(secs) +} + +func RecordSignalRequestSuccess() { + promMessageCounter.WithLabelValues("signal", "success", "request").Add(1) +} + +func RecordSignalRequestFailure() { + promMessageCounter.WithLabelValues("signal", "failure", "request").Add(1) +} + +func RecordSignalResponseSuccess() { + promMessageCounter.WithLabelValues("signal", "success", "response").Add(1) +} + +func RecordSignalResponseFailure() { + promMessageCounter.WithLabelValues("signal", "failure", "response").Add(1) +} + +func RecordServiceOperationSuccess(op string) { + promServiceOperationCounter.WithLabelValues(op, "success", "").Add(1) +} + +func RecordServiceOperationError(op string, error string) { + promServiceOperationCounter.WithLabelValues(op, "error", error).Add(1) +} + +func RecordTwirpRequestStatus(service string, method string, statusFamily string, code twirp.ErrorCode) { + promTwirpRequestStatusCounter.WithLabelValues(service, method, statusFamily, string(code)).Add(1) +} diff --git a/livekit/pkg/telemetry/prometheus/node_linux.go b/livekit/pkg/telemetry/prometheus/node_linux.go new file mode 100644 index 0000000..baa6919 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/node_linux.go @@ -0,0 +1,46 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package prometheus + +import ( + "fmt" + + "github.com/florianl/go-tc" +) + +func getTCStats() (packets, drops uint32, err error) { + rtnl, err := tc.Open(&tc.Config{}) + if err != nil { + err = fmt.Errorf("could not open rtnetlink socket: %v", err) + return + } + defer rtnl.Close() + + qdiscs, err := rtnl.Qdisc().Get() + if err != nil { + err = fmt.Errorf("could not get qdiscs: %v", err) + return + } + + for _, qdisc := range qdiscs { + packets = packets + qdisc.Stats.Packets + drops = drops + qdisc.Stats.Drops + } + + return +} diff --git a/livekit/pkg/telemetry/prometheus/node_nonlinux.go b/livekit/pkg/telemetry/prometheus/node_nonlinux.go new file mode 100644 index 0000000..22fc11c --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/node_nonlinux.go @@ -0,0 +1,22 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !linux + +package prometheus + +func getTCStats() (packets, drops uint32, err error) { + // linux only + return +} diff --git a/livekit/pkg/telemetry/prometheus/node_nonwindows.go b/livekit/pkg/telemetry/prometheus/node_nonwindows.go new file mode 100644 index 0000000..b765ce2 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/node_nonwindows.go @@ -0,0 +1,56 @@ +//go:build !windows + +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prometheus + +import ( + "runtime" + "sync" + + "github.com/mackerelio/go-osstat/cpu" + "github.com/mackerelio/go-osstat/loadavg" +) + +var ( + cpuStatsLock sync.RWMutex + lastCPUTotal, lastCPUIdle uint64 +) + +func getLoadAvg() (*loadavg.Stats, error) { + return loadavg.Get() +} + +func getCPUStats() (cpuLoad float32, numCPUs uint32, err error) { + cpuInfo, err := cpu.Get() + if err != nil { + return + } + + cpuStatsLock.Lock() + if lastCPUTotal > 0 && lastCPUTotal < cpuInfo.Total { + cpuLoad = 1 - float32(cpuInfo.Idle-lastCPUIdle)/float32(cpuInfo.Total-lastCPUTotal) + } + + lastCPUTotal = cpuInfo.Total + lastCPUIdle = cpuInfo.Idle + cpuStatsLock.Unlock() + + numCPUs = uint32(runtime.NumCPU()) + + return +} diff --git a/livekit/pkg/telemetry/prometheus/node_windows.go b/livekit/pkg/telemetry/prometheus/node_windows.go new file mode 100644 index 0000000..eb0ba48 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/node_windows.go @@ -0,0 +1,29 @@ +//go:build windows + +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prometheus + +import "github.com/mackerelio/go-osstat/loadavg" + +func getLoadAvg() (*loadavg.Stats, error) { + return &loadavg.Stats{}, nil +} + +func getCPUStats() (cpuLoad float32, numCPUs uint32, err error) { + return 1, 1, nil +} diff --git a/livekit/pkg/telemetry/prometheus/packets.go b/livekit/pkg/telemetry/prometheus/packets.go new file mode 100644 index 0000000..576bb65 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/packets.go @@ -0,0 +1,359 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" +) + +type Direction string + +const ( + Incoming Direction = "incoming" + Outgoing Direction = "outgoing" +) + +type TransmissionType string + +const ( + TransmissionInitial TransmissionType = "initial" + TransmissionRetransmit TransmissionType = "retransmit" +) + +var ( + bytesIn atomic.Uint64 + bytesOut atomic.Uint64 + packetsIn atomic.Uint64 + packetsOut atomic.Uint64 + nackTotal atomic.Uint64 + retransmitBytes atomic.Uint64 + retransmitPackets atomic.Uint64 + participantSignalConnected atomic.Uint64 + participantRTCConnected atomic.Uint64 + participantRTCInit atomic.Uint64 + participantRTCCanceled atomic.Uint64 + forwardLatency atomic.Uint32 + forwardJitter atomic.Uint32 + + promPacketLabels = []string{"direction", "transmission", "country"} + promPacketTotal *prometheus.CounterVec + promPacketBytes *prometheus.CounterVec + promRTCPLabels = []string{"direction", "country"} + promStreamLabels = []string{"direction", "source", "type", "country"} + promNackTotal *prometheus.CounterVec + promPliTotal *prometheus.CounterVec + promFirTotal *prometheus.CounterVec + promPacketLossTotal *prometheus.CounterVec + promPacketLoss *prometheus.HistogramVec + promPacketOutOfOrderTotal *prometheus.CounterVec + promPacketOutOfOrder *prometheus.HistogramVec + promJitter *prometheus.HistogramVec + promRTT *prometheus.HistogramVec + promParticipantJoin *prometheus.CounterVec + promConnections *prometheus.GaugeVec + promForwardLatency prometheus.Gauge + promForwardJitter prometheus.Gauge + promForwardLatencyHist prometheus.Histogram + + promPacketTotalIncomingInitial prometheus.Counter + promPacketTotalIncomingRetransmit prometheus.Counter + promPacketTotalOutgoingInitial prometheus.Counter + promPacketTotalOutgoingRetransmit prometheus.Counter + promPacketBytesIncomingInitial prometheus.Counter + promPacketBytesIncomingRetransmit prometheus.Counter + promPacketBytesOutgoingInitial prometheus.Counter + promPacketBytesOutgoingRetransmit prometheus.Counter +) + +func initPacketStats(nodeID string, nodeType livekit.NodeType) { + promPacketTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "packet", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promPacketLabels) + promPacketBytes = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "packet", + Name: "bytes", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promPacketLabels) + promNackTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "nack", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promRTCPLabels) + promPliTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "pli", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promRTCPLabels) + promFirTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "fir", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promRTCPLabels) + promPacketLossTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "packet_loss", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promStreamLabels) + promPacketLoss = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "packet_loss", + Name: "percent", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{0.0, 0.1, 0.3, 0.5, 0.7, 1, 5, 10, 40, 100}, + }, promStreamLabels) + promPacketOutOfOrderTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "packet_out_of_order", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, promStreamLabels) + promPacketOutOfOrder = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "packet_out_of_order", + Name: "percent", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{0.0, 0.1, 0.3, 0.5, 0.7, 1, 5, 10, 40, 100}, + }, promStreamLabels) + promJitter = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "jitter", + Name: "us", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + // 1ms, 10ms, 30ms, 50ms, 70ms, 100ms, 300ms, 600ms, 1s + Buckets: []float64{1000, 10000, 30000, 50000, 70000, 100000, 300000, 600000, 1000000}, + }, promStreamLabels) + promRTT = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "rtt", + Name: "ms", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{50, 100, 150, 200, 250, 500, 750, 1000, 5000, 10000}, + }, promStreamLabels) + promParticipantJoin = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "participant_join", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"state"}) + promConnections = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "connection", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"kind"}) + promForwardLatency = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "forward", + Name: "latency", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }) + promForwardJitter = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "forward", + Name: "jitter", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }) + promForwardLatencyHist = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "forward_latency", + Name: "ns", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + // 50us, 100us, 250us, 500us, 1ms, 2ms, 3ms, 5ms, 10ms, 20ms + Buckets: []float64{ + 50 * 1000, + 100 * 1000, + 250 * 1000, + 500 * 1000, + 1 * 1000 * 1000, + 2 * 1000 * 1000, + 3 * 1000 * 1000, + 5 * 1000 * 1000, + 10 * 1000 * 1000, + 20 * 1000 * 1000, + }, + }) + + prometheus.MustRegister(promPacketTotal) + prometheus.MustRegister(promPacketBytes) + prometheus.MustRegister(promNackTotal) + prometheus.MustRegister(promPliTotal) + prometheus.MustRegister(promFirTotal) + prometheus.MustRegister(promPacketLossTotal) + prometheus.MustRegister(promPacketLoss) + prometheus.MustRegister(promPacketOutOfOrderTotal) + prometheus.MustRegister(promPacketOutOfOrder) + prometheus.MustRegister(promJitter) + prometheus.MustRegister(promRTT) + prometheus.MustRegister(promParticipantJoin) + prometheus.MustRegister(promConnections) + prometheus.MustRegister(promForwardLatency) + prometheus.MustRegister(promForwardJitter) + prometheus.MustRegister(promForwardLatencyHist) +} + +func IncrementPackets(country string, direction Direction, count uint64, retransmit bool) { + var transmission TransmissionType + if retransmit { + transmission = TransmissionRetransmit + } else { + transmission = TransmissionInitial + } + promPacketTotal.WithLabelValues(string(direction), string(transmission), country).Add(float64(count)) + + if direction == Incoming { + packetsIn.Add(count) + } else { + packetsOut.Add(count) + if retransmit { + retransmitPackets.Add(count) + } + } +} + +func IncrementBytes(country string, direction Direction, count uint64, retransmit bool) { + var transmission TransmissionType + if retransmit { + transmission = TransmissionRetransmit + } else { + transmission = TransmissionInitial + } + promPacketBytes.WithLabelValues(string(direction), string(transmission), country).Add(float64(count)) + + if direction == Incoming { + bytesIn.Add(count) + } else { + bytesOut.Add(count) + if retransmit { + retransmitBytes.Add(count) + } + } +} + +func IncrementRTCP(country string, direction Direction, nack, pli, fir uint32) { + if nack > 0 { + promNackTotal.WithLabelValues(string(direction), country).Add(float64(nack)) + nackTotal.Add(uint64(nack)) + } + if pli > 0 { + promPliTotal.WithLabelValues(string(direction), country).Add(float64(pli)) + } + if fir > 0 { + promFirTotal.WithLabelValues(string(direction), country).Add(float64(fir)) + } +} + +func RecordPacketLoss( + country string, + direction Direction, + trackSource livekit.TrackSource, + trackType livekit.TrackType, + lost uint32, + total uint32, +) { + if total > 0 { + promPacketLoss.WithLabelValues(string(direction), trackSource.String(), trackType.String(), country).Observe(float64(lost) / float64(total) * 100) + } + if lost > 0 { + promPacketLossTotal.WithLabelValues(string(direction), trackSource.String(), trackType.String(), country).Add(float64(lost)) + } +} + +func RecordPacketOutOfOrder(country string, direction Direction, trackSource livekit.TrackSource, trackType livekit.TrackType, ooo, total uint32) { + if total > 0 { + promPacketOutOfOrder.WithLabelValues(string(direction), trackSource.String(), trackType.String(), country).Observe(float64(ooo) / float64(total) * 100) + } + if ooo > 0 { + promPacketOutOfOrderTotal.WithLabelValues(string(direction), trackSource.String(), trackType.String(), country).Add(float64(ooo)) + } +} + +func RecordJitter(country string, direction Direction, trackSource livekit.TrackSource, trackType livekit.TrackType, jitter uint32) { + if jitter > 0 { + promJitter.WithLabelValues(string(direction), trackSource.String(), trackType.String(), country).Observe(float64(jitter)) + } +} + +func RecordRTT(country string, direction Direction, trackSource livekit.TrackSource, trackType livekit.TrackType, rtt uint32) { + if rtt > 0 { + promRTT.WithLabelValues(string(direction), trackSource.String(), trackType.String(), country).Observe(float64(rtt)) + } +} + +func IncrementParticipantJoin(join uint32) { + if join > 0 { + participantSignalConnected.Add(uint64(join)) + promParticipantJoin.WithLabelValues("signal_connected").Add(float64(join)) + } +} + +func IncrementParticipantJoinFail(join uint32) { + if join > 0 { + promParticipantJoin.WithLabelValues("signal_failed").Add(float64(join)) + } +} + +func IncrementParticipantRtcInit(join uint32) { + if join > 0 { + participantRTCInit.Add(uint64(join)) + promParticipantJoin.WithLabelValues("rtc_init").Add(float64(join)) + } +} + +func IncrementParticipantRtcConnected(join uint32) { + if join > 0 { + participantRTCConnected.Add(uint64(join)) + promParticipantJoin.WithLabelValues("rtc_connected").Add(float64(join)) + } +} + +func IncrementParticipantRtcCanceled(numCancels uint64) { + if numCancels > 0 { + participantRTCConnected.Add(numCancels) + promParticipantJoin.WithLabelValues("rtc_canceled").Add(float64(numCancels)) + } +} + +func AddConnection(direction Direction) { + promConnections.WithLabelValues(string(direction)).Add(1) +} + +func SubConnection(direction Direction) { + promConnections.WithLabelValues(string(direction)).Sub(1) +} + +func RecordForwardLatencySample(forwardLatency int64) { + promForwardLatencyHist.Observe(float64(forwardLatency)) +} + +func RecordForwardLatency(longTermLatencyAvg uint32) { + forwardLatency.Store(longTermLatencyAvg) + promForwardLatency.Set(float64(longTermLatencyAvg)) +} + +func RecordForwardJitter(longTermJitterAvg uint32) { + forwardJitter.Store(longTermJitterAvg) + promForwardJitter.Set(float64(longTermJitterAvg)) +} diff --git a/livekit/pkg/telemetry/prometheus/quality.go b/livekit/pkg/telemetry/prometheus/quality.go new file mode 100644 index 0000000..b55a3f8 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/quality.go @@ -0,0 +1,51 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "github.com/prometheus/client_golang/prometheus" + + "github.com/livekit/protocol/livekit" +) + +var ( + qualityRating prometheus.Histogram + qualityScore prometheus.Histogram +) + +func initQualityStats(nodeID string, nodeType livekit.NodeType) { + qualityRating = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "quality", + Name: "rating", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{0, 1, 2}, + }) + qualityScore = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "quality", + Name: "score", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{1.0, 2.0, 2.5, 3.0, 3.25, 3.5, 3.75, 4.0, 4.25, 4.5}, + }) + + prometheus.MustRegister(qualityRating) + prometheus.MustRegister(qualityScore) +} + +func RecordQuality(rating livekit.ConnectionQuality, score float32) { + qualityRating.Observe(float64(rating)) + qualityScore.Observe(float64(score)) +} diff --git a/livekit/pkg/telemetry/prometheus/rooms.go b/livekit/pkg/telemetry/prometheus/rooms.go new file mode 100644 index 0000000..00ec4b1 --- /dev/null +++ b/livekit/pkg/telemetry/prometheus/rooms.go @@ -0,0 +1,271 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "strconv" + "time" + + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" +) + +var ( + roomCurrent atomic.Int32 + participantCurrent atomic.Int32 + trackPublishedCurrent atomic.Int32 + trackSubscribedCurrent atomic.Int32 + trackPublishAttempts atomic.Int32 + trackPublishSuccess atomic.Int32 + trackPublishCancels atomic.Int32 + trackSubscribeAttempts atomic.Int32 + trackSubscribeSuccess atomic.Int32 + trackSubscribeCancels atomic.Int32 + // count the number of failures that are due to user error (permissions, track doesn't exist), so we could compute + // success rate by subtracting this from total attempts + trackSubscribeUserError atomic.Int32 + + promRoomCurrent prometheus.Gauge + promRoomDuration prometheus.Histogram + promParticipantCurrent prometheus.Gauge + promTrackPublishedCurrent *prometheus.GaugeVec + promTrackSubscribedCurrent *prometheus.GaugeVec + promTrackPublishCounter *prometheus.CounterVec + promTrackSubscribeCounter *prometheus.CounterVec + promSessionStartTime *prometheus.HistogramVec + promSessionDuration *prometheus.HistogramVec + promPubSubTime *prometheus.HistogramVec +) + +func initRoomStats(nodeID string, nodeType livekit.NodeType) { + promRoomCurrent = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "room", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }) + promRoomDuration = prometheus.NewHistogram(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "room", + Name: "duration_seconds", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{ + 5, 10, 60, 5 * 60, 10 * 60, 30 * 60, 60 * 60, 2 * 60 * 60, 5 * 60 * 60, 10 * 60 * 60, + }, + }) + promParticipantCurrent = prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "participant", + Name: "total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }) + promTrackPublishedCurrent = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "track", + Name: "published_total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"kind"}) + promTrackSubscribedCurrent = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: livekitNamespace, + Subsystem: "track", + Name: "subscribed_total", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"kind"}) + promTrackPublishCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "track", + Name: "publish_counter", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"kind", "state"}) + promTrackSubscribeCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: livekitNamespace, + Subsystem: "track", + Name: "subscribe_counter", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + }, []string{"state", "error"}) + promSessionStartTime = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "session", + Name: "start_time_ms", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: prometheus.ExponentialBucketsRange(100, 10000, 15), + }, []string{"protocol_version"}) + promSessionDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "session", + Name: "duration_ms", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: prometheus.ExponentialBucketsRange(100, 4*60*60*1000, 15), + }, []string{"protocol_version"}) + promPubSubTime = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: livekitNamespace, + Subsystem: "pubsubtime", + Name: "ms", + ConstLabels: prometheus.Labels{"node_id": nodeID, "node_type": nodeType.String()}, + Buckets: []float64{100, 200, 500, 700, 1000, 5000, 10000}, + }, append(promStreamLabels, "sdk", "kind", "count")) + + prometheus.MustRegister(promRoomCurrent) + prometheus.MustRegister(promRoomDuration) + prometheus.MustRegister(promParticipantCurrent) + prometheus.MustRegister(promTrackPublishedCurrent) + prometheus.MustRegister(promTrackSubscribedCurrent) + prometheus.MustRegister(promTrackPublishCounter) + prometheus.MustRegister(promTrackSubscribeCounter) + prometheus.MustRegister(promSessionStartTime) + prometheus.MustRegister(promSessionDuration) + prometheus.MustRegister(promPubSubTime) +} + +func RoomStarted() { + promRoomCurrent.Add(1) + roomCurrent.Inc() +} + +func RoomEnded(startedAt time.Time) { + if !startedAt.IsZero() { + promRoomDuration.Observe(float64(time.Since(startedAt)) / float64(time.Second)) + } + promRoomCurrent.Sub(1) + roomCurrent.Dec() +} + +func AddParticipant() { + promParticipantCurrent.Add(1) + participantCurrent.Inc() +} + +func SubParticipant() { + promParticipantCurrent.Sub(1) + participantCurrent.Dec() +} + +func AddPublishedTrack(kind string) { + promTrackPublishedCurrent.WithLabelValues(kind).Add(1) + trackPublishedCurrent.Inc() +} + +func SubPublishedTrack(kind string) { + promTrackPublishedCurrent.WithLabelValues(kind).Sub(1) + trackPublishedCurrent.Dec() +} + +func RecordTrackPublishAttempt(kind string) { + trackPublishAttempts.Inc() + promTrackPublishCounter.WithLabelValues(kind, "attempt").Inc() +} + +func RecordTrackPublishSuccess(kind string) { + trackPublishSuccess.Inc() + promTrackPublishCounter.WithLabelValues(kind, "success").Inc() +} + +func RecordTrackPublishCancels(kind string, numCancels int32) { + trackPublishCancels.Add(numCancels) + promTrackPublishCounter.WithLabelValues(kind, "cancel").Add(float64(numCancels)) +} + +func RecordPublishTime( + country string, + source livekit.TrackSource, + trackType livekit.TrackType, + d time.Duration, + sdk livekit.ClientInfo_SDK, + kind livekit.ParticipantInfo_Kind, +) { + recordPubSubTime(true, country, source, trackType, d, sdk, kind, 1) +} + +func RecordSubscribeTime( + country string, + source livekit.TrackSource, + trackType livekit.TrackType, + d time.Duration, + sdk livekit.ClientInfo_SDK, + kind livekit.ParticipantInfo_Kind, + count int, +) { + recordPubSubTime(false, country, source, trackType, d, sdk, kind, count) +} + +func recordPubSubTime( + isPublish bool, + country string, + source livekit.TrackSource, + trackType livekit.TrackType, + d time.Duration, + sdk livekit.ClientInfo_SDK, + kind livekit.ParticipantInfo_Kind, + count int, +) { + direction := "subscribe" + if isPublish { + direction = "publish" + } + promPubSubTime.WithLabelValues( + direction, + source.String(), + trackType.String(), + country, + sdk.String(), + kind.String(), + strconv.Itoa(count), + ).Observe(float64(d.Milliseconds())) +} + +func RecordTrackSubscribeSuccess(kind string) { + // modify both current and total counters + promTrackSubscribedCurrent.WithLabelValues(kind).Add(1) + trackSubscribedCurrent.Inc() + + promTrackSubscribeCounter.WithLabelValues("success", "").Inc() + trackSubscribeSuccess.Inc() +} + +func RecordTrackUnsubscribed(kind string) { + // unsubscribed modifies current counter, but we leave the total values alone since they + // are used to compute rate + promTrackSubscribedCurrent.WithLabelValues(kind).Sub(1) + trackSubscribedCurrent.Dec() +} + +func RecordTrackSubscribeAttempt() { + trackSubscribeAttempts.Inc() + promTrackSubscribeCounter.WithLabelValues("attempt", "").Inc() +} + +func RecordTrackSubscribeFailure(err error, isUserError bool) { + promTrackSubscribeCounter.WithLabelValues("failure", err.Error()).Inc() + + if isUserError { + trackSubscribeUserError.Inc() + trackSubscribeCancels.Inc() + } +} + +func RecordTrackSubscribeCancels(numCancels int32) { + trackSubscribeCancels.Add(numCancels) + promTrackSubscribeCounter.WithLabelValues("cancel", "").Add(float64(numCancels)) +} + +func RecordSessionStartTime(protocolVersion int, d time.Duration) { + promSessionStartTime.WithLabelValues(strconv.Itoa(protocolVersion)).Observe(float64(d.Milliseconds())) +} + +func RecordSessionDuration(protocolVersion int, d time.Duration) { + promSessionDuration.WithLabelValues(strconv.Itoa(protocolVersion)).Observe(float64(d.Milliseconds())) +} diff --git a/livekit/pkg/telemetry/signalanddatastats.go b/livekit/pkg/telemetry/signalanddatastats.go new file mode 100644 index 0000000..898aa98 --- /dev/null +++ b/livekit/pkg/telemetry/signalanddatastats.go @@ -0,0 +1,277 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/frostbyte73/core" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/observability/roomobs" + "github.com/livekit/protocol/utils" +) + +type BytesTrackType string + +const ( + BytesTrackTypeData BytesTrackType = "DT" + BytesTrackTypeSignal BytesTrackType = "SG" +) + +// ------------------------------- + +type TrafficTotals struct { + At time.Time + SendBytes uint64 + SendMessages uint32 + RecvBytes uint64 + RecvMessages uint32 +} + +// -------------------------------- + +// stats for signal and data channel +type BytesTrackStats struct { + country string + trackID livekit.TrackID + pID livekit.ParticipantID + send, recv atomic.Uint64 + sendMessages, recvMessages atomic.Uint32 + totalSendBytes, totalRecvBytes atomic.Uint64 + totalSendMessages, totalRecvMessages atomic.Uint32 + telemetry TelemetryService + reporter roomobs.TrackReporter + done core.Fuse +} + +func NewBytesTrackStats( + country string, + trackID livekit.TrackID, + pID livekit.ParticipantID, + telemetry TelemetryService, + participantReporter roomobs.ParticipantSessionReporter, +) *BytesTrackStats { + s := &BytesTrackStats{ + country: country, + trackID: trackID, + pID: pID, + telemetry: telemetry, + reporter: participantReporter.WithTrack(trackID.String()), + } + go s.worker() + return s +} + +func (s *BytesTrackStats) AddBytes(bytes uint64, isSend bool) { + if isSend { + s.send.Add(bytes) + s.sendMessages.Inc() + s.totalSendBytes.Add(bytes) + s.totalSendMessages.Inc() + + s.reporter.Tx(func(tx roomobs.TrackTx) { + tx.ReportType(roomobs.TrackTypeData) + tx.ReportSendBytes(uint32(bytes)) + tx.ReportSendPackets(1) + }) + } else { + s.recv.Add(bytes) + s.recvMessages.Inc() + s.totalRecvBytes.Add(bytes) + s.totalRecvMessages.Inc() + + s.reporter.Tx(func(tx roomobs.TrackTx) { + tx.ReportType(roomobs.TrackTypeData) + tx.ReportRecvBytes(uint32(bytes)) + tx.ReportRecvPackets(1) + }) + } +} + +func (s *BytesTrackStats) GetTrafficTotals() *TrafficTotals { + return &TrafficTotals{ + At: time.Now(), + SendBytes: s.totalSendBytes.Load(), + SendMessages: s.totalSendMessages.Load(), + RecvBytes: s.totalRecvBytes.Load(), + RecvMessages: s.totalRecvMessages.Load(), + } +} + +func (s *BytesTrackStats) Stop() { + s.done.Break() +} + +func (s *BytesTrackStats) report() { + if recv := s.recv.Swap(0); recv > 0 { + packets := s.recvMessages.Swap(0) + s.telemetry.TrackStats( + StatsKeyForData(s.country, livekit.StreamType_UPSTREAM, s.pID, s.trackID), + &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: recv, + PrimaryPackets: packets, + }, + }, + }, + ) + } + + if send := s.send.Swap(0); send > 0 { + packets := s.sendMessages.Swap(0) + s.telemetry.TrackStats( + StatsKeyForData(s.country, livekit.StreamType_DOWNSTREAM, s.pID, s.trackID), + &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: send, + PrimaryPackets: packets, + }, + }, + }, + ) + } +} + +func (s *BytesTrackStats) worker() { + ticker := time.NewTicker(5 * time.Second) + defer func() { + ticker.Stop() + s.report() + }() + + for { + select { + case <-s.done.Watch(): + return + case <-ticker.C: + s.report() + } + } +} + +// ----------------------------------------------------------------------- + +type BytesSignalStats struct { + BytesTrackStats + ctx context.Context + + guard ReferenceGuard + + participantResolver roomobs.ParticipantReporterResolver + trackResolver roomobs.KeyResolver + + mu sync.Mutex + ri *livekit.Room + pi *livekit.ParticipantInfo + stopped chan struct{} +} + +func NewBytesSignalStats( + ctx context.Context, + telemetry TelemetryService, +) *BytesSignalStats { + projectReporter := telemetry.RoomProjectReporter(ctx) + participantReporter, participantReporterResolver := roomobs.DeferredParticipantReporter(projectReporter) + trackReporter, trackReporterResolver := participantReporter.WithDeferredTrack() + return &BytesSignalStats{ + BytesTrackStats: BytesTrackStats{ + telemetry: telemetry, + reporter: trackReporter, + }, + ctx: ctx, + participantResolver: participantReporterResolver, + trackResolver: trackReporterResolver, + } +} + +func (s *BytesSignalStats) ResolveRoom(ri *livekit.Room) { + s.mu.Lock() + defer s.mu.Unlock() + if s.ri == nil && ri.GetSid() != "" { + s.ri = &livekit.Room{ + Sid: ri.Sid, + Name: ri.Name, + } + s.maybeStart() + } +} + +func (s *BytesSignalStats) ResolveParticipant(pi *livekit.ParticipantInfo) { + s.mu.Lock() + defer s.mu.Unlock() + if s.pi == nil && pi != nil { + s.pi = &livekit.ParticipantInfo{ + Sid: pi.Sid, + Identity: pi.Identity, + } + s.maybeStart() + } +} + +func (s *BytesSignalStats) Reset() { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped != nil { + s.done.Break() + <-s.stopped + s.stopped = nil + s.done = core.Fuse{} + } + s.ri = nil + s.pi = nil + + s.participantResolver.Reset() + s.trackResolver.Reset() +} + +func (s *BytesSignalStats) maybeStart() { + if s.ri == nil || s.pi == nil { + return + } + + s.pID = livekit.ParticipantID(s.pi.Sid) + s.trackID = BytesTrackIDForParticipantID(BytesTrackTypeSignal, s.pID) + + s.participantResolver.Resolve( + livekit.RoomName(s.ri.Name), + livekit.RoomID(s.ri.Sid), + livekit.ParticipantIdentity(s.pi.Identity), + livekit.ParticipantID(s.pi.Sid), + ) + s.trackResolver.Resolve(string(s.trackID)) + + s.telemetry.ParticipantJoined(s.ctx, s.ri, s.pi, nil, nil, false, &s.guard) + s.stopped = make(chan struct{}) + go s.worker() +} + +func (s *BytesSignalStats) worker() { + s.BytesTrackStats.worker() + s.telemetry.ParticipantLeft(s.ctx, s.ri, s.pi, false, &s.guard) + close(s.stopped) +} + +// ----------------------------------------------------------------------- + +func BytesTrackIDForParticipantID(typ BytesTrackType, participantID livekit.ParticipantID) livekit.TrackID { + return livekit.TrackID(fmt.Sprintf("%s%s%s", utils.TrackPrefix, typ, participantID)) +} diff --git a/livekit/pkg/telemetry/stats.go b/livekit/pkg/telemetry/stats.go new file mode 100644 index 0000000..7ff082a --- /dev/null +++ b/livekit/pkg/telemetry/stats.go @@ -0,0 +1,114 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/protocol/livekit" +) + +type StatsKey struct { + country string + streamType livekit.StreamType + participantID livekit.ParticipantID + trackID livekit.TrackID + trackSource livekit.TrackSource + trackType livekit.TrackType + track bool +} + +func StatsKeyForTrack( + country string, + streamType livekit.StreamType, + participantID livekit.ParticipantID, + trackID livekit.TrackID, + trackSource livekit.TrackSource, + trackType livekit.TrackType, +) StatsKey { + return StatsKey{ + country: country, + streamType: streamType, + participantID: participantID, + trackID: trackID, + trackSource: trackSource, + trackType: trackType, + track: true, + } +} + +func StatsKeyForData( + country string, + streamType livekit.StreamType, + participantID livekit.ParticipantID, + trackID livekit.TrackID, +) StatsKey { + return StatsKey{ + country: country, + streamType: streamType, + participantID: participantID, + trackID: trackID, + } +} + +func (t *telemetryService) TrackStats(key StatsKey, stat *livekit.AnalyticsStat) { + t.enqueue(func() { + direction := prometheus.Incoming + if key.streamType == livekit.StreamType_DOWNSTREAM { + direction = prometheus.Outgoing + } + + nacks := uint32(0) + plis := uint32(0) + firs := uint32(0) + packets := uint32(0) + bytes := uint64(0) + retransmitBytes := uint64(0) + retransmitPackets := uint32(0) + for _, stream := range stat.Streams { + nacks += stream.Nacks + plis += stream.Plis + firs += stream.Firs + packets += stream.PrimaryPackets + stream.PaddingPackets + bytes += stream.PrimaryBytes + stream.PaddingBytes + if key.streamType == livekit.StreamType_DOWNSTREAM { + retransmitPackets += stream.RetransmitPackets + retransmitBytes += stream.RetransmitBytes + } else { + // for upstream, we don't account for these separately for now + packets += stream.RetransmitPackets + bytes += stream.RetransmitBytes + } + if key.track { + prometheus.RecordPacketLoss(key.country, direction, key.trackSource, key.trackType, stream.PacketsLost, stream.PrimaryPackets+stream.PaddingPackets) + prometheus.RecordPacketOutOfOrder(key.country, direction, key.trackSource, key.trackType, stream.PacketsOutOfOrder, stream.PrimaryPackets+stream.PaddingPackets) + prometheus.RecordRTT(key.country, direction, key.trackSource, key.trackType, stream.Rtt) + prometheus.RecordJitter(key.country, direction, key.trackSource, key.trackType, stream.Jitter) + } + } + prometheus.IncrementRTCP(key.country, direction, nacks, plis, firs) + prometheus.IncrementPackets(key.country, direction, uint64(packets), false) + prometheus.IncrementBytes(key.country, direction, bytes, false) + if retransmitPackets != 0 { + prometheus.IncrementPackets(key.country, direction, uint64(retransmitPackets), true) + } + if retransmitBytes != 0 { + prometheus.IncrementBytes(key.country, direction, retransmitBytes, true) + } + + if worker, ok := t.getWorker(key.participantID); ok { + worker.OnTrackStat(key.trackID, key.streamType, stat) + } + }) +} diff --git a/livekit/pkg/telemetry/stats_test.go b/livekit/pkg/telemetry/stats_test.go new file mode 100644 index 0000000..8db18df --- /dev/null +++ b/livekit/pkg/telemetry/stats_test.go @@ -0,0 +1,632 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/telemetry/telemetryfakes" +) + +func init() { + prometheus.Init("test", livekit.NodeType_SERVER) +} + +type telemetryServiceFixture struct { + sut telemetry.TelemetryService + analytics *telemetryfakes.FakeAnalyticsService +} + +func createFixture() *telemetryServiceFixture { + fixture := &telemetryServiceFixture{} + fixture.analytics = &telemetryfakes.FakeAnalyticsService{} + fixture.sut = telemetry.NewTelemetryService(nil, fixture.analytics) + return fixture +} + +func Test_ParticipantAndRoomDataAreSentWithAnalytics(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{Sid: "RoomSid", Name: "RoomName"} + partSID := livekit.ParticipantID("part1") + clientInfo := &livekit.ClientInfo{Sdk: 2} + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, clientInfo, nil, true, guard) + + // do + packet := 33 + stat := &livekit.AnalyticsStat{Streams: []*livekit.AnalyticsStream{{PrimaryBytes: uint64(packet)}}} + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, ""), stat) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_DOWNSTREAM, stats[0].Kind) + require.Equal(t, string(partSID), stats[0].ParticipantId) + require.Equal(t, room.Sid, stats[0].RoomId) + require.Equal(t, room.Name, stats[0].RoomName) +} + +func Test_OnDownstreamPackets(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + clientInfo := &livekit.ClientInfo{Sdk: 2} + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, clientInfo, nil, true, guard) + + // do + packets := []int{33, 23} + totalBytes := packets[0] + packets[1] + totalPackets := len(packets) + trackID := livekit.TrackID("trackID") + for i := range packets { + stat := &livekit.AnalyticsStat{Streams: []*livekit.AnalyticsStream{{PrimaryBytes: uint64(packets[i]), PrimaryPackets: uint32(1)}}} + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID), stat) + } + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_DOWNSTREAM, stats[0].Kind) + require.Equal(t, totalBytes, int(stats[0].Streams[0].PrimaryBytes)) + require.Equal(t, totalPackets, int(stats[0].Streams[0].PrimaryPackets)) + require.Equal(t, string(trackID), stats[0].TrackId) +} + +func Test_OnDownstreamPackets_SeveralTracks(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + clientInfo := &livekit.ClientInfo{Sdk: 2} + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, clientInfo, nil, true, guard) + + // do + packet1 := 33 + trackID1 := livekit.TrackID("trackID1") + stat1 := &livekit.AnalyticsStat{Streams: []*livekit.AnalyticsStream{{PrimaryBytes: uint64(packet1), PrimaryPackets: 1}}} + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID1), stat1) + + packet2 := 23 + trackID2 := livekit.TrackID("trackID2") + stat2 := &livekit.AnalyticsStat{Streams: []*livekit.AnalyticsStream{{PrimaryBytes: uint64(packet2), PrimaryPackets: 1}}} + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID2), stat2) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 2, len(stats)) + + found1 := false + found2 := false + for _, sentStat := range stats { + if livekit.TrackID(sentStat.TrackId) == trackID1 { + found1 = true + require.Equal(t, packet1, int(sentStat.Streams[0].PrimaryBytes)) + require.Equal(t, 1, int(sentStat.Streams[0].PrimaryPackets)) + } else if livekit.TrackID(sentStat.TrackId) == trackID2 { + found2 = true + require.Equal(t, packet2, int(sentStat.Streams[0].PrimaryBytes)) + require.Equal(t, 1, int(sentStat.Streams[0].PrimaryPackets)) + } + } + require.True(t, found1) + require.True(t, found2) +} + +func Test_OnDownStreamStat(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 1, + PrimaryPackets: 1, + PacketsLost: 3, + Nacks: 1, + Plis: 1, + Rtt: 23, + Jitter: 3, + }, + }, + } + trackID := livekit.TrackID("trackID1") + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID), stat1) + + stat2 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 2, + PrimaryPackets: 2, + PacketsLost: 4, + Nacks: 1, + Plis: 1, + Firs: 1, + Rtt: 10, + Jitter: 5, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID), stat2) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_DOWNSTREAM, stats[0].Kind) + require.Equal(t, 2, int(stats[0].Streams[0].Nacks)) + require.Equal(t, 2, int(stats[0].Streams[0].Plis)) + require.Equal(t, 1, int(stats[0].Streams[0].Firs)) + require.Equal(t, 23, int(stats[0].Streams[0].Rtt)) // max of RTT + require.Equal(t, 5, int(stats[0].Streams[0].Jitter)) // max of jitter + require.Equal(t, 7, int(stats[0].Streams[0].PacketsLost)) // coalesced delta packet losses + require.Equal(t, string(trackID), stats[0].TrackId) +} + +func Test_PacketLostDiffShouldBeSentToTelemetry(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + trackID := livekit.TrackID("trackID1") + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 1, + PrimaryPackets: 1, + PacketsLost: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID), stat1) // there should be bytes reported so that stats are sent + + // flush + fixture.flush() + + stat2 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 2, + PrimaryPackets: 2, + PacketsLost: 4, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID), stat2) + + // flush + fixture.flush() + + // test + require.Equal(t, 2, fixture.analytics.SendStatsCallCount()) // 2 calls to fixture.sut.FlushStats() + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_DOWNSTREAM, stats[0].Kind) + require.Equal(t, 1, int(stats[0].Streams[0].PacketsLost)) // see pkts1 + + _, stats = fixture.analytics.SendStatsArgsForCall(1) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_DOWNSTREAM, stats[0].Kind) + require.Equal(t, 4, int(stats[0].Streams[0].PacketsLost)) // delta loss should be sent as is +} + +func Test_OnDownStreamRTCP_SeveralTracks(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + trackID1 := livekit.TrackID("trackID1") + trackID2 := livekit.TrackID("trackID2") + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 1, + PrimaryPackets: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID1), stat1) // there should be bytes reported so that stats are sent + + stat2 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 2, + PrimaryPackets: 2, + Nacks: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID1), stat2) + + stat3 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 3, + PrimaryPackets: 3, + Firs: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, trackID2), stat3) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 2, len(stats)) + + found1 := false + found2 := false + for _, sentStat := range stats { + if livekit.TrackID(sentStat.TrackId) == trackID1 { + found1 = true + require.Equal(t, livekit.StreamType_DOWNSTREAM, sentStat.Kind) + require.Equal(t, 1, int(sentStat.Streams[0].Nacks)) // see pkts1 above + } else if livekit.TrackID(sentStat.TrackId) == trackID2 { + found2 = true + require.Equal(t, livekit.StreamType_DOWNSTREAM, sentStat.Kind) + require.Equal(t, 1, int(sentStat.Streams[0].Firs)) // see pkts2 above + } + } + require.True(t, found1) + require.True(t, found2) +} + +func Test_OnUpstreamStat(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 1, + PrimaryPackets: 1, + PacketsLost: 3, + Nacks: 1, + Plis: 1, + Firs: 1, + Rtt: 13, + Jitter: 5, + }, + }, + } + trackID := livekit.TrackID("trackID") + + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID), stat1) + + stat2 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 2, + PrimaryPackets: 2, + PacketsLost: 4, + Nacks: 1, + Plis: 1, + Firs: 1, + Rtt: 33, + Jitter: 2, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID), stat2) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_UPSTREAM, stats[0].Kind) + require.Equal(t, 2, int(stats[0].Streams[0].Nacks)) + require.Equal(t, 2, int(stats[0].Streams[0].Plis)) + require.Equal(t, 2, int(stats[0].Streams[0].Firs)) + require.Equal(t, 33, int(stats[0].Streams[0].Rtt)) // max of RTT + require.Equal(t, 5, int(stats[0].Streams[0].Jitter)) // max of jitter + require.Equal(t, 7, int(stats[0].Streams[0].PacketsLost)) // coalesced delta packet losses + require.Equal(t, string(trackID), stats[0].TrackId) +} + +func Test_OnUpstreamRTCP_SeveralTracks(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + identity := livekit.ParticipantIdentity("part1Identity") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID), Identity: string(identity)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // there should be bytes reported so that stats are sent + totalBytes := 1 + totalPackets := 1 + trackID1 := livekit.TrackID("trackID1") + trackID2 := livekit.TrackID("trackID2") + + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: uint64(totalBytes), + PrimaryPackets: uint32(totalPackets), + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID1), stat1) + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID2), stat1) // using same buffer is not correct but for test it is fine + + // do + totalBytes++ + totalPackets++ + stat2 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: uint64(totalBytes), + PrimaryPackets: uint32(totalPackets), + Nacks: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID1), stat2) + + stat3 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: uint64(totalBytes), + PrimaryPackets: uint32(totalPackets), + Firs: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID2), stat3) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 2, len(stats)) + + found1 := false + found2 := false + for _, sentStat := range stats { + if livekit.TrackID(sentStat.TrackId) == trackID1 { + found1 = true + require.Equal(t, livekit.StreamType_UPSTREAM, sentStat.Kind) + require.Equal(t, 1, int(sentStat.Streams[0].Nacks)) // see pkts1 above + } else if livekit.TrackID(sentStat.TrackId) == trackID2 { + found2 = true + require.Equal(t, livekit.StreamType_UPSTREAM, sentStat.Kind) + require.Equal(t, 1, int(sentStat.Streams[0].Firs)) // see pkts2 above + } + require.Equal(t, 3, int(sentStat.Streams[0].PrimaryBytes)) + require.Equal(t, 3, int(sentStat.Streams[0].PrimaryPackets)) + } + require.True(t, found1) + require.True(t, found2) + + // remove 1 track - track stats were flushed above, so no more calls to SendStats + fixture.sut.TrackUnpublished(context.Background(), partSID, identity, &livekit.TrackInfo{Sid: string(trackID2)}, true) + + // flush + fixture.flush() + + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) +} + +func Test_AnalyticsSentWhenParticipantLeaves(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := "part1" + participantInfo := &livekit.ParticipantInfo{Sid: partSID} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + fixture.sut.ParticipantLeft(context.Background(), room, participantInfo, true, guard) + + // should not be called if there are no track stats + time.Sleep(time.Millisecond * 500) + require.Equal(t, 0, fixture.analytics.SendStatsCallCount()) +} + +func Test_AddUpTrack(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + var totalBytes uint64 = 3 + var totalPackets uint32 = 3 + + stat := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: totalBytes, + PrimaryPackets: totalPackets, + }, + }, + } + trackID := livekit.TrackID("trackID") + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID), stat) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_UPSTREAM, stats[0].Kind) + require.Equal(t, totalBytes, stats[0].Streams[0].PrimaryBytes) + require.Equal(t, totalPackets, stats[0].Streams[0].PrimaryPackets) + require.Equal(t, string(trackID), stats[0].TrackId) +} + +func Test_AddUpTrack_SeveralBuffers_Simulcast(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + trackID := livekit.TrackID("trackID") + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 1, + PrimaryPackets: 1, + }, + { + PrimaryBytes: 2, + PrimaryPackets: 2, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, trackID), stat1) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 1, len(stats)) + require.Equal(t, livekit.StreamType_UPSTREAM, stats[0].Kind) + // should be a consolidated stream + require.Equal(t, stat1.Streams[0].PrimaryBytes+stat1.Streams[1].PrimaryBytes, stats[0].Streams[0].PrimaryBytes) + require.Equal(t, stat1.Streams[0].PrimaryPackets+stat1.Streams[1].PrimaryPackets, stats[0].Streams[0].PrimaryPackets) + require.Equal(t, string(trackID), stats[0].TrackId) +} + +func Test_BothDownstreamAndUpstreamStatsAreSentTogether(t *testing.T) { + fixture := createFixture() + + // prepare + room := &livekit.Room{} + partSID := livekit.ParticipantID("part1") + participantInfo := &livekit.ParticipantInfo{Sid: string(partSID)} + guard := &telemetry.ReferenceGuard{} + fixture.sut.ParticipantJoined(context.Background(), room, participantInfo, nil, nil, true, guard) + + // do + // upstream bytes + stat1 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 3, + PrimaryPackets: 3, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_UPSTREAM, partSID, "trackID"), stat1) + // downstream bytes + stat2 := &livekit.AnalyticsStat{ + Streams: []*livekit.AnalyticsStream{ + { + PrimaryBytes: 1, + PrimaryPackets: 1, + }, + }, + } + fixture.sut.TrackStats(telemetry.StatsKeyForData("test", livekit.StreamType_DOWNSTREAM, partSID, "trackID1"), stat2) + + // flush + fixture.flush() + + // test + require.Equal(t, 1, fixture.analytics.SendStatsCallCount()) + _, stats := fixture.analytics.SendStatsArgsForCall(0) + require.Equal(t, 2, len(stats)) + require.Equal(t, livekit.StreamType_UPSTREAM, stats[0].Kind) + require.Equal(t, livekit.StreamType_DOWNSTREAM, stats[1].Kind) +} + +func (f *telemetryServiceFixture) flush() { + time.Sleep(time.Millisecond * 500) + f.sut.FlushStats() +} diff --git a/livekit/pkg/telemetry/statsconn.go b/livekit/pkg/telemetry/statsconn.go new file mode 100644 index 0000000..308dc38 --- /dev/null +++ b/livekit/pkg/telemetry/statsconn.go @@ -0,0 +1,132 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "net" + + "github.com/pion/turn/v4" + + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" +) + +type Listener struct { + net.Listener +} + +func NewListener(l net.Listener) *Listener { + return &Listener{Listener: l} +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return NewConn(conn, prometheus.Incoming), nil +} + +type Conn struct { + net.Conn + direction prometheus.Direction +} + +func NewConn(c net.Conn, direction prometheus.Direction) *Conn { + prometheus.AddConnection(direction) + return &Conn{Conn: c, direction: direction} +} + +func (c *Conn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if n > 0 { + prometheus.IncrementBytes("", prometheus.Incoming, uint64(n), false) + prometheus.IncrementPackets("", prometheus.Incoming, 1, false) + } + return +} + +func (c *Conn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if n > 0 { + prometheus.IncrementBytes("", prometheus.Outgoing, uint64(n), false) + prometheus.IncrementPackets("", prometheus.Outgoing, 1, false) + } + return +} + +func (c *Conn) Close() error { + prometheus.SubConnection(c.direction) + return c.Conn.Close() +} + +type PacketConn struct { + net.PacketConn + direction prometheus.Direction +} + +func NewPacketConn(c net.PacketConn, direction prometheus.Direction) *PacketConn { + prometheus.AddConnection(direction) + return &PacketConn{PacketConn: c, direction: direction} +} + +func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if n > 0 { + prometheus.IncrementBytes("", prometheus.Incoming, uint64(n), false) + prometheus.IncrementPackets("", prometheus.Incoming, 1, false) + } + return +} + +func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + n, err = c.PacketConn.WriteTo(p, addr) + if n > 0 { + prometheus.IncrementBytes("", prometheus.Outgoing, uint64(n), false) + prometheus.IncrementPackets("", prometheus.Outgoing, 1, false) + } + return +} + +func (c *PacketConn) Close() error { + prometheus.SubConnection(c.direction) + return c.PacketConn.Close() +} + +type RelayAddressGenerator struct { + turn.RelayAddressGenerator +} + +func NewRelayAddressGenerator(g turn.RelayAddressGenerator) *RelayAddressGenerator { + return &RelayAddressGenerator{RelayAddressGenerator: g} +} + +func (g *RelayAddressGenerator) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { + conn, addr, err := g.RelayAddressGenerator.AllocatePacketConn(network, requestedPort) + if err != nil { + return nil, addr, err + } + + return NewPacketConn(conn, prometheus.Outgoing), addr, err +} + +func (g *RelayAddressGenerator) AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error) { + conn, addr, err := g.RelayAddressGenerator.AllocateConn(network, requestedPort) + if err != nil { + return nil, addr, err + } + + return NewConn(conn, prometheus.Outgoing), addr, err +} diff --git a/livekit/pkg/telemetry/statsworker.go b/livekit/pkg/telemetry/statsworker.go new file mode 100644 index 0000000..51eec63 --- /dev/null +++ b/livekit/pkg/telemetry/statsworker.go @@ -0,0 +1,362 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "context" + "sync" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + protoutils "github.com/livekit/protocol/utils" +) + +type ReferenceGuard struct { + activated, released bool +} + +type ReferenceCount struct { + count int +} + +func (s *ReferenceCount) Activate(guard *ReferenceGuard) { + if guard != nil && !guard.activated { + guard.activated = true + s.count++ + } +} + +func (s *ReferenceCount) Release(guard *ReferenceGuard) bool { + if guard == nil || !guard.activated || guard.released { + return false + } + guard.released = true + s.count-- + return s.count == 0 +} + +// StatsWorker handles participant stats +type StatsWorker struct { + next *StatsWorker + + ctx context.Context + t TelemetryService + roomID livekit.RoomID + roomName livekit.RoomName + participantID livekit.ParticipantID + participantIdentity livekit.ParticipantIdentity + isConnected bool + + lock sync.RWMutex + outgoingPerTrack map[livekit.TrackID][]*livekit.AnalyticsStat + incomingPerTrack map[livekit.TrackID][]*livekit.AnalyticsStat + refCount ReferenceCount + closedAt time.Time +} + +func newStatsWorker( + ctx context.Context, + t TelemetryService, + roomID livekit.RoomID, + roomName livekit.RoomName, + participantID livekit.ParticipantID, + identity livekit.ParticipantIdentity, + guard *ReferenceGuard, +) *StatsWorker { + s := &StatsWorker{ + ctx: ctx, + t: t, + roomID: roomID, + roomName: roomName, + participantID: participantID, + participantIdentity: identity, + outgoingPerTrack: make(map[livekit.TrackID][]*livekit.AnalyticsStat), + incomingPerTrack: make(map[livekit.TrackID][]*livekit.AnalyticsStat), + } + s.refCount.Activate(guard) + return s +} + +func (s *StatsWorker) OnTrackStat(trackID livekit.TrackID, direction livekit.StreamType, stat *livekit.AnalyticsStat) { + s.lock.Lock() + if direction == livekit.StreamType_DOWNSTREAM { + s.outgoingPerTrack[trackID] = append(s.outgoingPerTrack[trackID], stat) + } else { + s.incomingPerTrack[trackID] = append(s.incomingPerTrack[trackID], stat) + } + s.lock.Unlock() +} + +func (s *StatsWorker) ParticipantID() livekit.ParticipantID { + return s.participantID +} + +func (s *StatsWorker) SetConnected() { + s.lock.Lock() + s.isConnected = true + s.lock.Unlock() +} + +func (s *StatsWorker) IsConnected() bool { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.isConnected +} + +func (s *StatsWorker) Flush(now time.Time, closeWait time.Duration) bool { + ts := timestamppb.New(now) + + s.lock.Lock() + stats := make([]*livekit.AnalyticsStat, 0, len(s.incomingPerTrack)+len(s.outgoingPerTrack)) + + incomingPerTrack := s.incomingPerTrack + s.incomingPerTrack = make(map[livekit.TrackID][]*livekit.AnalyticsStat) + + outgoingPerTrack := s.outgoingPerTrack + s.outgoingPerTrack = make(map[livekit.TrackID][]*livekit.AnalyticsStat) + + closed := !s.closedAt.IsZero() && now.Sub(s.closedAt) > closeWait + s.lock.Unlock() + + stats = s.collectStats(ts, livekit.StreamType_UPSTREAM, incomingPerTrack, stats) + stats = s.collectStats(ts, livekit.StreamType_DOWNSTREAM, outgoingPerTrack, stats) + if len(stats) > 0 { + s.t.SendStats(s.ctx, stats) + } + + return closed +} + +func (s *StatsWorker) Close(guard *ReferenceGuard) bool { + s.lock.Lock() + defer s.lock.Unlock() + + if !s.refCount.Release(guard) { + return false + } + + ok := s.closedAt.IsZero() + if ok { + s.closedAt = time.Now() + } + return ok +} + +func (s *StatsWorker) Closed(guard *ReferenceGuard) bool { + s.lock.Lock() + defer s.lock.Unlock() + if s.closedAt.IsZero() { + s.refCount.Activate(guard) + return false + } + return true +} + +func (s *StatsWorker) collectStats( + ts *timestamppb.Timestamp, + streamType livekit.StreamType, + perTrack map[livekit.TrackID][]*livekit.AnalyticsStat, + stats []*livekit.AnalyticsStat, +) []*livekit.AnalyticsStat { + for trackID, analyticsStats := range perTrack { + coalesced := coalesce(analyticsStats) + if coalesced == nil { + continue + } + + coalesced.TimeStamp = ts + coalesced.TrackId = string(trackID) + coalesced.Kind = streamType + coalesced.RoomId = string(s.roomID) + coalesced.ParticipantId = string(s.participantID) + coalesced.RoomName = string(s.roomName) + stats = append(stats, coalesced) + } + return stats +} + +// ------------------------------------------------------------------------- + +// create a single stream and single video layer post aggregation +func coalesce(stats []*livekit.AnalyticsStat) *livekit.AnalyticsStat { + if len(stats) == 0 { + return nil + } + + // find aggregates across streams + startTime := time.Time{} + endTime := time.Time{} + scoreSum := float32(0.0) // used for average + minScore := float32(0.0) // min score in batched stats + var scores []float32 // used for median + maxRtt := uint32(0) + maxJitter := uint32(0) + coalescedVideoLayers := make(map[int32]*livekit.AnalyticsVideoLayer) + coalescedStream := &livekit.AnalyticsStream{} + for _, stat := range stats { + if !isValid(stat) { + logger.Warnw("telemetry skipping invalid stat", nil, "stat", stat) + continue + } + + // only consider non-zero scores + if stat.Score > 0 { + if minScore == 0 { + minScore = stat.Score + } else if stat.Score < minScore { + minScore = stat.Score + } + scoreSum += stat.Score + scores = append(scores, stat.Score) + } + + for _, analyticsStream := range stat.Streams { + start := analyticsStream.StartTime.AsTime() + if startTime.IsZero() || startTime.After(start) { + startTime = start + } + + end := analyticsStream.EndTime.AsTime() + if endTime.IsZero() || endTime.Before(end) { + endTime = end + } + + if analyticsStream.Rtt > maxRtt { + maxRtt = analyticsStream.Rtt + } + + if analyticsStream.Jitter > maxJitter { + maxJitter = analyticsStream.Jitter + } + + coalescedStream.PrimaryPackets += analyticsStream.PrimaryPackets + coalescedStream.PrimaryBytes += analyticsStream.PrimaryBytes + coalescedStream.RetransmitPackets += analyticsStream.RetransmitPackets + coalescedStream.RetransmitBytes += analyticsStream.RetransmitBytes + coalescedStream.PaddingPackets += analyticsStream.PaddingPackets + coalescedStream.PaddingBytes += analyticsStream.PaddingBytes + coalescedStream.PacketsLost += analyticsStream.PacketsLost + coalescedStream.PacketsOutOfOrder += analyticsStream.PacketsOutOfOrder + coalescedStream.Frames += analyticsStream.Frames + coalescedStream.Nacks += analyticsStream.Nacks + coalescedStream.Plis += analyticsStream.Plis + coalescedStream.Firs += analyticsStream.Firs + + for _, videoLayer := range analyticsStream.VideoLayers { + coalescedVideoLayer := coalescedVideoLayers[videoLayer.Layer] + if coalescedVideoLayer == nil { + coalescedVideoLayer = protoutils.CloneProto(videoLayer) + coalescedVideoLayers[videoLayer.Layer] = coalescedVideoLayer + } else { + coalescedVideoLayer.Packets += videoLayer.Packets + coalescedVideoLayer.Bytes += videoLayer.Bytes + coalescedVideoLayer.Frames += videoLayer.Frames + } + } + } + } + coalescedStream.StartTime = timestamppb.New(startTime) + coalescedStream.EndTime = timestamppb.New(endTime) + coalescedStream.Rtt = maxRtt + coalescedStream.Jitter = maxJitter + + // whittle it down to one video layer, just the max available layer + maxVideoLayer := int32(-1) + for _, coalescedVideoLayer := range coalescedVideoLayers { + if maxVideoLayer == -1 || maxVideoLayer < coalescedVideoLayer.Layer { + maxVideoLayer = coalescedVideoLayer.Layer + coalescedStream.VideoLayers = []*livekit.AnalyticsVideoLayer{coalescedVideoLayer} + } + } + + stat := &livekit.AnalyticsStat{ + MinScore: minScore, + MedianScore: utils.MedianFloat32(scores), + Streams: []*livekit.AnalyticsStream{coalescedStream}, + Mime: stats[len(stats)-1].Mime, // use the latest Mime + } + numScores := len(scores) + if numScores > 0 { + stat.Score = scoreSum / float32(numScores) + } + return stat +} + +type CondensedStat struct { + StartTime time.Time + EndTime time.Time + Bytes uint64 + Packets uint32 + PacketsLost uint32 + Frames uint32 +} + +func CondenseStat(stat *livekit.AnalyticsStat) (ps CondensedStat, ok bool) { + if ok = isValid(stat); !ok { + return + } + + for _, stream := range stat.Streams { + startTime := stream.StartTime.AsTime() + endTime := stream.EndTime.AsTime() + if ps.StartTime.IsZero() || startTime.Before(ps.StartTime) { + ps.StartTime = startTime + } + if endTime.After(ps.EndTime) { + ps.EndTime = endTime + } + + ps.Bytes += stream.PrimaryBytes + ps.Packets += stream.PrimaryPackets + ps.PacketsLost += stream.PacketsLost + ps.Frames += stream.Frames + } + + return +} + +func isValid(stat *livekit.AnalyticsStat) bool { + for _, analyticsStream := range stat.Streams { + if int32(analyticsStream.PrimaryPackets) < 0 || + int64(analyticsStream.PrimaryBytes) < 0 || + int32(analyticsStream.RetransmitPackets) < 0 || + int64(analyticsStream.RetransmitBytes) < 0 || + int32(analyticsStream.PaddingPackets) < 0 || + int64(analyticsStream.PaddingBytes) < 0 || + int32(analyticsStream.PacketsLost) < 0 || + int32(analyticsStream.PacketsOutOfOrder) < 0 || + int32(analyticsStream.Frames) < 0 || + int32(analyticsStream.Nacks) < 0 || + int32(analyticsStream.Plis) < 0 || + int32(analyticsStream.Firs) < 0 { + return false + } + + for _, videoLayer := range analyticsStream.VideoLayers { + if int32(videoLayer.Packets) < 0 || + int64(videoLayer.Bytes) < 0 || + int32(videoLayer.Frames) < 0 { + return false + } + } + } + + return true +} diff --git a/livekit/pkg/telemetry/statsworker_test.go b/livekit/pkg/telemetry/statsworker_test.go new file mode 100644 index 0000000..f7ae488 --- /dev/null +++ b/livekit/pkg/telemetry/statsworker_test.go @@ -0,0 +1,19 @@ +package telemetry + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStatsWorker(t *testing.T) { + t.Run("reference counted close works", func(t *testing.T) { + var g0, g1 ReferenceGuard + w := newStatsWorker(t.Context(), nil, "", "", "", "", &g0) + require.False(t, w.Closed(&g1)) + require.False(t, w.Close(&g0)) + require.False(t, w.Closed(&g1)) + require.True(t, w.Close(&g1)) + require.True(t, w.Closed(&g1)) + }) +} diff --git a/livekit/pkg/telemetry/telemetryfakes/fake_analytics_service.go b/livekit/pkg/telemetry/telemetryfakes/fake_analytics_service.go new file mode 100644 index 0000000..14c4dea --- /dev/null +++ b/livekit/pkg/telemetry/telemetryfakes/fake_analytics_service.go @@ -0,0 +1,234 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package telemetryfakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/observability/roomobs" +) + +type FakeAnalyticsService struct { + RoomProjectReporterStub func(context.Context) roomobs.ProjectReporter + roomProjectReporterMutex sync.RWMutex + roomProjectReporterArgsForCall []struct { + arg1 context.Context + } + roomProjectReporterReturns struct { + result1 roomobs.ProjectReporter + } + roomProjectReporterReturnsOnCall map[int]struct { + result1 roomobs.ProjectReporter + } + SendEventStub func(context.Context, *livekit.AnalyticsEvent) + sendEventMutex sync.RWMutex + sendEventArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsEvent + } + SendNodeRoomStatesStub func(context.Context, *livekit.AnalyticsNodeRooms) + sendNodeRoomStatesMutex sync.RWMutex + sendNodeRoomStatesArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + } + SendStatsStub func(context.Context, []*livekit.AnalyticsStat) + sendStatsMutex sync.RWMutex + sendStatsArgsForCall []struct { + arg1 context.Context + arg2 []*livekit.AnalyticsStat + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeAnalyticsService) RoomProjectReporter(arg1 context.Context) roomobs.ProjectReporter { + fake.roomProjectReporterMutex.Lock() + ret, specificReturn := fake.roomProjectReporterReturnsOnCall[len(fake.roomProjectReporterArgsForCall)] + fake.roomProjectReporterArgsForCall = append(fake.roomProjectReporterArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.RoomProjectReporterStub + fakeReturns := fake.roomProjectReporterReturns + fake.recordInvocation("RoomProjectReporter", []interface{}{arg1}) + fake.roomProjectReporterMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeAnalyticsService) RoomProjectReporterCallCount() int { + fake.roomProjectReporterMutex.RLock() + defer fake.roomProjectReporterMutex.RUnlock() + return len(fake.roomProjectReporterArgsForCall) +} + +func (fake *FakeAnalyticsService) RoomProjectReporterCalls(stub func(context.Context) roomobs.ProjectReporter) { + fake.roomProjectReporterMutex.Lock() + defer fake.roomProjectReporterMutex.Unlock() + fake.RoomProjectReporterStub = stub +} + +func (fake *FakeAnalyticsService) RoomProjectReporterArgsForCall(i int) context.Context { + fake.roomProjectReporterMutex.RLock() + defer fake.roomProjectReporterMutex.RUnlock() + argsForCall := fake.roomProjectReporterArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeAnalyticsService) RoomProjectReporterReturns(result1 roomobs.ProjectReporter) { + fake.roomProjectReporterMutex.Lock() + defer fake.roomProjectReporterMutex.Unlock() + fake.RoomProjectReporterStub = nil + fake.roomProjectReporterReturns = struct { + result1 roomobs.ProjectReporter + }{result1} +} + +func (fake *FakeAnalyticsService) RoomProjectReporterReturnsOnCall(i int, result1 roomobs.ProjectReporter) { + fake.roomProjectReporterMutex.Lock() + defer fake.roomProjectReporterMutex.Unlock() + fake.RoomProjectReporterStub = nil + if fake.roomProjectReporterReturnsOnCall == nil { + fake.roomProjectReporterReturnsOnCall = make(map[int]struct { + result1 roomobs.ProjectReporter + }) + } + fake.roomProjectReporterReturnsOnCall[i] = struct { + result1 roomobs.ProjectReporter + }{result1} +} + +func (fake *FakeAnalyticsService) SendEvent(arg1 context.Context, arg2 *livekit.AnalyticsEvent) { + fake.sendEventMutex.Lock() + fake.sendEventArgsForCall = append(fake.sendEventArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsEvent + }{arg1, arg2}) + stub := fake.SendEventStub + fake.recordInvocation("SendEvent", []interface{}{arg1, arg2}) + fake.sendEventMutex.Unlock() + if stub != nil { + fake.SendEventStub(arg1, arg2) + } +} + +func (fake *FakeAnalyticsService) SendEventCallCount() int { + fake.sendEventMutex.RLock() + defer fake.sendEventMutex.RUnlock() + return len(fake.sendEventArgsForCall) +} + +func (fake *FakeAnalyticsService) SendEventCalls(stub func(context.Context, *livekit.AnalyticsEvent)) { + fake.sendEventMutex.Lock() + defer fake.sendEventMutex.Unlock() + fake.SendEventStub = stub +} + +func (fake *FakeAnalyticsService) SendEventArgsForCall(i int) (context.Context, *livekit.AnalyticsEvent) { + fake.sendEventMutex.RLock() + defer fake.sendEventMutex.RUnlock() + argsForCall := fake.sendEventArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAnalyticsService) SendNodeRoomStates(arg1 context.Context, arg2 *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.Lock() + fake.sendNodeRoomStatesArgsForCall = append(fake.sendNodeRoomStatesArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + }{arg1, arg2}) + stub := fake.SendNodeRoomStatesStub + fake.recordInvocation("SendNodeRoomStates", []interface{}{arg1, arg2}) + fake.sendNodeRoomStatesMutex.Unlock() + if stub != nil { + fake.SendNodeRoomStatesStub(arg1, arg2) + } +} + +func (fake *FakeAnalyticsService) SendNodeRoomStatesCallCount() int { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + return len(fake.sendNodeRoomStatesArgsForCall) +} + +func (fake *FakeAnalyticsService) SendNodeRoomStatesCalls(stub func(context.Context, *livekit.AnalyticsNodeRooms)) { + fake.sendNodeRoomStatesMutex.Lock() + defer fake.sendNodeRoomStatesMutex.Unlock() + fake.SendNodeRoomStatesStub = stub +} + +func (fake *FakeAnalyticsService) SendNodeRoomStatesArgsForCall(i int) (context.Context, *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + argsForCall := fake.sendNodeRoomStatesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAnalyticsService) SendStats(arg1 context.Context, arg2 []*livekit.AnalyticsStat) { + var arg2Copy []*livekit.AnalyticsStat + if arg2 != nil { + arg2Copy = make([]*livekit.AnalyticsStat, len(arg2)) + copy(arg2Copy, arg2) + } + fake.sendStatsMutex.Lock() + fake.sendStatsArgsForCall = append(fake.sendStatsArgsForCall, struct { + arg1 context.Context + arg2 []*livekit.AnalyticsStat + }{arg1, arg2Copy}) + stub := fake.SendStatsStub + fake.recordInvocation("SendStats", []interface{}{arg1, arg2Copy}) + fake.sendStatsMutex.Unlock() + if stub != nil { + fake.SendStatsStub(arg1, arg2) + } +} + +func (fake *FakeAnalyticsService) SendStatsCallCount() int { + fake.sendStatsMutex.RLock() + defer fake.sendStatsMutex.RUnlock() + return len(fake.sendStatsArgsForCall) +} + +func (fake *FakeAnalyticsService) SendStatsCalls(stub func(context.Context, []*livekit.AnalyticsStat)) { + fake.sendStatsMutex.Lock() + defer fake.sendStatsMutex.Unlock() + fake.SendStatsStub = stub +} + +func (fake *FakeAnalyticsService) SendStatsArgsForCall(i int) (context.Context, []*livekit.AnalyticsStat) { + fake.sendStatsMutex.RLock() + defer fake.sendStatsMutex.RUnlock() + argsForCall := fake.sendStatsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeAnalyticsService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeAnalyticsService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ telemetry.AnalyticsService = new(FakeAnalyticsService) diff --git a/livekit/pkg/telemetry/telemetryfakes/fake_telemetry_service.go b/livekit/pkg/telemetry/telemetryfakes/fake_telemetry_service.go new file mode 100644 index 0000000..bc9109d --- /dev/null +++ b/livekit/pkg/telemetry/telemetryfakes/fake_telemetry_service.go @@ -0,0 +1,1642 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package telemetryfakes + +import ( + "context" + "sync" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/telemetry" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/observability/roomobs" +) + +type FakeTelemetryService struct { + APICallStub func(context.Context, *livekit.APICallInfo) + aPICallMutex sync.RWMutex + aPICallArgsForCall []struct { + arg1 context.Context + arg2 *livekit.APICallInfo + } + EgressEndedStub func(context.Context, *livekit.EgressInfo) + egressEndedMutex sync.RWMutex + egressEndedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.EgressInfo + } + EgressStartedStub func(context.Context, *livekit.EgressInfo) + egressStartedMutex sync.RWMutex + egressStartedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.EgressInfo + } + EgressUpdatedStub func(context.Context, *livekit.EgressInfo) + egressUpdatedMutex sync.RWMutex + egressUpdatedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.EgressInfo + } + FlushStatsStub func() + flushStatsMutex sync.RWMutex + flushStatsArgsForCall []struct { + } + IngressCreatedStub func(context.Context, *livekit.IngressInfo) + ingressCreatedMutex sync.RWMutex + ingressCreatedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + IngressDeletedStub func(context.Context, *livekit.IngressInfo) + ingressDeletedMutex sync.RWMutex + ingressDeletedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + IngressEndedStub func(context.Context, *livekit.IngressInfo) + ingressEndedMutex sync.RWMutex + ingressEndedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + IngressStartedStub func(context.Context, *livekit.IngressInfo) + ingressStartedMutex sync.RWMutex + ingressStartedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + IngressUpdatedStub func(context.Context, *livekit.IngressInfo) + ingressUpdatedMutex sync.RWMutex + ingressUpdatedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.IngressInfo + } + LocalRoomStateStub func(context.Context, *livekit.AnalyticsNodeRooms) + localRoomStateMutex sync.RWMutex + localRoomStateArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + } + NotifyEgressEventStub func(context.Context, string, *livekit.EgressInfo) + notifyEgressEventMutex sync.RWMutex + notifyEgressEventArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 *livekit.EgressInfo + } + ParticipantActiveStub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, *livekit.AnalyticsClientMeta, bool, *telemetry.ReferenceGuard) + participantActiveMutex sync.RWMutex + participantActiveArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 *livekit.AnalyticsClientMeta + arg5 bool + arg6 *telemetry.ReferenceGuard + } + ParticipantJoinedStub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, *livekit.ClientInfo, *livekit.AnalyticsClientMeta, bool, *telemetry.ReferenceGuard) + participantJoinedMutex sync.RWMutex + participantJoinedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 *livekit.ClientInfo + arg5 *livekit.AnalyticsClientMeta + arg6 bool + arg7 *telemetry.ReferenceGuard + } + ParticipantLeftStub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, bool, *telemetry.ReferenceGuard) + participantLeftMutex sync.RWMutex + participantLeftArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 bool + arg5 *telemetry.ReferenceGuard + } + ParticipantResumedStub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, livekit.NodeID, livekit.ReconnectReason) + participantResumedMutex sync.RWMutex + participantResumedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 livekit.NodeID + arg5 livekit.ReconnectReason + } + ReportStub func(context.Context, *livekit.ReportInfo) + reportMutex sync.RWMutex + reportArgsForCall []struct { + arg1 context.Context + arg2 *livekit.ReportInfo + } + RoomEndedStub func(context.Context, *livekit.Room) + roomEndedMutex sync.RWMutex + roomEndedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + } + RoomProjectReporterStub func(context.Context) roomobs.ProjectReporter + roomProjectReporterMutex sync.RWMutex + roomProjectReporterArgsForCall []struct { + arg1 context.Context + } + roomProjectReporterReturns struct { + result1 roomobs.ProjectReporter + } + roomProjectReporterReturnsOnCall map[int]struct { + result1 roomobs.ProjectReporter + } + RoomStartedStub func(context.Context, *livekit.Room) + roomStartedMutex sync.RWMutex + roomStartedArgsForCall []struct { + arg1 context.Context + arg2 *livekit.Room + } + SendEventStub func(context.Context, *livekit.AnalyticsEvent) + sendEventMutex sync.RWMutex + sendEventArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsEvent + } + SendNodeRoomStatesStub func(context.Context, *livekit.AnalyticsNodeRooms) + sendNodeRoomStatesMutex sync.RWMutex + sendNodeRoomStatesArgsForCall []struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + } + SendStatsStub func(context.Context, []*livekit.AnalyticsStat) + sendStatsMutex sync.RWMutex + sendStatsArgsForCall []struct { + arg1 context.Context + arg2 []*livekit.AnalyticsStat + } + TrackMaxSubscribedVideoQualityStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality) + trackMaxSubscribedVideoQualityMutex sync.RWMutex + trackMaxSubscribedVideoQualityArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + arg4 mime.MimeType + arg5 livekit.VideoQuality + } + TrackMutedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) + trackMutedMutex sync.RWMutex + trackMutedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + } + TrackPublishRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats) + trackPublishRTPStatsMutex sync.RWMutex + trackPublishRTPStatsArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.TrackID + arg4 mime.MimeType + arg5 int + arg6 *livekit.RTPStats + } + TrackPublishRequestedStub func(context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo) + trackPublishRequestedMutex sync.RWMutex + trackPublishRequestedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.ParticipantIdentity + arg4 *livekit.TrackInfo + } + TrackPublishedStub func(context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo, bool) + trackPublishedMutex sync.RWMutex + trackPublishedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.ParticipantIdentity + arg4 *livekit.TrackInfo + arg5 bool + } + TrackPublishedUpdateStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) + trackPublishedUpdateMutex sync.RWMutex + trackPublishedUpdateArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + } + TrackStatsStub func(telemetry.StatsKey, *livekit.AnalyticsStat) + trackStatsMutex sync.RWMutex + trackStatsArgsForCall []struct { + arg1 telemetry.StatsKey + arg2 *livekit.AnalyticsStat + } + TrackSubscribeFailedStub func(context.Context, livekit.ParticipantID, livekit.TrackID, error, bool) + trackSubscribeFailedMutex sync.RWMutex + trackSubscribeFailedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.TrackID + arg4 error + arg5 bool + } + TrackSubscribeRTPStatsStub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats) + trackSubscribeRTPStatsMutex sync.RWMutex + trackSubscribeRTPStatsArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.TrackID + arg4 mime.MimeType + arg5 *livekit.RTPStats + } + TrackSubscribeRequestedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) + trackSubscribeRequestedMutex sync.RWMutex + trackSubscribeRequestedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + } + TrackSubscribedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, *livekit.ParticipantInfo, bool) + trackSubscribedMutex sync.RWMutex + trackSubscribedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + arg4 *livekit.ParticipantInfo + arg5 bool + } + TrackUnmutedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo) + trackUnmutedMutex sync.RWMutex + trackUnmutedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + } + TrackUnpublishedStub func(context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo, bool) + trackUnpublishedMutex sync.RWMutex + trackUnpublishedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.ParticipantIdentity + arg4 *livekit.TrackInfo + arg5 bool + } + TrackUnsubscribedStub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, bool) + trackUnsubscribedMutex sync.RWMutex + trackUnsubscribedArgsForCall []struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + arg4 bool + } + WebhookStub func(context.Context, *livekit.WebhookInfo) + webhookMutex sync.RWMutex + webhookArgsForCall []struct { + arg1 context.Context + arg2 *livekit.WebhookInfo + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeTelemetryService) APICall(arg1 context.Context, arg2 *livekit.APICallInfo) { + fake.aPICallMutex.Lock() + fake.aPICallArgsForCall = append(fake.aPICallArgsForCall, struct { + arg1 context.Context + arg2 *livekit.APICallInfo + }{arg1, arg2}) + stub := fake.APICallStub + fake.recordInvocation("APICall", []interface{}{arg1, arg2}) + fake.aPICallMutex.Unlock() + if stub != nil { + fake.APICallStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) APICallCallCount() int { + fake.aPICallMutex.RLock() + defer fake.aPICallMutex.RUnlock() + return len(fake.aPICallArgsForCall) +} + +func (fake *FakeTelemetryService) APICallCalls(stub func(context.Context, *livekit.APICallInfo)) { + fake.aPICallMutex.Lock() + defer fake.aPICallMutex.Unlock() + fake.APICallStub = stub +} + +func (fake *FakeTelemetryService) APICallArgsForCall(i int) (context.Context, *livekit.APICallInfo) { + fake.aPICallMutex.RLock() + defer fake.aPICallMutex.RUnlock() + argsForCall := fake.aPICallArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) EgressEnded(arg1 context.Context, arg2 *livekit.EgressInfo) { + fake.egressEndedMutex.Lock() + fake.egressEndedArgsForCall = append(fake.egressEndedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.EgressInfo + }{arg1, arg2}) + stub := fake.EgressEndedStub + fake.recordInvocation("EgressEnded", []interface{}{arg1, arg2}) + fake.egressEndedMutex.Unlock() + if stub != nil { + fake.EgressEndedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) EgressEndedCallCount() int { + fake.egressEndedMutex.RLock() + defer fake.egressEndedMutex.RUnlock() + return len(fake.egressEndedArgsForCall) +} + +func (fake *FakeTelemetryService) EgressEndedCalls(stub func(context.Context, *livekit.EgressInfo)) { + fake.egressEndedMutex.Lock() + defer fake.egressEndedMutex.Unlock() + fake.EgressEndedStub = stub +} + +func (fake *FakeTelemetryService) EgressEndedArgsForCall(i int) (context.Context, *livekit.EgressInfo) { + fake.egressEndedMutex.RLock() + defer fake.egressEndedMutex.RUnlock() + argsForCall := fake.egressEndedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) EgressStarted(arg1 context.Context, arg2 *livekit.EgressInfo) { + fake.egressStartedMutex.Lock() + fake.egressStartedArgsForCall = append(fake.egressStartedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.EgressInfo + }{arg1, arg2}) + stub := fake.EgressStartedStub + fake.recordInvocation("EgressStarted", []interface{}{arg1, arg2}) + fake.egressStartedMutex.Unlock() + if stub != nil { + fake.EgressStartedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) EgressStartedCallCount() int { + fake.egressStartedMutex.RLock() + defer fake.egressStartedMutex.RUnlock() + return len(fake.egressStartedArgsForCall) +} + +func (fake *FakeTelemetryService) EgressStartedCalls(stub func(context.Context, *livekit.EgressInfo)) { + fake.egressStartedMutex.Lock() + defer fake.egressStartedMutex.Unlock() + fake.EgressStartedStub = stub +} + +func (fake *FakeTelemetryService) EgressStartedArgsForCall(i int) (context.Context, *livekit.EgressInfo) { + fake.egressStartedMutex.RLock() + defer fake.egressStartedMutex.RUnlock() + argsForCall := fake.egressStartedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) EgressUpdated(arg1 context.Context, arg2 *livekit.EgressInfo) { + fake.egressUpdatedMutex.Lock() + fake.egressUpdatedArgsForCall = append(fake.egressUpdatedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.EgressInfo + }{arg1, arg2}) + stub := fake.EgressUpdatedStub + fake.recordInvocation("EgressUpdated", []interface{}{arg1, arg2}) + fake.egressUpdatedMutex.Unlock() + if stub != nil { + fake.EgressUpdatedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) EgressUpdatedCallCount() int { + fake.egressUpdatedMutex.RLock() + defer fake.egressUpdatedMutex.RUnlock() + return len(fake.egressUpdatedArgsForCall) +} + +func (fake *FakeTelemetryService) EgressUpdatedCalls(stub func(context.Context, *livekit.EgressInfo)) { + fake.egressUpdatedMutex.Lock() + defer fake.egressUpdatedMutex.Unlock() + fake.EgressUpdatedStub = stub +} + +func (fake *FakeTelemetryService) EgressUpdatedArgsForCall(i int) (context.Context, *livekit.EgressInfo) { + fake.egressUpdatedMutex.RLock() + defer fake.egressUpdatedMutex.RUnlock() + argsForCall := fake.egressUpdatedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) FlushStats() { + fake.flushStatsMutex.Lock() + fake.flushStatsArgsForCall = append(fake.flushStatsArgsForCall, struct { + }{}) + stub := fake.FlushStatsStub + fake.recordInvocation("FlushStats", []interface{}{}) + fake.flushStatsMutex.Unlock() + if stub != nil { + fake.FlushStatsStub() + } +} + +func (fake *FakeTelemetryService) FlushStatsCallCount() int { + fake.flushStatsMutex.RLock() + defer fake.flushStatsMutex.RUnlock() + return len(fake.flushStatsArgsForCall) +} + +func (fake *FakeTelemetryService) FlushStatsCalls(stub func()) { + fake.flushStatsMutex.Lock() + defer fake.flushStatsMutex.Unlock() + fake.FlushStatsStub = stub +} + +func (fake *FakeTelemetryService) IngressCreated(arg1 context.Context, arg2 *livekit.IngressInfo) { + fake.ingressCreatedMutex.Lock() + fake.ingressCreatedArgsForCall = append(fake.ingressCreatedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.IngressCreatedStub + fake.recordInvocation("IngressCreated", []interface{}{arg1, arg2}) + fake.ingressCreatedMutex.Unlock() + if stub != nil { + fake.IngressCreatedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) IngressCreatedCallCount() int { + fake.ingressCreatedMutex.RLock() + defer fake.ingressCreatedMutex.RUnlock() + return len(fake.ingressCreatedArgsForCall) +} + +func (fake *FakeTelemetryService) IngressCreatedCalls(stub func(context.Context, *livekit.IngressInfo)) { + fake.ingressCreatedMutex.Lock() + defer fake.ingressCreatedMutex.Unlock() + fake.IngressCreatedStub = stub +} + +func (fake *FakeTelemetryService) IngressCreatedArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.ingressCreatedMutex.RLock() + defer fake.ingressCreatedMutex.RUnlock() + argsForCall := fake.ingressCreatedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) IngressDeleted(arg1 context.Context, arg2 *livekit.IngressInfo) { + fake.ingressDeletedMutex.Lock() + fake.ingressDeletedArgsForCall = append(fake.ingressDeletedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.IngressDeletedStub + fake.recordInvocation("IngressDeleted", []interface{}{arg1, arg2}) + fake.ingressDeletedMutex.Unlock() + if stub != nil { + fake.IngressDeletedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) IngressDeletedCallCount() int { + fake.ingressDeletedMutex.RLock() + defer fake.ingressDeletedMutex.RUnlock() + return len(fake.ingressDeletedArgsForCall) +} + +func (fake *FakeTelemetryService) IngressDeletedCalls(stub func(context.Context, *livekit.IngressInfo)) { + fake.ingressDeletedMutex.Lock() + defer fake.ingressDeletedMutex.Unlock() + fake.IngressDeletedStub = stub +} + +func (fake *FakeTelemetryService) IngressDeletedArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.ingressDeletedMutex.RLock() + defer fake.ingressDeletedMutex.RUnlock() + argsForCall := fake.ingressDeletedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) IngressEnded(arg1 context.Context, arg2 *livekit.IngressInfo) { + fake.ingressEndedMutex.Lock() + fake.ingressEndedArgsForCall = append(fake.ingressEndedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.IngressEndedStub + fake.recordInvocation("IngressEnded", []interface{}{arg1, arg2}) + fake.ingressEndedMutex.Unlock() + if stub != nil { + fake.IngressEndedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) IngressEndedCallCount() int { + fake.ingressEndedMutex.RLock() + defer fake.ingressEndedMutex.RUnlock() + return len(fake.ingressEndedArgsForCall) +} + +func (fake *FakeTelemetryService) IngressEndedCalls(stub func(context.Context, *livekit.IngressInfo)) { + fake.ingressEndedMutex.Lock() + defer fake.ingressEndedMutex.Unlock() + fake.IngressEndedStub = stub +} + +func (fake *FakeTelemetryService) IngressEndedArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.ingressEndedMutex.RLock() + defer fake.ingressEndedMutex.RUnlock() + argsForCall := fake.ingressEndedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) IngressStarted(arg1 context.Context, arg2 *livekit.IngressInfo) { + fake.ingressStartedMutex.Lock() + fake.ingressStartedArgsForCall = append(fake.ingressStartedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.IngressStartedStub + fake.recordInvocation("IngressStarted", []interface{}{arg1, arg2}) + fake.ingressStartedMutex.Unlock() + if stub != nil { + fake.IngressStartedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) IngressStartedCallCount() int { + fake.ingressStartedMutex.RLock() + defer fake.ingressStartedMutex.RUnlock() + return len(fake.ingressStartedArgsForCall) +} + +func (fake *FakeTelemetryService) IngressStartedCalls(stub func(context.Context, *livekit.IngressInfo)) { + fake.ingressStartedMutex.Lock() + defer fake.ingressStartedMutex.Unlock() + fake.IngressStartedStub = stub +} + +func (fake *FakeTelemetryService) IngressStartedArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.ingressStartedMutex.RLock() + defer fake.ingressStartedMutex.RUnlock() + argsForCall := fake.ingressStartedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) IngressUpdated(arg1 context.Context, arg2 *livekit.IngressInfo) { + fake.ingressUpdatedMutex.Lock() + fake.ingressUpdatedArgsForCall = append(fake.ingressUpdatedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.IngressInfo + }{arg1, arg2}) + stub := fake.IngressUpdatedStub + fake.recordInvocation("IngressUpdated", []interface{}{arg1, arg2}) + fake.ingressUpdatedMutex.Unlock() + if stub != nil { + fake.IngressUpdatedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) IngressUpdatedCallCount() int { + fake.ingressUpdatedMutex.RLock() + defer fake.ingressUpdatedMutex.RUnlock() + return len(fake.ingressUpdatedArgsForCall) +} + +func (fake *FakeTelemetryService) IngressUpdatedCalls(stub func(context.Context, *livekit.IngressInfo)) { + fake.ingressUpdatedMutex.Lock() + defer fake.ingressUpdatedMutex.Unlock() + fake.IngressUpdatedStub = stub +} + +func (fake *FakeTelemetryService) IngressUpdatedArgsForCall(i int) (context.Context, *livekit.IngressInfo) { + fake.ingressUpdatedMutex.RLock() + defer fake.ingressUpdatedMutex.RUnlock() + argsForCall := fake.ingressUpdatedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) LocalRoomState(arg1 context.Context, arg2 *livekit.AnalyticsNodeRooms) { + fake.localRoomStateMutex.Lock() + fake.localRoomStateArgsForCall = append(fake.localRoomStateArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + }{arg1, arg2}) + stub := fake.LocalRoomStateStub + fake.recordInvocation("LocalRoomState", []interface{}{arg1, arg2}) + fake.localRoomStateMutex.Unlock() + if stub != nil { + fake.LocalRoomStateStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) LocalRoomStateCallCount() int { + fake.localRoomStateMutex.RLock() + defer fake.localRoomStateMutex.RUnlock() + return len(fake.localRoomStateArgsForCall) +} + +func (fake *FakeTelemetryService) LocalRoomStateCalls(stub func(context.Context, *livekit.AnalyticsNodeRooms)) { + fake.localRoomStateMutex.Lock() + defer fake.localRoomStateMutex.Unlock() + fake.LocalRoomStateStub = stub +} + +func (fake *FakeTelemetryService) LocalRoomStateArgsForCall(i int) (context.Context, *livekit.AnalyticsNodeRooms) { + fake.localRoomStateMutex.RLock() + defer fake.localRoomStateMutex.RUnlock() + argsForCall := fake.localRoomStateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) NotifyEgressEvent(arg1 context.Context, arg2 string, arg3 *livekit.EgressInfo) { + fake.notifyEgressEventMutex.Lock() + fake.notifyEgressEventArgsForCall = append(fake.notifyEgressEventArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 *livekit.EgressInfo + }{arg1, arg2, arg3}) + stub := fake.NotifyEgressEventStub + fake.recordInvocation("NotifyEgressEvent", []interface{}{arg1, arg2, arg3}) + fake.notifyEgressEventMutex.Unlock() + if stub != nil { + fake.NotifyEgressEventStub(arg1, arg2, arg3) + } +} + +func (fake *FakeTelemetryService) NotifyEgressEventCallCount() int { + fake.notifyEgressEventMutex.RLock() + defer fake.notifyEgressEventMutex.RUnlock() + return len(fake.notifyEgressEventArgsForCall) +} + +func (fake *FakeTelemetryService) NotifyEgressEventCalls(stub func(context.Context, string, *livekit.EgressInfo)) { + fake.notifyEgressEventMutex.Lock() + defer fake.notifyEgressEventMutex.Unlock() + fake.NotifyEgressEventStub = stub +} + +func (fake *FakeTelemetryService) NotifyEgressEventArgsForCall(i int) (context.Context, string, *livekit.EgressInfo) { + fake.notifyEgressEventMutex.RLock() + defer fake.notifyEgressEventMutex.RUnlock() + argsForCall := fake.notifyEgressEventArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeTelemetryService) ParticipantActive(arg1 context.Context, arg2 *livekit.Room, arg3 *livekit.ParticipantInfo, arg4 *livekit.AnalyticsClientMeta, arg5 bool, arg6 *telemetry.ReferenceGuard) { + fake.participantActiveMutex.Lock() + fake.participantActiveArgsForCall = append(fake.participantActiveArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 *livekit.AnalyticsClientMeta + arg5 bool + arg6 *telemetry.ReferenceGuard + }{arg1, arg2, arg3, arg4, arg5, arg6}) + stub := fake.ParticipantActiveStub + fake.recordInvocation("ParticipantActive", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6}) + fake.participantActiveMutex.Unlock() + if stub != nil { + fake.ParticipantActiveStub(arg1, arg2, arg3, arg4, arg5, arg6) + } +} + +func (fake *FakeTelemetryService) ParticipantActiveCallCount() int { + fake.participantActiveMutex.RLock() + defer fake.participantActiveMutex.RUnlock() + return len(fake.participantActiveArgsForCall) +} + +func (fake *FakeTelemetryService) ParticipantActiveCalls(stub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, *livekit.AnalyticsClientMeta, bool, *telemetry.ReferenceGuard)) { + fake.participantActiveMutex.Lock() + defer fake.participantActiveMutex.Unlock() + fake.ParticipantActiveStub = stub +} + +func (fake *FakeTelemetryService) ParticipantActiveArgsForCall(i int) (context.Context, *livekit.Room, *livekit.ParticipantInfo, *livekit.AnalyticsClientMeta, bool, *telemetry.ReferenceGuard) { + fake.participantActiveMutex.RLock() + defer fake.participantActiveMutex.RUnlock() + argsForCall := fake.participantActiveArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6 +} + +func (fake *FakeTelemetryService) ParticipantJoined(arg1 context.Context, arg2 *livekit.Room, arg3 *livekit.ParticipantInfo, arg4 *livekit.ClientInfo, arg5 *livekit.AnalyticsClientMeta, arg6 bool, arg7 *telemetry.ReferenceGuard) { + fake.participantJoinedMutex.Lock() + fake.participantJoinedArgsForCall = append(fake.participantJoinedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 *livekit.ClientInfo + arg5 *livekit.AnalyticsClientMeta + arg6 bool + arg7 *telemetry.ReferenceGuard + }{arg1, arg2, arg3, arg4, arg5, arg6, arg7}) + stub := fake.ParticipantJoinedStub + fake.recordInvocation("ParticipantJoined", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6, arg7}) + fake.participantJoinedMutex.Unlock() + if stub != nil { + fake.ParticipantJoinedStub(arg1, arg2, arg3, arg4, arg5, arg6, arg7) + } +} + +func (fake *FakeTelemetryService) ParticipantJoinedCallCount() int { + fake.participantJoinedMutex.RLock() + defer fake.participantJoinedMutex.RUnlock() + return len(fake.participantJoinedArgsForCall) +} + +func (fake *FakeTelemetryService) ParticipantJoinedCalls(stub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, *livekit.ClientInfo, *livekit.AnalyticsClientMeta, bool, *telemetry.ReferenceGuard)) { + fake.participantJoinedMutex.Lock() + defer fake.participantJoinedMutex.Unlock() + fake.ParticipantJoinedStub = stub +} + +func (fake *FakeTelemetryService) ParticipantJoinedArgsForCall(i int) (context.Context, *livekit.Room, *livekit.ParticipantInfo, *livekit.ClientInfo, *livekit.AnalyticsClientMeta, bool, *telemetry.ReferenceGuard) { + fake.participantJoinedMutex.RLock() + defer fake.participantJoinedMutex.RUnlock() + argsForCall := fake.participantJoinedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6, argsForCall.arg7 +} + +func (fake *FakeTelemetryService) ParticipantLeft(arg1 context.Context, arg2 *livekit.Room, arg3 *livekit.ParticipantInfo, arg4 bool, arg5 *telemetry.ReferenceGuard) { + fake.participantLeftMutex.Lock() + fake.participantLeftArgsForCall = append(fake.participantLeftArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 bool + arg5 *telemetry.ReferenceGuard + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.ParticipantLeftStub + fake.recordInvocation("ParticipantLeft", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.participantLeftMutex.Unlock() + if stub != nil { + fake.ParticipantLeftStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) ParticipantLeftCallCount() int { + fake.participantLeftMutex.RLock() + defer fake.participantLeftMutex.RUnlock() + return len(fake.participantLeftArgsForCall) +} + +func (fake *FakeTelemetryService) ParticipantLeftCalls(stub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, bool, *telemetry.ReferenceGuard)) { + fake.participantLeftMutex.Lock() + defer fake.participantLeftMutex.Unlock() + fake.ParticipantLeftStub = stub +} + +func (fake *FakeTelemetryService) ParticipantLeftArgsForCall(i int) (context.Context, *livekit.Room, *livekit.ParticipantInfo, bool, *telemetry.ReferenceGuard) { + fake.participantLeftMutex.RLock() + defer fake.participantLeftMutex.RUnlock() + argsForCall := fake.participantLeftArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) ParticipantResumed(arg1 context.Context, arg2 *livekit.Room, arg3 *livekit.ParticipantInfo, arg4 livekit.NodeID, arg5 livekit.ReconnectReason) { + fake.participantResumedMutex.Lock() + fake.participantResumedArgsForCall = append(fake.participantResumedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + arg3 *livekit.ParticipantInfo + arg4 livekit.NodeID + arg5 livekit.ReconnectReason + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.ParticipantResumedStub + fake.recordInvocation("ParticipantResumed", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.participantResumedMutex.Unlock() + if stub != nil { + fake.ParticipantResumedStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) ParticipantResumedCallCount() int { + fake.participantResumedMutex.RLock() + defer fake.participantResumedMutex.RUnlock() + return len(fake.participantResumedArgsForCall) +} + +func (fake *FakeTelemetryService) ParticipantResumedCalls(stub func(context.Context, *livekit.Room, *livekit.ParticipantInfo, livekit.NodeID, livekit.ReconnectReason)) { + fake.participantResumedMutex.Lock() + defer fake.participantResumedMutex.Unlock() + fake.ParticipantResumedStub = stub +} + +func (fake *FakeTelemetryService) ParticipantResumedArgsForCall(i int) (context.Context, *livekit.Room, *livekit.ParticipantInfo, livekit.NodeID, livekit.ReconnectReason) { + fake.participantResumedMutex.RLock() + defer fake.participantResumedMutex.RUnlock() + argsForCall := fake.participantResumedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) Report(arg1 context.Context, arg2 *livekit.ReportInfo) { + fake.reportMutex.Lock() + fake.reportArgsForCall = append(fake.reportArgsForCall, struct { + arg1 context.Context + arg2 *livekit.ReportInfo + }{arg1, arg2}) + stub := fake.ReportStub + fake.recordInvocation("Report", []interface{}{arg1, arg2}) + fake.reportMutex.Unlock() + if stub != nil { + fake.ReportStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) ReportCallCount() int { + fake.reportMutex.RLock() + defer fake.reportMutex.RUnlock() + return len(fake.reportArgsForCall) +} + +func (fake *FakeTelemetryService) ReportCalls(stub func(context.Context, *livekit.ReportInfo)) { + fake.reportMutex.Lock() + defer fake.reportMutex.Unlock() + fake.ReportStub = stub +} + +func (fake *FakeTelemetryService) ReportArgsForCall(i int) (context.Context, *livekit.ReportInfo) { + fake.reportMutex.RLock() + defer fake.reportMutex.RUnlock() + argsForCall := fake.reportArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) RoomEnded(arg1 context.Context, arg2 *livekit.Room) { + fake.roomEndedMutex.Lock() + fake.roomEndedArgsForCall = append(fake.roomEndedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + }{arg1, arg2}) + stub := fake.RoomEndedStub + fake.recordInvocation("RoomEnded", []interface{}{arg1, arg2}) + fake.roomEndedMutex.Unlock() + if stub != nil { + fake.RoomEndedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) RoomEndedCallCount() int { + fake.roomEndedMutex.RLock() + defer fake.roomEndedMutex.RUnlock() + return len(fake.roomEndedArgsForCall) +} + +func (fake *FakeTelemetryService) RoomEndedCalls(stub func(context.Context, *livekit.Room)) { + fake.roomEndedMutex.Lock() + defer fake.roomEndedMutex.Unlock() + fake.RoomEndedStub = stub +} + +func (fake *FakeTelemetryService) RoomEndedArgsForCall(i int) (context.Context, *livekit.Room) { + fake.roomEndedMutex.RLock() + defer fake.roomEndedMutex.RUnlock() + argsForCall := fake.roomEndedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) RoomProjectReporter(arg1 context.Context) roomobs.ProjectReporter { + fake.roomProjectReporterMutex.Lock() + ret, specificReturn := fake.roomProjectReporterReturnsOnCall[len(fake.roomProjectReporterArgsForCall)] + fake.roomProjectReporterArgsForCall = append(fake.roomProjectReporterArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.RoomProjectReporterStub + fakeReturns := fake.roomProjectReporterReturns + fake.recordInvocation("RoomProjectReporter", []interface{}{arg1}) + fake.roomProjectReporterMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeTelemetryService) RoomProjectReporterCallCount() int { + fake.roomProjectReporterMutex.RLock() + defer fake.roomProjectReporterMutex.RUnlock() + return len(fake.roomProjectReporterArgsForCall) +} + +func (fake *FakeTelemetryService) RoomProjectReporterCalls(stub func(context.Context) roomobs.ProjectReporter) { + fake.roomProjectReporterMutex.Lock() + defer fake.roomProjectReporterMutex.Unlock() + fake.RoomProjectReporterStub = stub +} + +func (fake *FakeTelemetryService) RoomProjectReporterArgsForCall(i int) context.Context { + fake.roomProjectReporterMutex.RLock() + defer fake.roomProjectReporterMutex.RUnlock() + argsForCall := fake.roomProjectReporterArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeTelemetryService) RoomProjectReporterReturns(result1 roomobs.ProjectReporter) { + fake.roomProjectReporterMutex.Lock() + defer fake.roomProjectReporterMutex.Unlock() + fake.RoomProjectReporterStub = nil + fake.roomProjectReporterReturns = struct { + result1 roomobs.ProjectReporter + }{result1} +} + +func (fake *FakeTelemetryService) RoomProjectReporterReturnsOnCall(i int, result1 roomobs.ProjectReporter) { + fake.roomProjectReporterMutex.Lock() + defer fake.roomProjectReporterMutex.Unlock() + fake.RoomProjectReporterStub = nil + if fake.roomProjectReporterReturnsOnCall == nil { + fake.roomProjectReporterReturnsOnCall = make(map[int]struct { + result1 roomobs.ProjectReporter + }) + } + fake.roomProjectReporterReturnsOnCall[i] = struct { + result1 roomobs.ProjectReporter + }{result1} +} + +func (fake *FakeTelemetryService) RoomStarted(arg1 context.Context, arg2 *livekit.Room) { + fake.roomStartedMutex.Lock() + fake.roomStartedArgsForCall = append(fake.roomStartedArgsForCall, struct { + arg1 context.Context + arg2 *livekit.Room + }{arg1, arg2}) + stub := fake.RoomStartedStub + fake.recordInvocation("RoomStarted", []interface{}{arg1, arg2}) + fake.roomStartedMutex.Unlock() + if stub != nil { + fake.RoomStartedStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) RoomStartedCallCount() int { + fake.roomStartedMutex.RLock() + defer fake.roomStartedMutex.RUnlock() + return len(fake.roomStartedArgsForCall) +} + +func (fake *FakeTelemetryService) RoomStartedCalls(stub func(context.Context, *livekit.Room)) { + fake.roomStartedMutex.Lock() + defer fake.roomStartedMutex.Unlock() + fake.RoomStartedStub = stub +} + +func (fake *FakeTelemetryService) RoomStartedArgsForCall(i int) (context.Context, *livekit.Room) { + fake.roomStartedMutex.RLock() + defer fake.roomStartedMutex.RUnlock() + argsForCall := fake.roomStartedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) SendEvent(arg1 context.Context, arg2 *livekit.AnalyticsEvent) { + fake.sendEventMutex.Lock() + fake.sendEventArgsForCall = append(fake.sendEventArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsEvent + }{arg1, arg2}) + stub := fake.SendEventStub + fake.recordInvocation("SendEvent", []interface{}{arg1, arg2}) + fake.sendEventMutex.Unlock() + if stub != nil { + fake.SendEventStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) SendEventCallCount() int { + fake.sendEventMutex.RLock() + defer fake.sendEventMutex.RUnlock() + return len(fake.sendEventArgsForCall) +} + +func (fake *FakeTelemetryService) SendEventCalls(stub func(context.Context, *livekit.AnalyticsEvent)) { + fake.sendEventMutex.Lock() + defer fake.sendEventMutex.Unlock() + fake.SendEventStub = stub +} + +func (fake *FakeTelemetryService) SendEventArgsForCall(i int) (context.Context, *livekit.AnalyticsEvent) { + fake.sendEventMutex.RLock() + defer fake.sendEventMutex.RUnlock() + argsForCall := fake.sendEventArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) SendNodeRoomStates(arg1 context.Context, arg2 *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.Lock() + fake.sendNodeRoomStatesArgsForCall = append(fake.sendNodeRoomStatesArgsForCall, struct { + arg1 context.Context + arg2 *livekit.AnalyticsNodeRooms + }{arg1, arg2}) + stub := fake.SendNodeRoomStatesStub + fake.recordInvocation("SendNodeRoomStates", []interface{}{arg1, arg2}) + fake.sendNodeRoomStatesMutex.Unlock() + if stub != nil { + fake.SendNodeRoomStatesStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) SendNodeRoomStatesCallCount() int { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + return len(fake.sendNodeRoomStatesArgsForCall) +} + +func (fake *FakeTelemetryService) SendNodeRoomStatesCalls(stub func(context.Context, *livekit.AnalyticsNodeRooms)) { + fake.sendNodeRoomStatesMutex.Lock() + defer fake.sendNodeRoomStatesMutex.Unlock() + fake.SendNodeRoomStatesStub = stub +} + +func (fake *FakeTelemetryService) SendNodeRoomStatesArgsForCall(i int) (context.Context, *livekit.AnalyticsNodeRooms) { + fake.sendNodeRoomStatesMutex.RLock() + defer fake.sendNodeRoomStatesMutex.RUnlock() + argsForCall := fake.sendNodeRoomStatesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) SendStats(arg1 context.Context, arg2 []*livekit.AnalyticsStat) { + var arg2Copy []*livekit.AnalyticsStat + if arg2 != nil { + arg2Copy = make([]*livekit.AnalyticsStat, len(arg2)) + copy(arg2Copy, arg2) + } + fake.sendStatsMutex.Lock() + fake.sendStatsArgsForCall = append(fake.sendStatsArgsForCall, struct { + arg1 context.Context + arg2 []*livekit.AnalyticsStat + }{arg1, arg2Copy}) + stub := fake.SendStatsStub + fake.recordInvocation("SendStats", []interface{}{arg1, arg2Copy}) + fake.sendStatsMutex.Unlock() + if stub != nil { + fake.SendStatsStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) SendStatsCallCount() int { + fake.sendStatsMutex.RLock() + defer fake.sendStatsMutex.RUnlock() + return len(fake.sendStatsArgsForCall) +} + +func (fake *FakeTelemetryService) SendStatsCalls(stub func(context.Context, []*livekit.AnalyticsStat)) { + fake.sendStatsMutex.Lock() + defer fake.sendStatsMutex.Unlock() + fake.SendStatsStub = stub +} + +func (fake *FakeTelemetryService) SendStatsArgsForCall(i int) (context.Context, []*livekit.AnalyticsStat) { + fake.sendStatsMutex.RLock() + defer fake.sendStatsMutex.RUnlock() + argsForCall := fake.sendStatsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQuality(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 mime.MimeType, arg5 livekit.VideoQuality) { + fake.trackMaxSubscribedVideoQualityMutex.Lock() + fake.trackMaxSubscribedVideoQualityArgsForCall = append(fake.trackMaxSubscribedVideoQualityArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + arg4 mime.MimeType + arg5 livekit.VideoQuality + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.TrackMaxSubscribedVideoQualityStub + fake.recordInvocation("TrackMaxSubscribedVideoQuality", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.trackMaxSubscribedVideoQualityMutex.Unlock() + if stub != nil { + fake.TrackMaxSubscribedVideoQualityStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCallCount() int { + fake.trackMaxSubscribedVideoQualityMutex.RLock() + defer fake.trackMaxSubscribedVideoQualityMutex.RUnlock() + return len(fake.trackMaxSubscribedVideoQualityArgsForCall) +} + +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality)) { + fake.trackMaxSubscribedVideoQualityMutex.Lock() + defer fake.trackMaxSubscribedVideoQualityMutex.Unlock() + fake.TrackMaxSubscribedVideoQualityStub = stub +} + +func (fake *FakeTelemetryService) TrackMaxSubscribedVideoQualityArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, mime.MimeType, livekit.VideoQuality) { + fake.trackMaxSubscribedVideoQualityMutex.RLock() + defer fake.trackMaxSubscribedVideoQualityMutex.RUnlock() + argsForCall := fake.trackMaxSubscribedVideoQualityArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) TrackMuted(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo) { + fake.trackMutedMutex.Lock() + fake.trackMutedArgsForCall = append(fake.trackMutedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + }{arg1, arg2, arg3}) + stub := fake.TrackMutedStub + fake.recordInvocation("TrackMuted", []interface{}{arg1, arg2, arg3}) + fake.trackMutedMutex.Unlock() + if stub != nil { + fake.TrackMutedStub(arg1, arg2, arg3) + } +} + +func (fake *FakeTelemetryService) TrackMutedCallCount() int { + fake.trackMutedMutex.RLock() + defer fake.trackMutedMutex.RUnlock() + return len(fake.trackMutedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackMutedCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo)) { + fake.trackMutedMutex.Lock() + defer fake.trackMutedMutex.Unlock() + fake.TrackMutedStub = stub +} + +func (fake *FakeTelemetryService) TrackMutedArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo) { + fake.trackMutedMutex.RLock() + defer fake.trackMutedMutex.RUnlock() + argsForCall := fake.trackMutedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeTelemetryService) TrackPublishRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 mime.MimeType, arg5 int, arg6 *livekit.RTPStats) { + fake.trackPublishRTPStatsMutex.Lock() + fake.trackPublishRTPStatsArgsForCall = append(fake.trackPublishRTPStatsArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.TrackID + arg4 mime.MimeType + arg5 int + arg6 *livekit.RTPStats + }{arg1, arg2, arg3, arg4, arg5, arg6}) + stub := fake.TrackPublishRTPStatsStub + fake.recordInvocation("TrackPublishRTPStats", []interface{}{arg1, arg2, arg3, arg4, arg5, arg6}) + fake.trackPublishRTPStatsMutex.Unlock() + if stub != nil { + fake.TrackPublishRTPStatsStub(arg1, arg2, arg3, arg4, arg5, arg6) + } +} + +func (fake *FakeTelemetryService) TrackPublishRTPStatsCallCount() int { + fake.trackPublishRTPStatsMutex.RLock() + defer fake.trackPublishRTPStatsMutex.RUnlock() + return len(fake.trackPublishRTPStatsArgsForCall) +} + +func (fake *FakeTelemetryService) TrackPublishRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats)) { + fake.trackPublishRTPStatsMutex.Lock() + defer fake.trackPublishRTPStatsMutex.Unlock() + fake.TrackPublishRTPStatsStub = stub +} + +func (fake *FakeTelemetryService) TrackPublishRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, int, *livekit.RTPStats) { + fake.trackPublishRTPStatsMutex.RLock() + defer fake.trackPublishRTPStatsMutex.RUnlock() + argsForCall := fake.trackPublishRTPStatsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5, argsForCall.arg6 +} + +func (fake *FakeTelemetryService) TrackPublishRequested(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.ParticipantIdentity, arg4 *livekit.TrackInfo) { + fake.trackPublishRequestedMutex.Lock() + fake.trackPublishRequestedArgsForCall = append(fake.trackPublishRequestedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.ParticipantIdentity + arg4 *livekit.TrackInfo + }{arg1, arg2, arg3, arg4}) + stub := fake.TrackPublishRequestedStub + fake.recordInvocation("TrackPublishRequested", []interface{}{arg1, arg2, arg3, arg4}) + fake.trackPublishRequestedMutex.Unlock() + if stub != nil { + fake.TrackPublishRequestedStub(arg1, arg2, arg3, arg4) + } +} + +func (fake *FakeTelemetryService) TrackPublishRequestedCallCount() int { + fake.trackPublishRequestedMutex.RLock() + defer fake.trackPublishRequestedMutex.RUnlock() + return len(fake.trackPublishRequestedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackPublishRequestedCalls(stub func(context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo)) { + fake.trackPublishRequestedMutex.Lock() + defer fake.trackPublishRequestedMutex.Unlock() + fake.TrackPublishRequestedStub = stub +} + +func (fake *FakeTelemetryService) TrackPublishRequestedArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo) { + fake.trackPublishRequestedMutex.RLock() + defer fake.trackPublishRequestedMutex.RUnlock() + argsForCall := fake.trackPublishRequestedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeTelemetryService) TrackPublished(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.ParticipantIdentity, arg4 *livekit.TrackInfo, arg5 bool) { + fake.trackPublishedMutex.Lock() + fake.trackPublishedArgsForCall = append(fake.trackPublishedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.ParticipantIdentity + arg4 *livekit.TrackInfo + arg5 bool + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.TrackPublishedStub + fake.recordInvocation("TrackPublished", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.trackPublishedMutex.Unlock() + if stub != nil { + fake.TrackPublishedStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) TrackPublishedCallCount() int { + fake.trackPublishedMutex.RLock() + defer fake.trackPublishedMutex.RUnlock() + return len(fake.trackPublishedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackPublishedCalls(stub func(context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo, bool)) { + fake.trackPublishedMutex.Lock() + defer fake.trackPublishedMutex.Unlock() + fake.TrackPublishedStub = stub +} + +func (fake *FakeTelemetryService) TrackPublishedArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo, bool) { + fake.trackPublishedMutex.RLock() + defer fake.trackPublishedMutex.RUnlock() + argsForCall := fake.trackPublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) TrackPublishedUpdate(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo) { + fake.trackPublishedUpdateMutex.Lock() + fake.trackPublishedUpdateArgsForCall = append(fake.trackPublishedUpdateArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + }{arg1, arg2, arg3}) + stub := fake.TrackPublishedUpdateStub + fake.recordInvocation("TrackPublishedUpdate", []interface{}{arg1, arg2, arg3}) + fake.trackPublishedUpdateMutex.Unlock() + if stub != nil { + fake.TrackPublishedUpdateStub(arg1, arg2, arg3) + } +} + +func (fake *FakeTelemetryService) TrackPublishedUpdateCallCount() int { + fake.trackPublishedUpdateMutex.RLock() + defer fake.trackPublishedUpdateMutex.RUnlock() + return len(fake.trackPublishedUpdateArgsForCall) +} + +func (fake *FakeTelemetryService) TrackPublishedUpdateCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo)) { + fake.trackPublishedUpdateMutex.Lock() + defer fake.trackPublishedUpdateMutex.Unlock() + fake.TrackPublishedUpdateStub = stub +} + +func (fake *FakeTelemetryService) TrackPublishedUpdateArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo) { + fake.trackPublishedUpdateMutex.RLock() + defer fake.trackPublishedUpdateMutex.RUnlock() + argsForCall := fake.trackPublishedUpdateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeTelemetryService) TrackStats(arg1 telemetry.StatsKey, arg2 *livekit.AnalyticsStat) { + fake.trackStatsMutex.Lock() + fake.trackStatsArgsForCall = append(fake.trackStatsArgsForCall, struct { + arg1 telemetry.StatsKey + arg2 *livekit.AnalyticsStat + }{arg1, arg2}) + stub := fake.TrackStatsStub + fake.recordInvocation("TrackStats", []interface{}{arg1, arg2}) + fake.trackStatsMutex.Unlock() + if stub != nil { + fake.TrackStatsStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) TrackStatsCallCount() int { + fake.trackStatsMutex.RLock() + defer fake.trackStatsMutex.RUnlock() + return len(fake.trackStatsArgsForCall) +} + +func (fake *FakeTelemetryService) TrackStatsCalls(stub func(telemetry.StatsKey, *livekit.AnalyticsStat)) { + fake.trackStatsMutex.Lock() + defer fake.trackStatsMutex.Unlock() + fake.TrackStatsStub = stub +} + +func (fake *FakeTelemetryService) TrackStatsArgsForCall(i int) (telemetry.StatsKey, *livekit.AnalyticsStat) { + fake.trackStatsMutex.RLock() + defer fake.trackStatsMutex.RUnlock() + argsForCall := fake.trackStatsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) TrackSubscribeFailed(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 error, arg5 bool) { + fake.trackSubscribeFailedMutex.Lock() + fake.trackSubscribeFailedArgsForCall = append(fake.trackSubscribeFailedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.TrackID + arg4 error + arg5 bool + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.TrackSubscribeFailedStub + fake.recordInvocation("TrackSubscribeFailed", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.trackSubscribeFailedMutex.Unlock() + if stub != nil { + fake.TrackSubscribeFailedStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) TrackSubscribeFailedCallCount() int { + fake.trackSubscribeFailedMutex.RLock() + defer fake.trackSubscribeFailedMutex.RUnlock() + return len(fake.trackSubscribeFailedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackSubscribeFailedCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, error, bool)) { + fake.trackSubscribeFailedMutex.Lock() + defer fake.trackSubscribeFailedMutex.Unlock() + fake.TrackSubscribeFailedStub = stub +} + +func (fake *FakeTelemetryService) TrackSubscribeFailedArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, error, bool) { + fake.trackSubscribeFailedMutex.RLock() + defer fake.trackSubscribeFailedMutex.RUnlock() + argsForCall := fake.trackSubscribeFailedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) TrackSubscribeRTPStats(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.TrackID, arg4 mime.MimeType, arg5 *livekit.RTPStats) { + fake.trackSubscribeRTPStatsMutex.Lock() + fake.trackSubscribeRTPStatsArgsForCall = append(fake.trackSubscribeRTPStatsArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.TrackID + arg4 mime.MimeType + arg5 *livekit.RTPStats + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.TrackSubscribeRTPStatsStub + fake.recordInvocation("TrackSubscribeRTPStats", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.trackSubscribeRTPStatsMutex.Unlock() + if stub != nil { + fake.TrackSubscribeRTPStatsStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCallCount() int { + fake.trackSubscribeRTPStatsMutex.RLock() + defer fake.trackSubscribeRTPStatsMutex.RUnlock() + return len(fake.trackSubscribeRTPStatsArgsForCall) +} + +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsCalls(stub func(context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats)) { + fake.trackSubscribeRTPStatsMutex.Lock() + defer fake.trackSubscribeRTPStatsMutex.Unlock() + fake.TrackSubscribeRTPStatsStub = stub +} + +func (fake *FakeTelemetryService) TrackSubscribeRTPStatsArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.TrackID, mime.MimeType, *livekit.RTPStats) { + fake.trackSubscribeRTPStatsMutex.RLock() + defer fake.trackSubscribeRTPStatsMutex.RUnlock() + argsForCall := fake.trackSubscribeRTPStatsArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) TrackSubscribeRequested(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo) { + fake.trackSubscribeRequestedMutex.Lock() + fake.trackSubscribeRequestedArgsForCall = append(fake.trackSubscribeRequestedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + }{arg1, arg2, arg3}) + stub := fake.TrackSubscribeRequestedStub + fake.recordInvocation("TrackSubscribeRequested", []interface{}{arg1, arg2, arg3}) + fake.trackSubscribeRequestedMutex.Unlock() + if stub != nil { + fake.TrackSubscribeRequestedStub(arg1, arg2, arg3) + } +} + +func (fake *FakeTelemetryService) TrackSubscribeRequestedCallCount() int { + fake.trackSubscribeRequestedMutex.RLock() + defer fake.trackSubscribeRequestedMutex.RUnlock() + return len(fake.trackSubscribeRequestedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackSubscribeRequestedCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo)) { + fake.trackSubscribeRequestedMutex.Lock() + defer fake.trackSubscribeRequestedMutex.Unlock() + fake.TrackSubscribeRequestedStub = stub +} + +func (fake *FakeTelemetryService) TrackSubscribeRequestedArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo) { + fake.trackSubscribeRequestedMutex.RLock() + defer fake.trackSubscribeRequestedMutex.RUnlock() + argsForCall := fake.trackSubscribeRequestedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeTelemetryService) TrackSubscribed(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 *livekit.ParticipantInfo, arg5 bool) { + fake.trackSubscribedMutex.Lock() + fake.trackSubscribedArgsForCall = append(fake.trackSubscribedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + arg4 *livekit.ParticipantInfo + arg5 bool + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.TrackSubscribedStub + fake.recordInvocation("TrackSubscribed", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.trackSubscribedMutex.Unlock() + if stub != nil { + fake.TrackSubscribedStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) TrackSubscribedCallCount() int { + fake.trackSubscribedMutex.RLock() + defer fake.trackSubscribedMutex.RUnlock() + return len(fake.trackSubscribedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackSubscribedCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, *livekit.ParticipantInfo, bool)) { + fake.trackSubscribedMutex.Lock() + defer fake.trackSubscribedMutex.Unlock() + fake.TrackSubscribedStub = stub +} + +func (fake *FakeTelemetryService) TrackSubscribedArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, *livekit.ParticipantInfo, bool) { + fake.trackSubscribedMutex.RLock() + defer fake.trackSubscribedMutex.RUnlock() + argsForCall := fake.trackSubscribedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) TrackUnmuted(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo) { + fake.trackUnmutedMutex.Lock() + fake.trackUnmutedArgsForCall = append(fake.trackUnmutedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + }{arg1, arg2, arg3}) + stub := fake.TrackUnmutedStub + fake.recordInvocation("TrackUnmuted", []interface{}{arg1, arg2, arg3}) + fake.trackUnmutedMutex.Unlock() + if stub != nil { + fake.TrackUnmutedStub(arg1, arg2, arg3) + } +} + +func (fake *FakeTelemetryService) TrackUnmutedCallCount() int { + fake.trackUnmutedMutex.RLock() + defer fake.trackUnmutedMutex.RUnlock() + return len(fake.trackUnmutedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackUnmutedCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo)) { + fake.trackUnmutedMutex.Lock() + defer fake.trackUnmutedMutex.Unlock() + fake.TrackUnmutedStub = stub +} + +func (fake *FakeTelemetryService) TrackUnmutedArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo) { + fake.trackUnmutedMutex.RLock() + defer fake.trackUnmutedMutex.RUnlock() + argsForCall := fake.trackUnmutedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeTelemetryService) TrackUnpublished(arg1 context.Context, arg2 livekit.ParticipantID, arg3 livekit.ParticipantIdentity, arg4 *livekit.TrackInfo, arg5 bool) { + fake.trackUnpublishedMutex.Lock() + fake.trackUnpublishedArgsForCall = append(fake.trackUnpublishedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 livekit.ParticipantIdentity + arg4 *livekit.TrackInfo + arg5 bool + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.TrackUnpublishedStub + fake.recordInvocation("TrackUnpublished", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.trackUnpublishedMutex.Unlock() + if stub != nil { + fake.TrackUnpublishedStub(arg1, arg2, arg3, arg4, arg5) + } +} + +func (fake *FakeTelemetryService) TrackUnpublishedCallCount() int { + fake.trackUnpublishedMutex.RLock() + defer fake.trackUnpublishedMutex.RUnlock() + return len(fake.trackUnpublishedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackUnpublishedCalls(stub func(context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo, bool)) { + fake.trackUnpublishedMutex.Lock() + defer fake.trackUnpublishedMutex.Unlock() + fake.TrackUnpublishedStub = stub +} + +func (fake *FakeTelemetryService) TrackUnpublishedArgsForCall(i int) (context.Context, livekit.ParticipantID, livekit.ParticipantIdentity, *livekit.TrackInfo, bool) { + fake.trackUnpublishedMutex.RLock() + defer fake.trackUnpublishedMutex.RUnlock() + argsForCall := fake.trackUnpublishedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeTelemetryService) TrackUnsubscribed(arg1 context.Context, arg2 livekit.ParticipantID, arg3 *livekit.TrackInfo, arg4 bool) { + fake.trackUnsubscribedMutex.Lock() + fake.trackUnsubscribedArgsForCall = append(fake.trackUnsubscribedArgsForCall, struct { + arg1 context.Context + arg2 livekit.ParticipantID + arg3 *livekit.TrackInfo + arg4 bool + }{arg1, arg2, arg3, arg4}) + stub := fake.TrackUnsubscribedStub + fake.recordInvocation("TrackUnsubscribed", []interface{}{arg1, arg2, arg3, arg4}) + fake.trackUnsubscribedMutex.Unlock() + if stub != nil { + fake.TrackUnsubscribedStub(arg1, arg2, arg3, arg4) + } +} + +func (fake *FakeTelemetryService) TrackUnsubscribedCallCount() int { + fake.trackUnsubscribedMutex.RLock() + defer fake.trackUnsubscribedMutex.RUnlock() + return len(fake.trackUnsubscribedArgsForCall) +} + +func (fake *FakeTelemetryService) TrackUnsubscribedCalls(stub func(context.Context, livekit.ParticipantID, *livekit.TrackInfo, bool)) { + fake.trackUnsubscribedMutex.Lock() + defer fake.trackUnsubscribedMutex.Unlock() + fake.TrackUnsubscribedStub = stub +} + +func (fake *FakeTelemetryService) TrackUnsubscribedArgsForCall(i int) (context.Context, livekit.ParticipantID, *livekit.TrackInfo, bool) { + fake.trackUnsubscribedMutex.RLock() + defer fake.trackUnsubscribedMutex.RUnlock() + argsForCall := fake.trackUnsubscribedArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeTelemetryService) Webhook(arg1 context.Context, arg2 *livekit.WebhookInfo) { + fake.webhookMutex.Lock() + fake.webhookArgsForCall = append(fake.webhookArgsForCall, struct { + arg1 context.Context + arg2 *livekit.WebhookInfo + }{arg1, arg2}) + stub := fake.WebhookStub + fake.recordInvocation("Webhook", []interface{}{arg1, arg2}) + fake.webhookMutex.Unlock() + if stub != nil { + fake.WebhookStub(arg1, arg2) + } +} + +func (fake *FakeTelemetryService) WebhookCallCount() int { + fake.webhookMutex.RLock() + defer fake.webhookMutex.RUnlock() + return len(fake.webhookArgsForCall) +} + +func (fake *FakeTelemetryService) WebhookCalls(stub func(context.Context, *livekit.WebhookInfo)) { + fake.webhookMutex.Lock() + defer fake.webhookMutex.Unlock() + fake.WebhookStub = stub +} + +func (fake *FakeTelemetryService) WebhookArgsForCall(i int) (context.Context, *livekit.WebhookInfo) { + fake.webhookMutex.RLock() + defer fake.webhookMutex.RUnlock() + argsForCall := fake.webhookArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeTelemetryService) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeTelemetryService) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ telemetry.TelemetryService = new(FakeTelemetryService) diff --git a/livekit/pkg/telemetry/telemetryservice.go b/livekit/pkg/telemetry/telemetryservice.go new file mode 100644 index 0000000..cd48cd3 --- /dev/null +++ b/livekit/pkg/telemetry/telemetryservice.go @@ -0,0 +1,310 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package telemetry + +import ( + "context" + "sync" + "time" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/utils" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/webhook" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . TelemetryService +type TelemetryService interface { + // TrackStats is called periodically for each track in both directions (published/subscribed) + TrackStats(key StatsKey, stat *livekit.AnalyticsStat) + + // events + RoomStarted(ctx context.Context, room *livekit.Room) + RoomEnded(ctx context.Context, room *livekit.Room) + // ParticipantJoined - a participant establishes signal connection to a room + ParticipantJoined(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, clientInfo *livekit.ClientInfo, clientMeta *livekit.AnalyticsClientMeta, shouldSendEvent bool, guard *ReferenceGuard) + // ParticipantActive - a participant establishes media connection + ParticipantActive(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, clientMeta *livekit.AnalyticsClientMeta, isMigration bool, guard *ReferenceGuard) + // ParticipantResumed - there has been an ICE restart or connection resume attempt, and we've received their signal connection + ParticipantResumed(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, nodeID livekit.NodeID, reason livekit.ReconnectReason) + // ParticipantLeft - the participant leaves the room, only sent if ParticipantActive has been called before + ParticipantLeft(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, shouldSendEvent bool, guard *ReferenceGuard) + // TrackPublishRequested - a publication attempt has been received + TrackPublishRequested(ctx context.Context, participantID livekit.ParticipantID, identity livekit.ParticipantIdentity, track *livekit.TrackInfo) + // TrackPublished - a publication attempt has been successful + TrackPublished(ctx context.Context, participantID livekit.ParticipantID, identity livekit.ParticipantIdentity, track *livekit.TrackInfo, shouldSendEvent bool) + // TrackUnpublished - a participant unpublished a track + TrackUnpublished(ctx context.Context, participantID livekit.ParticipantID, identity livekit.ParticipantIdentity, track *livekit.TrackInfo, shouldSendEvent bool) + // TrackSubscribeRequested - a participant requested to subscribe to a track + TrackSubscribeRequested(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) + // TrackSubscribed - a participant subscribed to a track successfully + TrackSubscribed(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, publisher *livekit.ParticipantInfo, shouldSendEvent bool) + // TrackUnsubscribed - a participant unsubscribed from a track successfully + TrackUnsubscribed(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, shouldSendEvent bool) + // TrackSubscribeFailed - failure to subscribe to a track + TrackSubscribeFailed(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, err error, isUserError bool) + // TrackMuted - the publisher has muted the Track + TrackMuted(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) + // TrackUnmuted - the publisher has muted the Track + TrackUnmuted(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) + // TrackPublishedUpdate - track metadata has been updated + TrackPublishedUpdate(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) + // TrackMaxSubscribedVideoQuality - publisher is notified of the max quality subscribers desire + TrackMaxSubscribedVideoQuality(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, mime mime.MimeType, maxQuality livekit.VideoQuality) + TrackPublishRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, layer int, stats *livekit.RTPStats) + TrackSubscribeRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, stats *livekit.RTPStats) + EgressStarted(ctx context.Context, info *livekit.EgressInfo) + EgressUpdated(ctx context.Context, info *livekit.EgressInfo) + EgressEnded(ctx context.Context, info *livekit.EgressInfo) + IngressCreated(ctx context.Context, info *livekit.IngressInfo) + IngressDeleted(ctx context.Context, info *livekit.IngressInfo) + IngressStarted(ctx context.Context, info *livekit.IngressInfo) + IngressUpdated(ctx context.Context, info *livekit.IngressInfo) + IngressEnded(ctx context.Context, info *livekit.IngressInfo) + LocalRoomState(ctx context.Context, info *livekit.AnalyticsNodeRooms) + Report(ctx context.Context, reportInfo *livekit.ReportInfo) + APICall(ctx context.Context, apiCallInfo *livekit.APICallInfo) + Webhook(ctx context.Context, webhookInfo *livekit.WebhookInfo) + + // helpers + AnalyticsService + NotifyEgressEvent(ctx context.Context, event string, info *livekit.EgressInfo) + FlushStats() +} + +// ----------------------------- + +var _ TelemetryService = (*NullTelemetryService)(nil) + +type NullTelemetryService struct { + NullAnalyticService +} + +func (n NullTelemetryService) TrackStats(key StatsKey, stat *livekit.AnalyticsStat) {} +func (n NullTelemetryService) RoomStarted(ctx context.Context, room *livekit.Room) {} +func (n NullTelemetryService) RoomEnded(ctx context.Context, room *livekit.Room) {} +func (n NullTelemetryService) ParticipantJoined(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, clientInfo *livekit.ClientInfo, clientMeta *livekit.AnalyticsClientMeta, shouldSendEvent bool, guard *ReferenceGuard) { +} +func (n NullTelemetryService) ParticipantActive(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, clientMeta *livekit.AnalyticsClientMeta, isMigration bool, guard *ReferenceGuard) { +} +func (n NullTelemetryService) ParticipantResumed(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, nodeID livekit.NodeID, reason livekit.ReconnectReason) { +} +func (n NullTelemetryService) ParticipantLeft(ctx context.Context, room *livekit.Room, participant *livekit.ParticipantInfo, shouldSendEvent bool, guard *ReferenceGuard) { +} +func (n NullTelemetryService) TrackPublishRequested(ctx context.Context, participantID livekit.ParticipantID, identity livekit.ParticipantIdentity, track *livekit.TrackInfo) { +} +func (n NullTelemetryService) TrackPublished(ctx context.Context, participantID livekit.ParticipantID, identity livekit.ParticipantIdentity, track *livekit.TrackInfo, shouldSendEvent bool) { +} +func (n NullTelemetryService) TrackUnpublished(ctx context.Context, participantID livekit.ParticipantID, identity livekit.ParticipantIdentity, track *livekit.TrackInfo, shouldSendEvent bool) { +} +func (n NullTelemetryService) TrackSubscribeRequested(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) { +} +func (n NullTelemetryService) TrackSubscribed(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, publisher *livekit.ParticipantInfo, shouldSendEvent bool) { +} +func (n NullTelemetryService) TrackUnsubscribed(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, shouldSendEvent bool) { +} +func (n NullTelemetryService) TrackSubscribeFailed(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, err error, isUserError bool) { +} +func (n NullTelemetryService) TrackMuted(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) { +} +func (n NullTelemetryService) TrackUnmuted(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) { +} +func (n NullTelemetryService) TrackPublishedUpdate(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo) { +} +func (n NullTelemetryService) TrackMaxSubscribedVideoQuality(ctx context.Context, participantID livekit.ParticipantID, track *livekit.TrackInfo, mime mime.MimeType, maxQuality livekit.VideoQuality) { +} +func (n NullTelemetryService) TrackPublishRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, layer int, stats *livekit.RTPStats) { +} +func (n NullTelemetryService) TrackSubscribeRTPStats(ctx context.Context, participantID livekit.ParticipantID, trackID livekit.TrackID, mimeType mime.MimeType, stats *livekit.RTPStats) { +} +func (n NullTelemetryService) EgressStarted(ctx context.Context, info *livekit.EgressInfo) {} +func (n NullTelemetryService) EgressUpdated(ctx context.Context, info *livekit.EgressInfo) {} +func (n NullTelemetryService) EgressEnded(ctx context.Context, info *livekit.EgressInfo) {} +func (n NullTelemetryService) IngressCreated(ctx context.Context, info *livekit.IngressInfo) {} +func (n NullTelemetryService) IngressDeleted(ctx context.Context, info *livekit.IngressInfo) {} +func (n NullTelemetryService) IngressStarted(ctx context.Context, info *livekit.IngressInfo) {} +func (n NullTelemetryService) IngressUpdated(ctx context.Context, info *livekit.IngressInfo) {} +func (n NullTelemetryService) IngressEnded(ctx context.Context, info *livekit.IngressInfo) {} +func (n NullTelemetryService) LocalRoomState(ctx context.Context, info *livekit.AnalyticsNodeRooms) {} +func (n NullTelemetryService) Report(ctx context.Context, reportInfo *livekit.ReportInfo) {} +func (n NullTelemetryService) APICall(ctx context.Context, apiCallInfo *livekit.APICallInfo) {} +func (n NullTelemetryService) Webhook(ctx context.Context, webhookInfo *livekit.WebhookInfo) {} +func (n NullTelemetryService) NotifyEgressEvent(ctx context.Context, event string, info *livekit.EgressInfo) { +} +func (n NullTelemetryService) FlushStats() {} + +// ----------------------------- + +const ( + workerCleanupWait = 3 * time.Minute + jobsQueueMinSize = 2048 + + telemetryStatsUpdateInterval = time.Second * 30 +) + +type telemetryService struct { + AnalyticsService + + notifier webhook.QueuedNotifier + jobsQueue *utils.OpsQueue + + workersMu sync.RWMutex + workers map[livekit.ParticipantID]*StatsWorker + workerList *StatsWorker + + flushMu sync.Mutex +} + +func NewTelemetryService(notifier webhook.QueuedNotifier, analytics AnalyticsService) TelemetryService { + t := &telemetryService{ + AnalyticsService: analytics, + notifier: notifier, + jobsQueue: utils.NewOpsQueue(utils.OpsQueueParams{ + Name: "telemetry", + MinSize: jobsQueueMinSize, + FlushOnStop: true, + Logger: logger.GetLogger(), + }), + workers: make(map[livekit.ParticipantID]*StatsWorker), + } + if t.notifier != nil { + t.notifier.RegisterProcessedHook(func(ctx context.Context, whi *livekit.WebhookInfo) { + t.Webhook(ctx, whi) + }) + } + + t.jobsQueue.Start() + go t.run() + + return t +} + +func (t *telemetryService) FlushStats() { + t.flushMu.Lock() + defer t.flushMu.Unlock() + + t.workersMu.RLock() + worker := t.workerList + t.workersMu.RUnlock() + + now := time.Now() + var prev, reap *StatsWorker + for worker != nil { + next := worker.next + if closed := worker.Flush(now, workerCleanupWait); closed { + if prev == nil { + // this worker was at the head of the list + t.workersMu.Lock() + p := &t.workerList + for *p != worker { + // new workers have been added. scan until we find the one + // immediately before this + prev = *p + p = &prev.next + } + *p = worker.next + t.workersMu.Unlock() + } else { + prev.next = worker.next + } + + worker.next = reap + reap = worker + } else { + prev = worker + } + worker = next + } + + if reap != nil { + t.workersMu.Lock() + for reap != nil { + if reap == t.workers[reap.participantID] { + delete(t.workers, reap.participantID) + } + reap = reap.next + } + t.workersMu.Unlock() + } +} + +func (t *telemetryService) run() { + for range time.Tick(telemetryStatsUpdateInterval) { + t.FlushStats() + } +} + +func (t *telemetryService) enqueue(op func()) { + t.jobsQueue.Enqueue(op) +} + +func (t *telemetryService) getWorker(participantID livekit.ParticipantID) (worker *StatsWorker, ok bool) { + t.workersMu.RLock() + defer t.workersMu.RUnlock() + + worker, ok = t.workers[participantID] + return +} + +func (t *telemetryService) getOrCreateWorker( + ctx context.Context, + roomID livekit.RoomID, + roomName livekit.RoomName, + participantID livekit.ParticipantID, + participantIdentity livekit.ParticipantIdentity, + guard *ReferenceGuard, +) (*StatsWorker, bool) { + t.workersMu.Lock() + defer t.workersMu.Unlock() + + worker, ok := t.workers[participantID] + if ok && !worker.Closed(guard) { + return worker, true + } + + existingIsConnected := false + if ok { + existingIsConnected = worker.IsConnected() + } + + worker = newStatsWorker( + ctx, + t, + roomID, + roomName, + participantID, + participantIdentity, + guard, + ) + if existingIsConnected { + worker.SetConnected() + } + + t.workers[participantID] = worker + + worker.next = t.workerList + t.workerList = worker + + return worker, false +} + +func (t *telemetryService) LocalRoomState(ctx context.Context, info *livekit.AnalyticsNodeRooms) { + t.enqueue(func() { + t.SendNodeRoomStates(ctx, info) + }) +} diff --git a/livekit/pkg/testutils/timeout.go b/livekit/pkg/testutils/timeout.go new file mode 100644 index 0000000..7745b2e --- /dev/null +++ b/livekit/pkg/testutils/timeout.go @@ -0,0 +1,48 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import ( + "context" + "testing" + "time" +) + +var ( + ConnectTimeout = 30 * time.Second +) + +func WithTimeout(t *testing.T, f func() string, timeouts ...time.Duration) { + timeout := ConnectTimeout + if len(timeouts) > 0 { + timeout = timeouts[0] + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + lastErr := "" + for { + select { + case <-ctx.Done(): + if lastErr != "" { + t.Fatalf("did not reach expected state after %v: %s", timeout, lastErr) + } + case <-time.After(10 * time.Millisecond): + lastErr = f() + if lastErr == "" { + return + } + } + } +} diff --git a/livekit/pkg/utils/changenotifier.go b/livekit/pkg/utils/changenotifier.go new file mode 100644 index 0000000..2757e78 --- /dev/null +++ b/livekit/pkg/utils/changenotifier.go @@ -0,0 +1,112 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import "sync" + +type ChangeNotifier struct { + lock sync.Mutex + observers map[string]func() +} + +func NewChangeNotifier() *ChangeNotifier { + return &ChangeNotifier{ + observers: make(map[string]func()), + } +} + +func (n *ChangeNotifier) AddObserver(key string, onChanged func()) { + n.lock.Lock() + defer n.lock.Unlock() + + n.observers[key] = onChanged +} + +func (n *ChangeNotifier) RemoveObserver(key string) { + n.lock.Lock() + defer n.lock.Unlock() + + delete(n.observers, key) +} + +func (n *ChangeNotifier) HasObservers() bool { + n.lock.Lock() + defer n.lock.Unlock() + + return len(n.observers) > 0 +} + +func (n *ChangeNotifier) NotifyChanged() { + n.lock.Lock() + if len(n.observers) == 0 { + n.lock.Unlock() + return + } + observers := make([]func(), 0, len(n.observers)) + for _, f := range n.observers { + observers = append(observers, f) + } + n.lock.Unlock() + + go func() { + for _, f := range observers { + f() + } + }() +} + +type ChangeNotifierManager struct { + lock sync.Mutex + notifiers map[string]*ChangeNotifier +} + +func NewChangeNotifierManager() *ChangeNotifierManager { + return &ChangeNotifierManager{ + notifiers: make(map[string]*ChangeNotifier), + } +} + +func (m *ChangeNotifierManager) GetNotifier(key string) *ChangeNotifier { + m.lock.Lock() + defer m.lock.Unlock() + + return m.notifiers[key] +} + +func (m *ChangeNotifierManager) GetOrCreateNotifier(key string) *ChangeNotifier { + m.lock.Lock() + defer m.lock.Unlock() + + if notifier, ok := m.notifiers[key]; ok { + return notifier + } + + notifier := NewChangeNotifier() + m.notifiers[key] = notifier + return notifier +} + +func (m *ChangeNotifierManager) RemoveNotifier(key string, force bool) { + m.lock.Lock() + defer m.lock.Unlock() + + if notifier, ok := m.notifiers[key]; ok { + if force || !notifier.HasObservers() { + delete(m.notifiers, key) + } + } +} diff --git a/livekit/pkg/utils/context.go b/livekit/pkg/utils/context.go new file mode 100644 index 0000000..4ad9086 --- /dev/null +++ b/livekit/pkg/utils/context.go @@ -0,0 +1,48 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + + "github.com/livekit/protocol/logger" +) + +type attemptKey struct{} +type loggerKey = struct{} + +func ContextWithAttempt(ctx context.Context, attempt int) context.Context { + return context.WithValue(ctx, attemptKey{}, attempt) +} + +func GetAttempt(ctx context.Context) int { + if attempt, ok := ctx.Value(attemptKey{}).(int); ok { + return attempt + } + return 0 +} + +func ContextWithLogger(ctx context.Context, logger logger.Logger) context.Context { + return context.WithValue(ctx, loggerKey{}, logger) +} + +func GetLogger(ctx context.Context) logger.Logger { + if l, ok := ctx.Value(loggerKey{}).(logger.Logger); ok { + return l + } + return logger.GetLogger() +} diff --git a/livekit/pkg/utils/iceconfigcache.go b/livekit/pkg/utils/iceconfigcache.go new file mode 100644 index 0000000..e42ed2b --- /dev/null +++ b/livekit/pkg/utils/iceconfigcache.go @@ -0,0 +1,42 @@ +package utils + +import ( + "time" + + "github.com/jellydator/ttlcache/v3" + + "github.com/livekit/protocol/livekit" +) + +const ( + iceConfigTTLMin = 5 * time.Minute +) + +type IceConfigCache[T comparable] struct { + c *ttlcache.Cache[T, *livekit.ICEConfig] +} + +func NewIceConfigCache[T comparable](ttl time.Duration) *IceConfigCache[T] { + cache := ttlcache.New( + ttlcache.WithTTL[T, *livekit.ICEConfig](max(ttl, iceConfigTTLMin)), + ttlcache.WithDisableTouchOnHit[T, *livekit.ICEConfig](), + ) + go cache.Start() + + return &IceConfigCache[T]{cache} +} + +func (icc *IceConfigCache[T]) Stop() { + icc.c.Stop() +} + +func (icc *IceConfigCache[T]) Put(key T, iceConfig *livekit.ICEConfig) { + icc.c.Set(key, iceConfig, ttlcache.DefaultTTL) +} + +func (icc *IceConfigCache[T]) Get(key T) *livekit.ICEConfig { + if it := icc.c.Get(key); it != nil { + return it.Value() + } + return &livekit.ICEConfig{} +} diff --git a/livekit/pkg/utils/iceconfigcache_test.go b/livekit/pkg/utils/iceconfigcache_test.go new file mode 100644 index 0000000..78feebb --- /dev/null +++ b/livekit/pkg/utils/iceconfigcache_test.go @@ -0,0 +1,18 @@ +package utils + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" +) + +func TestIceConfigCache(t *testing.T) { + cache := NewIceConfigCache[string](10 * time.Second) + t.Cleanup(cache.Stop) + + cache.Put("test", &livekit.ICEConfig{}) + require.NotNil(t, cache) +} diff --git a/livekit/pkg/utils/incrementaldispatcher.go b/livekit/pkg/utils/incrementaldispatcher.go new file mode 100644 index 0000000..845ed3e --- /dev/null +++ b/livekit/pkg/utils/incrementaldispatcher.go @@ -0,0 +1,86 @@ +/* + * Copyright 2024 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "sync" + + "github.com/frostbyte73/core" +) + +// IncrementalDispatcher is a dispatcher that allows multiple consumers to consume items as they become +// available, while producers can add items at anytime. +type IncrementalDispatcher[T any] struct { + done core.Fuse + lock sync.RWMutex + cond *sync.Cond + items []T +} + +func NewIncrementalDispatcher[T any]() *IncrementalDispatcher[T] { + p := &IncrementalDispatcher[T]{} + p.cond = sync.NewCond(&p.lock) + return p +} + +func (d *IncrementalDispatcher[T]) Add(item T) { + if d.done.IsBroken() { + return + } + d.lock.Lock() + d.items = append(d.items, item) + d.cond.Broadcast() + d.lock.Unlock() +} + +func (d *IncrementalDispatcher[T]) Done() { + d.lock.Lock() + d.done.Break() + d.cond.Broadcast() + d.lock.Unlock() +} + +func (d *IncrementalDispatcher[T]) ForEach(fn func(T)) { + idx := 0 + dispatchFromIdx := func() { + var itemsToDispatch []T + d.lock.RLock() + for idx < len(d.items) { + itemsToDispatch = append(itemsToDispatch, d.items[idx]) + idx++ + } + d.lock.RUnlock() + for _, item := range itemsToDispatch { + fn(item) + } + } + for !d.done.IsBroken() { + dispatchFromIdx() + d.lock.Lock() + // need to check again because Done may have been called while dispatching + if d.done.IsBroken() { + d.lock.Unlock() + break + } + if idx == len(d.items) { + d.cond.Wait() + } + d.lock.Unlock() + } + + dispatchFromIdx() +} diff --git a/livekit/pkg/utils/incrementaldispatcher_test.go b/livekit/pkg/utils/incrementaldispatcher_test.go new file mode 100644 index 0000000..4b756b9 --- /dev/null +++ b/livekit/pkg/utils/incrementaldispatcher_test.go @@ -0,0 +1,97 @@ +/* + * Copyright 2024 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils_test + +import ( + "fmt" + "sync" + "testing" + "time" + + "go.uber.org/atomic" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/livekit-server/pkg/utils" +) + +func TestForEach(t *testing.T) { + producer := utils.NewIncrementalDispatcher[int]() + go func() { + defer producer.Done() + producer.Add(1) + producer.Add(2) + producer.Add(3) + }() + + sum := 0 + producer.ForEach(func(item int) { + sum += item + }) + + require.Equal(t, 6, sum) +} + +func TestConcurrentConsumption(t *testing.T) { + producer := utils.NewIncrementalDispatcher[int]() + numConsumers := 100 + sums := make([]atomic.Int32, numConsumers) + var wg sync.WaitGroup + + for i := range numConsumers { + wg.Add(1) + i := i + go func() { + defer wg.Done() + producer.ForEach(func(item int) { + sums[i].Add(int32(item)) + }) + }() + } + + // Add items + expectedSum := 0 + for i := range 20 { + expectedSum += i + producer.Add(i) + } + + for i := range numConsumers { + testutils.WithTimeout(t, func() string { + if sums[i].Load() != int32(expectedSum) { + return fmt.Sprintf("consumer %d did not consume all the items. expected %d, actual: %d", + i, expectedSum, sums[i].Load()) + } + return "" + }, time.Second) + } + + // keep adding and ensure it's consumed + for i := 20; i < 30; i++ { + expectedSum += i + producer.Add(i) + } + + // wait for all consumers to finish + producer.Done() + wg.Wait() + + for i := range numConsumers { + require.Equal(t, int32(expectedSum), sums[i].Load(), "consumer %d did not match", i) + } +} diff --git a/livekit/pkg/utils/logging.go b/livekit/pkg/utils/logging.go new file mode 100644 index 0000000..5e9a3db --- /dev/null +++ b/livekit/pkg/utils/logging.go @@ -0,0 +1,28 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +const ( + ComponentPub = "pub" + ComponentSub = "sub" + ComponentRoom = "room" + ComponentAPI = "api" + ComponentTransport = "transport" + ComponentSFU = "sfu" + // transport subcomponents + ComponentCongestionControl = "cc" +) diff --git a/livekit/pkg/utils/math.go b/livekit/pkg/utils/math.go new file mode 100644 index 0000000..9e75c1a --- /dev/null +++ b/livekit/pkg/utils/math.go @@ -0,0 +1,36 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import "sort" + +// MedianFloat32 gets median value for an array of float32 +func MedianFloat32(input []float32) float32 { + num := len(input) + if num == 0 { + return 0 + } else if num == 1 { + return input[0] + } + sort.Slice(input, func(i, j int) bool { + return input[i] < input[j] + }) + if num%2 != 0 { + return input[num/2] + } + left := input[num/2-1] + right := input[num/2] + return (left + right) / 2 +} diff --git a/livekit/pkg/utils/opsqueue.go b/livekit/pkg/utils/opsqueue.go new file mode 100644 index 0000000..65e0daf --- /dev/null +++ b/livekit/pkg/utils/opsqueue.go @@ -0,0 +1,157 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "sync" + + "github.com/gammazero/deque" + + "github.com/livekit/protocol/logger" +) + +type OpsQueueParams struct { + Name string + MinSize uint + FlushOnStop bool + Logger logger.Logger +} + +type UntypedQueueOp func() + +func (op UntypedQueueOp) run() { + op() +} + +type OpsQueue struct { + opsQueueBase[UntypedQueueOp] +} + +func NewOpsQueue(params OpsQueueParams) *OpsQueue { + return &OpsQueue{*newOpsQueueBase[UntypedQueueOp](params)} +} + +type typedQueueOp[T any] struct { + fn func(T) + arg T +} + +func (op typedQueueOp[T]) run() { + op.fn(op.arg) +} + +type TypedOpsQueue[T any] struct { + opsQueueBase[typedQueueOp[T]] +} + +func NewTypedOpsQueue[T any](params OpsQueueParams) *TypedOpsQueue[T] { + return &TypedOpsQueue[T]{*newOpsQueueBase[typedQueueOp[T]](params)} +} + +func (oq *TypedOpsQueue[T]) Enqueue(fn func(T), arg T) { + oq.opsQueueBase.Enqueue(typedQueueOp[T]{fn, arg}) +} + +type opsQueueItem interface { + run() +} + +type opsQueueBase[T opsQueueItem] struct { + params OpsQueueParams + + lock sync.Mutex + ops deque.Deque[T] + wake chan struct{} + isStarted bool + doneChan chan struct{} + isStopped bool +} + +func newOpsQueueBase[T opsQueueItem](params OpsQueueParams) *opsQueueBase[T] { + o := &opsQueueBase[T]{ + params: params, + wake: make(chan struct{}, 1), + doneChan: make(chan struct{}), + } + o.ops.SetBaseCap(int(min(params.MinSize, 128))) + return o +} + +func (oq *opsQueueBase[T]) Start() { + oq.lock.Lock() + if oq.isStarted { + oq.lock.Unlock() + return + } + + oq.isStarted = true + oq.lock.Unlock() + + go oq.process() +} + +func (oq *opsQueueBase[T]) Stop() <-chan struct{} { + oq.lock.Lock() + if oq.isStopped { + oq.lock.Unlock() + return oq.doneChan + } + + oq.isStopped = true + close(oq.wake) + oq.lock.Unlock() + return oq.doneChan +} + +func (oq *opsQueueBase[T]) Enqueue(op T) { + oq.lock.Lock() + defer oq.lock.Unlock() + + if oq.isStopped { + return + } + + oq.ops.PushBack(op) + if oq.ops.Len() == 1 { + select { + case oq.wake <- struct{}{}: + default: + } + } +} + +func (oq *opsQueueBase[T]) process() { + defer close(oq.doneChan) + + for { + <-oq.wake + for { + oq.lock.Lock() + if oq.isStopped && (!oq.params.FlushOnStop || oq.ops.Len() == 0) { + oq.lock.Unlock() + return + } + + if oq.ops.Len() == 0 { + oq.lock.Unlock() + break + } + op := oq.ops.PopFront() + oq.lock.Unlock() + + op.run() + } + } +} diff --git a/livekit/pkg/utils/protocol.go b/livekit/pkg/utils/protocol.go new file mode 100644 index 0000000..77d324b --- /dev/null +++ b/livekit/pkg/utils/protocol.go @@ -0,0 +1,31 @@ +/* + * Copyright 2023 LiveKit, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/utils" +) + +func ClientInfoWithoutAddress(c *livekit.ClientInfo) *livekit.ClientInfo { + if c == nil { + return nil + } + clone := utils.CloneProto(c) + clone.Address = "" + return clone +} diff --git a/livekit/pkg/utils/slice.go b/livekit/pkg/utils/slice.go new file mode 100644 index 0000000..1d692ef --- /dev/null +++ b/livekit/pkg/utils/slice.go @@ -0,0 +1,11 @@ +package utils + +import ( + "cmp" + "slices" +) + +func DedupeSlice[T cmp.Ordered](s []T) []T { + slices.Sort(s) + return slices.Compact(s) +} diff --git a/livekit/renovate.json b/livekit/renovate.json new file mode 100644 index 0000000..b472aef --- /dev/null +++ b/livekit/renovate.json @@ -0,0 +1,33 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": ["config:base"], + "commitBody": "Generated by renovateBot", + "packageRules": [ + { + "schedule": "before 6am every monday", + "matchManagers": ["github-actions"], + "groupName": "github workflows" + }, + { + "schedule": "before 6am every monday", + "matchManagers": ["dockerfile"], + "groupName": "docker deps" + }, + { + "schedule": "before 6am every monday", + "matchManagers": ["gomod"], + "groupName": "go deps" + }, + { + "matchManagers": ["gomod"], + "matchPackagePrefixes": ["github.com/pion"], + "groupName": "pion deps" + }, + { + "matchManagers": ["gomod"], + "matchPackagePrefixes": ["github.com/livekit"], + "groupName": "livekit deps" + } + ], + "postUpdateOptions": ["gomodTidy"] +} diff --git a/livekit/test/agent.go b/livekit/test/agent.go new file mode 100644 index 0000000..3640524 --- /dev/null +++ b/livekit/test/agent.go @@ -0,0 +1,198 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "net/http" + "net/url" + "sync" + + "github.com/gorilla/websocket" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" +) + +type agentClient struct { + mu sync.Mutex + conn *websocket.Conn + + registered atomic.Int32 + roomAvailability atomic.Int32 + roomJobs atomic.Int32 + publisherAvailability atomic.Int32 + publisherJobs atomic.Int32 + participantAvailability atomic.Int32 + participantJobs atomic.Int32 + + requestedJobs chan *livekit.Job + + done chan struct{} +} + +func newAgentClient(token string, port uint32) (*agentClient, error) { + host := fmt.Sprintf("ws://localhost:%d", port) + u, err := url.Parse(host + "/agent") + if err != nil { + return nil, err + } + requestHeader := make(http.Header) + requestHeader.Set("Authorization", "Bearer "+token) + + connectUrl := u.String() + conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader) + if err != nil { + return nil, err + } + + return &agentClient{ + conn: conn, + requestedJobs: make(chan *livekit.Job, 100), + done: make(chan struct{}), + }, nil +} + +func (c *agentClient) Run(jobType livekit.JobType, namespace string) (err error) { + go c.read() + + switch jobType { + case livekit.JobType_JT_ROOM: + err = c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_ROOM, + Version: "version", + Namespace: &namespace, + }, + }, + }) + + case livekit.JobType_JT_PUBLISHER: + err = c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_PUBLISHER, + Version: "version", + Namespace: &namespace, + }, + }, + }) + + case livekit.JobType_JT_PARTICIPANT: + err = c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Register{ + Register: &livekit.RegisterWorkerRequest{ + Type: livekit.JobType_JT_PARTICIPANT, + Version: "version", + Namespace: &namespace, + }, + }, + }) + } + + return err +} + +func (c *agentClient) read() { + for { + select { + case <-c.done: + return + default: + _, b, err := c.conn.ReadMessage() + if err != nil { + return + } + + msg := &livekit.ServerMessage{} + if err = proto.Unmarshal(b, msg); err != nil { + return + } + + switch m := msg.Message.(type) { + case *livekit.ServerMessage_Assignment: + go c.handleAssignment(m.Assignment) + case *livekit.ServerMessage_Availability: + go c.handleAvailability(m.Availability) + case *livekit.ServerMessage_Register: + go c.handleRegister(m.Register) + } + } + } +} + +func (c *agentClient) handleAssignment(req *livekit.JobAssignment) { + switch req.Job.Type { + case livekit.JobType_JT_ROOM: + c.roomJobs.Inc() + case livekit.JobType_JT_PUBLISHER: + c.publisherJobs.Inc() + case livekit.JobType_JT_PARTICIPANT: + c.participantJobs.Inc() + } +} + +func (c *agentClient) handleAvailability(req *livekit.AvailabilityRequest) { + switch req.Job.Type { + case livekit.JobType_JT_ROOM: + c.roomAvailability.Inc() + case livekit.JobType_JT_PUBLISHER: + c.publisherAvailability.Inc() + case livekit.JobType_JT_PARTICIPANT: + c.participantAvailability.Inc() + } + + c.requestedJobs <- req.Job + + c.write(&livekit.WorkerMessage{ + Message: &livekit.WorkerMessage_Availability{ + Availability: &livekit.AvailabilityResponse{ + JobId: req.Job.Id, + Available: true, + }, + }, + }) +} + +func (c *agentClient) handleRegister(req *livekit.RegisterWorkerResponse) { + c.registered.Inc() +} + +func (c *agentClient) write(msg *livekit.WorkerMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + select { + case <-c.done: + return nil + default: + b, err := proto.Marshal(msg) + if err != nil { + return err + } + + return c.conn.WriteMessage(websocket.BinaryMessage, b) + } +} + +func (c *agentClient) close() { + c.mu.Lock() + close(c.done) + _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + _ = c.conn.Close() + c.mu.Unlock() +} diff --git a/livekit/test/agent_test.go b/livekit/test/agent_test.go new file mode 100644 index 0000000..9b31452 --- /dev/null +++ b/livekit/test/agent_test.go @@ -0,0 +1,244 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" +) + +var ( + RegisterTimeout = 2 * time.Second + AssignJobTimeout = 3 * time.Second +) + +func TestAgents(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, finish := setupSingleNodeTest("TestAgents") + defer finish() + + ac1, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac3, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac4, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac5, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac6, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + defer ac3.close() + defer ac4.close() + defer ac5.close() + defer ac6.close() + ac1.Run(livekit.JobType_JT_ROOM, "default") + ac2.Run(livekit.JobType_JT_ROOM, "default") + ac3.Run(livekit.JobType_JT_PUBLISHER, "default") + ac4.Run(livekit.JobType_JT_PUBLISHER, "default") + ac5.Run(livekit.JobType_JT_PARTICIPANT, "default") + ac6.Run(livekit.JobType_JT_PARTICIPANT, "default") + + testutils.WithTimeout(t, func() string { + if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 || ac3.registered.Load() != 1 || ac4.registered.Load() != 1 || ac5.registered.Load() != 1 || ac6.registered.Load() != 1 { + return "worker not registered" + } + + return "" + }, RegisterTimeout) + + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + + // publish 2 tracks + t1, err := c1.AddStaticTrack("audio/opus", "audio", "micro") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() + + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load()+ac2.roomJobs.Load() != 1 { + return "room job not assigned" + } + + if ac3.publisherJobs.Load()+ac4.publisherJobs.Load() != 1 { + return fmt.Sprintf("publisher jobs not assigned, ac3: %d, ac4: %d", ac3.publisherJobs.Load(), ac4.publisherJobs.Load()) + } + + if ac5.participantJobs.Load()+ac6.participantJobs.Load() != 2 { + return fmt.Sprintf("participant jobs not assigned, ac5: %d, ac6: %d", ac5.participantJobs.Load(), ac6.participantJobs.Load()) + } + + return "" + }, 6*time.Second) + + // publish 2 tracks + t3, err := c2.AddStaticTrack("audio/opus", "audio", "micro") + require.NoError(t, err) + defer t3.Stop() + t4, err := c2.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t4.Stop() + + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load()+ac2.roomJobs.Load() != 1 { + return "room job must be assigned 1 time" + } + + if ac3.publisherJobs.Load()+ac4.publisherJobs.Load() != 2 { + return "2 publisher jobs must assigned" + } + + if ac5.participantJobs.Load()+ac6.participantJobs.Load() != 2 { + return "2 participant jobs must assigned" + } + + return "" + }, AssignJobTimeout) + }) + } +} + +func TestAgentNamespaces(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, finish := setupSingleNodeTest("TestAgentNamespaces") + defer finish() + + ac1, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + ac1.Run(livekit.JobType_JT_ROOM, "namespace1") + ac2.Run(livekit.JobType_JT_ROOM, "namespace2") + + _, err = roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: testRoom, + Agents: []*livekit.RoomAgentDispatch{ + {}, + { + AgentName: "ag", + }, + }, + }) + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 { + return "worker not registered" + } + return "" + }, RegisterTimeout) + + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load() != 1 || ac2.roomJobs.Load() != 1 { + return "room job not assigned" + } + + job1 := <-ac1.requestedJobs + job2 := <-ac2.requestedJobs + + if job1.Namespace != "namespace1" { + return "namespace is not 'namespace'" + } + + if job2.Namespace != "namespace2" { + return "namespace is not 'namespace2'" + } + + if job1.Id == job2.Id { + return "job ids are the same" + } + + return "" + }, AssignJobTimeout) + }) + } +} + +func TestAgentMultiNode(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestAgentMultiNode") + defer finish() + + ac1, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + ac2, err := newAgentClient(agentToken(), defaultServerPort) + require.NoError(t, err) + defer ac1.close() + defer ac2.close() + ac1.Run(livekit.JobType_JT_ROOM, "default") + ac2.Run(livekit.JobType_JT_PUBLISHER, "default") + + testutils.WithTimeout(t, func() string { + if ac1.registered.Load() != 1 || ac2.registered.Load() != 1 { + return "worker not registered" + } + return "" + }, RegisterTimeout) + + c1 := createRTCClient("c1", secondServerPort, testRTCServicePath, nil) // Create a room on the second node + waitUntilConnected(t, c1) + + t1, err := c1.AddStaticTrack("audio/opus", "audio", "micro") + require.NoError(t, err) + defer t1.Stop() + + time.Sleep(time.Second * 10) + + testutils.WithTimeout(t, func() string { + if ac1.roomJobs.Load() != 1 { + return "room job not assigned" + } + + if ac2.publisherJobs.Load() != 1 { + return "participant job not assigned" + } + + return "" + }, AssignJobTimeout) + }) + } +} + +func agentToken() string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(&auth.VideoGrant{Agent: true}) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} diff --git a/livekit/test/client/client.go b/livekit/test/client/client.go new file mode 100644 index 0000000..f98e82d --- /dev/null +++ b/livekit/test/client/client.go @@ -0,0 +1,1317 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "math/rand/v2" + "net/http" + "net/url" + "path/filepath" + "runtime" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" + "github.com/thoas/go-funk" + "go.uber.org/atomic" + "google.golang.org/protobuf/proto" + + "github.com/livekit/mediatransportutil/pkg/rtcconfig" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/signalling" + + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/transport/transportfakes" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/sfu/buffer" + "github.com/livekit/livekit-server/pkg/sfu/mime" +) + +type SignalRequestHandler func(msg *livekit.SignalRequest) error +type SignalRequestInterceptor func(msg *livekit.SignalRequest, next SignalRequestHandler) error +type SignalResponseHandler func(msg *livekit.SignalResponse) error +type SignalResponseInterceptor func(msg *livekit.SignalResponse, next SignalResponseHandler) error + +type RTCClient struct { + useSinglePeerConnection bool + id livekit.ParticipantID + conn *websocket.Conn + publisher *rtc.PCTransport + subscriber *rtc.PCTransport + // sid => track + localTracks map[string]webrtc.TrackLocal + trackSenders map[string]*webrtc.RTPSender + lock sync.Mutex + wsLock sync.Mutex + ctx context.Context + cancel context.CancelFunc + me *webrtc.MediaEngine // optional, populated only when receiving tracks + subscribedTracks map[livekit.ParticipantID][]*webrtc.TrackRemote + localParticipant *livekit.ParticipantInfo + remoteParticipants map[livekit.ParticipantID]*livekit.ParticipantInfo + + signalRequestInterceptor SignalRequestInterceptor + signalResponseInterceptor SignalResponseInterceptor + + icQueue [2]atomic.Pointer[webrtc.ICECandidate] + + subscriberAsPrimary atomic.Bool + publisherFullyEstablished atomic.Bool + subscriberFullyEstablished atomic.Bool + pongReceivedAt atomic.Int64 + + // tracks waiting to be acked, cid => trackInfo + pendingPublishedTracks map[string]*livekit.TrackInfo + + // remote tracks waiting to be processed + pendingRemoteTracks []*webrtc.TrackRemote + + pendingTrackWriters []TrackWriter + OnConnected func() + OnDataReceived func(data []byte, sid string) + OnDataUnlabeledReceived func(data []byte) + refreshToken string + + // map of livekit.ParticipantID and last packet + lastPackets map[livekit.ParticipantID]*rtp.Packet + bytesReceived map[livekit.ParticipantID]uint64 + + subscriptionResponse atomic.Pointer[livekit.SubscriptionResponse] + + nextDataTrackHandle atomic.Uint32 + pendingPublishedDataTracks map[uint16]*livekit.DataTrackInfo + pendingDataTrackWriters []TrackWriter + subscribedDataTracks map[livekit.ParticipantID]map[uint16]*DataTrackRemote +} + +var ( + // minimal settings only with stun server + rtcConf = webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{ + { + URLs: []string{"stun:stun.l.google.com:19302"}, + }, + }, + } + extMimeMapping = map[string]string{ + ".ivf": mime.MimeTypeVP8.String(), + ".h264": mime.MimeTypeH264.String(), + ".ogg": mime.MimeTypeOpus.String(), + } +) + +type Options struct { + AutoSubscribe bool + AutoSubscribeDataTrack bool + Publish string + Attributes map[string]string + ClientInfo *livekit.ClientInfo + DisabledCodecs []webrtc.RTPCodecCapability + TokenCustomizer func(token *auth.AccessToken, grants *auth.VideoGrant) + SignalRequestInterceptor SignalRequestInterceptor + SignalResponseInterceptor SignalResponseInterceptor + UseJoinRequestQueryParam bool + RTCServicePath string +} + +func NewWebSocketConn(host, token string, opts *Options) (*websocket.Conn, error) { + rtcServicePath := "/rtc" + if opts != nil && opts.RTCServicePath != "" { + rtcServicePath = opts.RTCServicePath + } + parsedURL, err := url.Parse(host + rtcServicePath) + if err != nil { + return nil, err + } + requestHeader := make(http.Header) + SetAuthorizationToken(requestHeader, token) + + connectUrl := parsedURL.String() + if opts != nil && opts.UseJoinRequestQueryParam { + clientInfo := &livekit.ClientInfo{ + Os: runtime.GOOS, + Sdk: livekit.ClientInfo_GO, + Protocol: int32(types.CurrentProtocol), + } + if opts.ClientInfo != nil { + clientInfo = opts.ClientInfo + } + + connectionSettings := &livekit.ConnectionSettings{ + AutoSubscribe: opts.AutoSubscribe, + AutoSubscribeDataTrack: &opts.AutoSubscribeDataTrack, + } + + joinRequest := &livekit.JoinRequest{ + ClientInfo: clientInfo, + ConnectionSettings: connectionSettings, + ParticipantAttributes: opts.Attributes, + } + + if marshalled, err := proto.Marshal(joinRequest); err == nil { + wrapped := &livekit.WrappedJoinRequest{ + JoinRequest: marshalled, + } + if marshalled, err := proto.Marshal(wrapped); err == nil { + connectUrl += fmt.Sprintf("?join_request=%s", base64.URLEncoding.EncodeToString(marshalled)) + } + } + } else { + connectUrl += fmt.Sprintf("?protocol=%d", types.CurrentProtocol) + + sdk := "go" + if opts != nil { + connectUrl += fmt.Sprintf("&auto_subscribe=%t", opts.AutoSubscribe) + connectUrl += fmt.Sprintf("&auto_subscribe_data_track=%t", opts.AutoSubscribeDataTrack) + if opts.Publish != "" { + connectUrl += encodeQueryParam("publish", opts.Publish) + } + if len(opts.Attributes) != 0 { + data, err := json.Marshal(opts.Attributes) + if err != nil { + return nil, err + } + connectUrl += encodeQueryParam("attributes", base64.URLEncoding.EncodeToString(data)) + } + if opts.ClientInfo != nil { + if opts.ClientInfo.DeviceModel != "" { + connectUrl += encodeQueryParam("device_model", opts.ClientInfo.DeviceModel) + } + if opts.ClientInfo.Os != "" { + connectUrl += encodeQueryParam("os", opts.ClientInfo.Os) + } + if opts.ClientInfo.Sdk != livekit.ClientInfo_UNKNOWN { + sdk = opts.ClientInfo.Sdk.String() + } + } + } + connectUrl += encodeQueryParam("sdk", sdk) + } + + logger.Infow("connecting to", "url", parsedURL.String()) + conn, _, err := websocket.DefaultDialer.Dial(connectUrl, requestHeader) + return conn, err +} + +func SetAuthorizationToken(header http.Header, token string) { + header.Set("Authorization", "Bearer "+token) +} + +func NewRTCClient(conn *websocket.Conn, useSinglePeerConnection bool, opts *Options) (*RTCClient, error) { + var err error + + c := &RTCClient{ + useSinglePeerConnection: useSinglePeerConnection, + conn: conn, + localTracks: make(map[string]webrtc.TrackLocal), + trackSenders: make(map[string]*webrtc.RTPSender), + pendingPublishedTracks: make(map[string]*livekit.TrackInfo), + subscribedTracks: make(map[livekit.ParticipantID][]*webrtc.TrackRemote), + remoteParticipants: make(map[livekit.ParticipantID]*livekit.ParticipantInfo), + me: &webrtc.MediaEngine{}, + lastPackets: make(map[livekit.ParticipantID]*rtp.Packet), + bytesReceived: make(map[livekit.ParticipantID]uint64), + pendingPublishedDataTracks: make(map[uint16]*livekit.DataTrackInfo), + subscribedDataTracks: make(map[livekit.ParticipantID]map[uint16]*DataTrackRemote), + } + c.nextDataTrackHandle.Store(uint32(rand.IntN(8192))) + c.ctx, c.cancel = context.WithCancel(context.Background()) + + conf := rtc.WebRTCConfig{ + WebRTCConfig: rtcconfig.WebRTCConfig{ + Configuration: rtcConf, + }, + } + conf.SettingEngine.SetLite(false) + conf.SettingEngine.SetAnsweringDTLSRole(webrtc.DTLSRoleClient) + ff := buffer.NewFactoryOfBufferFactory(500, 200) + conf.SetBufferFactory(ff.CreateBufferFactory()) + var codecs []*livekit.Codec + for _, codec := range []*livekit.Codec{ + { + Mime: "audio/opus", + }, + { + Mime: "video/vp8", + }, + { + Mime: "video/h264", + }, + } { + var disabled bool + if opts != nil { + for _, dc := range opts.DisabledCodecs { + if mime.IsMimeTypeStringEqual(dc.MimeType, codec.Mime) && (dc.SDPFmtpLine == "" || dc.SDPFmtpLine == codec.FmtpLine) { + disabled = true + break + } + } + } + if !disabled { + codecs = append(codecs, codec) + } + } + + // + // The signal targets are from point of view of server. + // From client side, they are flipped, + // i. e. the publisher transport on client side has SUBSCRIBER signal target (i. e. publisher is offerer). + // Same applies for subscriber transport also + // + publisherHandler := &transportfakes.FakeHandler{} + c.publisher, err = rtc.NewPCTransport(rtc.TransportParams{ + Config: &conf, + DirectionConfig: conf.Subscriber, + EnabledCodecs: codecs, + IsOfferer: true, + IsSendSide: true, + Handler: publisherHandler, + DatachannelMaxReceiverBufferSize: 1500, + DatachannelSlowThreshold: 1024 * 1024 * 1024, + FireOnTrackBySdp: true, + EnableDataTracks: true, + }) + if err != nil { + return nil, err + } + + publisherHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { + return c.SendIceCandidate(ic, livekit.SignalTarget_PUBLISHER) + }) + publisherHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + go c.processRemoteTrack(track) + }) + publisherHandler.OnDataMessageCalls(c.handleDataMessage) + publisherHandler.OnDataMessageUnlabeledCalls(c.handleDataMessageUnlabeled) + publisherHandler.OnDataTrackMessageCalls(c.handleDataTrackMessage) + publisherHandler.OnInitialConnectedCalls(func() { + logger.Debugw("publisher initial connected", "participant", c.localParticipant.Identity) + + c.lock.Lock() + defer c.lock.Unlock() + for _, tw := range c.pendingTrackWriters { + if err := tw.Start(); err != nil { + logger.Errorw("track writer error", err) + } + } + c.pendingTrackWriters = nil + + for _, dtw := range c.pendingDataTrackWriters { + if err := dtw.Start(); err != nil { + logger.Errorw("data track writer error", err) + } + } + c.pendingDataTrackWriters = nil + + if c.OnConnected != nil { + go c.OnConnected() + } + }) + publisherHandler.OnOfferCalls(c.onOffer) + publisherHandler.OnFullyEstablishedCalls(func() { + logger.Debugw("publisher fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) + c.publisherFullyEstablished.Store(true) + }) + + ordered := true + if err := c.publisher.CreateDataChannel(rtc.ReliableDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { + return nil, err + } + + if err := c.publisher.CreateDataChannel("pubraw", &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { + return nil, err + } + + ordered = false + maxRetransmits := uint16(0) + if err := c.publisher.CreateDataChannel(rtc.LossyDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &maxRetransmits, + }); err != nil { + return nil, err + } + + if err := c.publisher.CreateDataChannel(rtc.DataTrackDataChannel, &webrtc.DataChannelInit{ + Ordered: &ordered, + MaxRetransmits: &maxRetransmits, + }); err != nil { + return nil, err + } + + if !c.useSinglePeerConnection { + subscriberHandler := &transportfakes.FakeHandler{} + c.subscriber, err = rtc.NewPCTransport(rtc.TransportParams{ + Config: &conf, + DirectionConfig: conf.Publisher, + EnabledCodecs: codecs, + Handler: subscriberHandler, + DatachannelMaxReceiverBufferSize: 1500, + DatachannelSlowThreshold: 1024 * 1024 * 1024, + FireOnTrackBySdp: true, + EnableDataTracks: true, + }) + if err != nil { + return nil, err + } + + ordered := true + if err := c.subscriber.CreateReadableDataChannel("subraw", &webrtc.DataChannelInit{ + Ordered: &ordered, + }); err != nil { + return nil, err + } + + subscriberHandler.OnICECandidateCalls(func(ic *webrtc.ICECandidate, t livekit.SignalTarget) error { + if ic == nil { + return nil + } + return c.SendIceCandidate(ic, livekit.SignalTarget_SUBSCRIBER) + }) + subscriberHandler.OnTrackCalls(func(track *webrtc.TrackRemote, rtpReceiver *webrtc.RTPReceiver) { + go c.processRemoteTrack(track) + }) + subscriberHandler.OnDataMessageCalls(c.handleDataMessage) + subscriberHandler.OnDataMessageUnlabeledCalls(c.handleDataMessageUnlabeled) + subscriberHandler.OnDataTrackMessageCalls(c.handleDataTrackMessage) + subscriberHandler.OnInitialConnectedCalls(func() { + logger.Debugw("subscriber initial connected", "participant", c.localParticipant.Identity) + + c.lock.Lock() + defer c.lock.Unlock() + for _, tw := range c.pendingTrackWriters { + if err := tw.Start(); err != nil { + logger.Errorw("track writer error", err) + } + } + c.pendingTrackWriters = nil + + for _, dtw := range c.pendingDataTrackWriters { + if err := dtw.Start(); err != nil { + logger.Errorw("data track writer error", err) + } + } + c.pendingDataTrackWriters = nil + + if c.OnConnected != nil { + go c.OnConnected() + } + }) + subscriberHandler.OnFullyEstablishedCalls(func() { + logger.Debugw("subscriber fully established", "participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid) + c.subscriberFullyEstablished.Store(true) + }) + subscriberHandler.OnAnswerCalls(func(answer webrtc.SessionDescription, answerId uint32, _midToTrackID map[string]string) error { + // send remote an answer + logger.Infow( + "sending subscriber answer", + "participant", c.localParticipant.Identity, + "sdp", answer, + ) + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Answer{ + Answer: signalling.ToProtoSessionDescription(answer, answerId, nil), + }, + }) + }) + } else { + go c.ensurePublisherConnected() + } + + if opts != nil { + c.signalRequestInterceptor = opts.SignalRequestInterceptor + c.signalResponseInterceptor = opts.SignalResponseInterceptor + } + + return c, nil +} + +func (c *RTCClient) ID() livekit.ParticipantID { + return c.id +} + +// create an offer for the server +func (c *RTCClient) Run() error { + c.conn.SetCloseHandler(func(code int, text string) error { + // when closed, stop connection + logger.Infow("connection closed", "code", code, "text", text) + c.Stop() + return nil + }) + + // run the session + for { + res, err := c.ReadResponse() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + logger.Errorw("error while reading", err) + return err + } + if c.signalResponseInterceptor != nil { + err = c.signalResponseInterceptor(res, c.handleSignalResponse) + } else { + err = c.handleSignalResponse(res) + } + if err != nil { + return err + } + } +} + +func (c *RTCClient) handleSignalResponse(res *livekit.SignalResponse) error { + switch msg := res.Message.(type) { + case *livekit.SignalResponse_Join: + c.localParticipant = msg.Join.Participant + c.id = livekit.ParticipantID(msg.Join.Participant.Sid) + c.lock.Lock() + for _, p := range msg.Join.OtherParticipants { + c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p + } + c.lock.Unlock() + // if publish only, negotiate + if !msg.Join.SubscriberPrimary { + c.subscriberAsPrimary.Store(false) + c.publisher.Negotiate(false) + } else { + c.subscriberAsPrimary.Store(true) + } + + if c.subscriber != nil { + logger.Infow("join accepted, awaiting offer", "participant", msg.Join.Participant.Identity) + } else { + logger.Infow("join accepted", "participant", msg.Join.Participant.Identity) + } + + case *livekit.SignalResponse_Answer: + logger.Infow( + "received server answer", + "participant", c.localParticipant.Identity, + "answer", msg.Answer.Sdp, + ) + c.handleAnswer(signalling.FromProtoSessionDescription(msg.Answer)) + + case *livekit.SignalResponse_Offer: + desc, offerId, midToTrackID := signalling.FromProtoSessionDescription(msg.Offer) + logger.Infow( + "received server offer", + "participant", c.localParticipant.Identity, + "sdp", desc, + "offerId", offerId, + "midToTrackID", midToTrackID, + ) + c.handleOffer(desc, offerId, midToTrackID) + + case *livekit.SignalResponse_Trickle: + candidateInit, err := signalling.FromProtoTrickle(msg.Trickle) + if err != nil { + return err + } + if msg.Trickle.Target == livekit.SignalTarget_PUBLISHER { + c.publisher.AddICECandidate(candidateInit) + } else { + c.subscriber.AddICECandidate(candidateInit) + } + + case *livekit.SignalResponse_Update: + c.lock.Lock() + for _, p := range msg.Update.Participants { + if livekit.ParticipantID(p.Sid) != c.id { + if p.State != livekit.ParticipantInfo_DISCONNECTED { + c.remoteParticipants[livekit.ParticipantID(p.Sid)] = p + } else { + delete(c.remoteParticipants, livekit.ParticipantID(p.Sid)) + } + } + } + c.lock.Unlock() + + case *livekit.SignalResponse_TrackPublished: + logger.Debugw( + "track published", + "participant", c.localParticipant.Identity, + "cid", msg.TrackPublished.Cid, + "trackID", msg.TrackPublished.Track.Sid, + "trackName", msg.TrackPublished.Track.Name, + ) + c.lock.Lock() + c.pendingPublishedTracks[msg.TrackPublished.Cid] = msg.TrackPublished.Track + c.lock.Unlock() + + case *livekit.SignalResponse_RefreshToken: + c.lock.Lock() + c.refreshToken = msg.RefreshToken + c.lock.Unlock() + + case *livekit.SignalResponse_TrackUnpublished: + sid := msg.TrackUnpublished.TrackSid + c.lock.Lock() + if sender := c.trackSenders[sid]; sender != nil { + if err := c.publisher.RemoveTrack(sender); err != nil { + logger.Errorw("Could not unpublish track", err) + } + c.publisher.Negotiate(false) + } + delete(c.trackSenders, sid) + delete(c.localTracks, sid) + c.lock.Unlock() + + case *livekit.SignalResponse_Pong: + c.pongReceivedAt.Store(msg.Pong) + + case *livekit.SignalResponse_SubscriptionResponse: + c.subscriptionResponse.Store(msg.SubscriptionResponse) + + case *livekit.SignalResponse_MediaSectionsRequirement: + logger.Infow( + "received media sections requirement", + "participant", c.localParticipant.Identity, + "numAudios", msg.MediaSectionsRequirement.NumAudios, + "numVideos", msg.MediaSectionsRequirement.NumVideos, + ) + c.handleMediaSectionsRequirement(msg.MediaSectionsRequirement) + + case *livekit.SignalResponse_PublishDataTrackResponse: + logger.Debugw( + "data track published", + "participant", c.localParticipant.Identity, + "trackID", msg.PublishDataTrackResponse.Info.Sid, + "trackHandle", msg.PublishDataTrackResponse.Info.PubHandle, + "trackName", msg.PublishDataTrackResponse.Info.Name, + ) + c.lock.Lock() + c.pendingPublishedDataTracks[uint16(msg.PublishDataTrackResponse.Info.PubHandle)] = msg.PublishDataTrackResponse.Info + c.lock.Unlock() + + case *livekit.SignalResponse_DataTrackSubscriberHandles: + logger.Infow( + "received data track subscriber handles", + "participant", c.localParticipant.Identity, + "handles", msg.DataTrackSubscriberHandles.SubHandles, + ) + c.lock.Lock() + // create new remote data tracks if one does not exist for a handle + for handle, publishedDataTrack := range msg.DataTrackSubscriberHandles.SubHandles { + publisherID := livekit.ParticipantID(publishedDataTrack.PublisherSid) + tracks := c.subscribedDataTracks[publisherID] + if tracks == nil { + c.subscribedDataTracks[publisherID] = make(map[uint16]*DataTrackRemote) + tracks = c.subscribedDataTracks[publisherID] + } + if tracks[uint16(handle)] == nil { + tracks[uint16(handle)] = NewDataTrackRemote( + livekit.ParticipantIdentity(publishedDataTrack.PublisherIdentity), + livekit.ParticipantID(publishedDataTrack.PublisherSid), + uint16(handle), + livekit.TrackID(publishedDataTrack.TrackSid), + logger.GetLogger().WithValues("participant", c.localParticipant.Identity, "pID", c.localParticipant.Sid), + ) + } + } + + // delete remote data tracks that have gone away + for publisherID, tracks := range c.subscribedDataTracks { + for handle, dataTrackRemote := range tracks { + if msg.DataTrackSubscriberHandles.SubHandles[uint32(handle)] == nil { + dataTrackRemote.Close() + delete(tracks, handle) + if len(tracks) == 0 { + delete(c.subscribedDataTracks, publisherID) + } + } + } + } + c.lock.Unlock() + } + return nil +} + +func (c *RTCClient) WaitUntilConnected() error { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + for { + select { + case <-ctx.Done(): + id := string(c.ID()) + if c.localParticipant != nil { + id = c.localParticipant.Identity + } + return fmt.Errorf("%s could not connect after timeout", id) + case <-time.After(10 * time.Millisecond): + if c.subscriberAsPrimary.Load() { + if c.subscriberFullyEstablished.Load() { + return nil + } + } else { + if c.publisherFullyEstablished.Load() { + return nil + } + } + } + } +} + +func (c *RTCClient) ReadResponse() (*livekit.SignalResponse, error) { + for { + // handle special messages and pass on the rest + messageType, payload, err := c.conn.ReadMessage() + if err != nil { + return nil, err + } + + if c.ctx.Err() != nil { + return nil, c.ctx.Err() + } + + msg := &livekit.SignalResponse{} + switch messageType { + case websocket.PingMessage: + _ = c.conn.WriteMessage(websocket.PongMessage, nil) + continue + case websocket.BinaryMessage: + // protobuf encoded + err := proto.Unmarshal(payload, msg) + return msg, err + default: + return nil, fmt.Errorf("unexpected message received: %v", messageType) + } + } +} + +func (c *RTCClient) SubscribedTracks() map[livekit.ParticipantID][]*webrtc.TrackRemote { + // create a copy of this + c.lock.Lock() + defer c.lock.Unlock() + tracks := make(map[livekit.ParticipantID][]*webrtc.TrackRemote, len(c.subscribedTracks)) + maps.Copy(tracks, c.subscribedTracks) + return tracks +} + +func (c *RTCClient) SubscribedDataTracks() map[livekit.ParticipantID]map[uint16]*DataTrackRemote { + // create a copy of this + c.lock.Lock() + defer c.lock.Unlock() + tracks := make(map[livekit.ParticipantID]map[uint16]*DataTrackRemote, len(c.subscribedDataTracks)) + for publisherID, sts := range c.subscribedDataTracks { + tracks[publisherID] = make(map[uint16]*DataTrackRemote) + maps.Copy(tracks[publisherID], sts) + } + return tracks +} + +func (c *RTCClient) RemoteParticipants() []*livekit.ParticipantInfo { + c.lock.Lock() + defer c.lock.Unlock() + return funk.Values(c.remoteParticipants).([]*livekit.ParticipantInfo) +} + +func (c *RTCClient) GetRemoteParticipant(sid livekit.ParticipantID) *livekit.ParticipantInfo { + c.lock.Lock() + defer c.lock.Unlock() + return c.remoteParticipants[sid] +} + +func (c *RTCClient) Stop() { + logger.Infow("stopping client", "ID", c.ID()) + _ = c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Leave{ + Leave: &livekit.LeaveRequest{ + Reason: livekit.DisconnectReason_CLIENT_INITIATED, + Action: livekit.LeaveRequest_DISCONNECT, + }, + }, + }) + c.publisherFullyEstablished.Store(false) + c.subscriberFullyEstablished.Store(false) + _ = c.conn.Close() + if c.publisher != nil { + c.publisher.Close() + } + if c.subscriber != nil { + c.subscriber.Close() + } + c.cancel() +} + +func (c *RTCClient) RefreshToken() string { + c.lock.Lock() + defer c.lock.Unlock() + return c.refreshToken +} + +func (c *RTCClient) PongReceivedAt() int64 { + return c.pongReceivedAt.Load() +} + +func (c *RTCClient) GetSubscriptionResponseAndClear() *livekit.SubscriptionResponse { + return c.subscriptionResponse.Swap(nil) +} + +func (c *RTCClient) SendPing() error { + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Ping{ + Ping: time.Now().UnixNano(), + }, + }) +} + +func (c *RTCClient) SendRequest(msg *livekit.SignalRequest) error { + if c.signalRequestInterceptor != nil { + return c.signalRequestInterceptor(msg, c.sendRequest) + } else { + return c.sendRequest(msg) + } +} + +func (c *RTCClient) sendRequest(msg *livekit.SignalRequest) error { + payload, err := proto.Marshal(msg) + if err != nil { + return err + } + + c.wsLock.Lock() + defer c.wsLock.Unlock() + return c.conn.WriteMessage(websocket.BinaryMessage, payload) +} + +func (c *RTCClient) SendIceCandidate(ic *webrtc.ICECandidate, target livekit.SignalTarget) error { + prevIC := c.icQueue[target].Swap(ic) + if prevIC == nil { + return nil + } + + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Trickle{ + Trickle: signalling.ToProtoTrickle(prevIC.ToJSON(), target, ic == nil), + }, + }) +} + +func (c *RTCClient) SetAttributes(attrs map[string]string) error { + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_UpdateMetadata{ + UpdateMetadata: &livekit.UpdateParticipantMetadata{ + Attributes: attrs, + }, + }, + }) +} + +func (c *RTCClient) hasPrimaryEverConnected() bool { + if c.subscriberAsPrimary.Load() { + return c.subscriber.HasEverConnected() + } else { + return c.publisher.HasEverConnected() + } +} + +type AddTrackParams struct { + NoWriter bool +} + +type AddTrackOption func(params *AddTrackParams) + +func AddTrackNoWriter() AddTrackOption { + return func(params *AddTrackParams) { + params.NoWriter = true + } +} + +func (c *RTCClient) AddTrack(track *webrtc.TrackLocalStaticSample, path string, opts ...AddTrackOption) (writer TrackWriter, err error) { + var params AddTrackParams + for _, opt := range opts { + opt(¶ms) + } + trackType := livekit.TrackType_AUDIO + if track.Kind() == webrtc.RTPCodecTypeVideo { + trackType = livekit.TrackType_VIDEO + } + + sender, _, err := c.publisher.AddTrack(track, types.AddTrackParams{}, nil, rtc.RTCPFeedbackConfig{}) + if err != nil { + logger.Errorw( + "add track failed", err, + "participant", c.localParticipant.Identity, + "pID", c.localParticipant.Sid, + "trackID", track.ID(), + ) + return + } + + if err = c.SendAddTrack(track.ID(), track.Codec().MimeType, track.StreamID(), trackType); err != nil { + return + } + + // wait till track published message is received + timeout := time.After(5 * time.Second) + var ti *livekit.TrackInfo + for { + select { + case <-timeout: + return nil, errors.New("could not publish track after timeout") + default: + c.lock.Lock() + ti = c.pendingPublishedTracks[track.ID()] + if ti != nil { + delete(c.pendingPublishedTracks, track.ID()) + c.lock.Unlock() + break + } + c.lock.Unlock() + time.Sleep(50 * time.Millisecond) + } + if ti != nil { + break + } + } + + c.lock.Lock() + defer c.lock.Unlock() + + c.localTracks[ti.Sid] = track + c.trackSenders[ti.Sid] = sender + c.publisher.Negotiate(false) + + if !params.NoWriter { + writer = NewTrackWriter(c.ctx, track, path) + + // write tracks only after connection established + if c.hasPrimaryEverConnected() { + err = writer.Start() + } else { + c.pendingTrackWriters = append(c.pendingTrackWriters, writer) + } + } + + return +} + +func (c *RTCClient) AddStaticTrack(mime string, id string, label string, opts ...AddTrackOption) (writer TrackWriter, err error) { + return c.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{MimeType: mime}, id, label, opts...) +} + +func (c *RTCClient) AddStaticTrackWithCodec(codec webrtc.RTPCodecCapability, id string, label string, opts ...AddTrackOption) (writer TrackWriter, err error) { + track, err := webrtc.NewTrackLocalStaticSample(codec, id, label) + if err != nil { + return + } + + return c.AddTrack(track, "", opts...) +} + +func (c *RTCClient) AddFileTrack(path string, id string, label string) (writer TrackWriter, err error) { + // determine file mime + mime, ok := extMimeMapping[filepath.Ext(path)] + if !ok { + return nil, fmt.Errorf("%s has an unsupported extension", filepath.Base(path)) + } + + logger.Debugw("adding file track", "mime", mime) + + track, err := webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{MimeType: mime}, + id, + label, + ) + if err != nil { + return + } + + return c.AddTrack(track, path) +} + +// send AddTrack command to server to initiate server-side negotiation +func (c *RTCClient) SendAddTrack(cid string, mimeType string, name string, trackType livekit.TrackType) error { + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_AddTrack{ + AddTrack: &livekit.AddTrackRequest{ + Cid: cid, + Name: name, + Type: trackType, + SimulcastCodecs: []*livekit.SimulcastCodec{ + { + Cid: cid, + Codec: mimeType, + }, + }, + }, + }, + }) +} + +func (c *RTCClient) PublishData(data []byte, kind livekit.DataPacket_Kind) error { + if err := c.ensurePublisherConnected(); err != nil { + return err + } + + dpData, err := proto.Marshal(&livekit.DataPacket{ + Value: &livekit.DataPacket_User{ + User: &livekit.UserPacket{Payload: data}, + }, + }) + if err != nil { + return err + } + + return c.publisher.SendDataMessage(kind, dpData) +} + +func (c *RTCClient) PublishDataUnlabeled(data []byte) error { + if err := c.ensurePublisherConnected(); err != nil { + return err + } + + return c.publisher.SendDataMessageUnlabeled(data, true, "test") +} + +func (c *RTCClient) PublishDataTrack() (writer TrackWriter, err error) { + if err = c.ensurePublisherConnected(); err != nil { + return + } + + dataTrackHandle := uint16(c.nextDataTrackHandle.Inc()) + if err = c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_PublishDataTrackRequest{ + PublishDataTrackRequest: &livekit.PublishDataTrackRequest{ + PubHandle: uint32(dataTrackHandle), + Name: fmt.Sprintf("data_track_%d", dataTrackHandle), + }, + }, + }); err != nil { + return + } + + // wait till data track published message is received + timeout := time.After(5 * time.Second) + var dti *livekit.DataTrackInfo + for { + select { + case <-timeout: + return nil, errors.New("could not publish data track after timeout") + default: + c.lock.Lock() + dti = c.pendingPublishedDataTracks[dataTrackHandle] + if dti != nil { + delete(c.pendingPublishedDataTracks, dataTrackHandle) + c.lock.Unlock() + break + } + c.lock.Unlock() + time.Sleep(50 * time.Millisecond) + } + if dti != nil { + break + } + } + + c.lock.Lock() + defer c.lock.Unlock() + + writer = NewDataTrackWriter(c.ctx, dataTrackHandle, c.publisher) + + // write data tracks only after connection established + if c.hasPrimaryEverConnected() { + err = writer.Start() + } else { + c.pendingDataTrackWriters = append(c.pendingDataTrackWriters, writer) + } + return +} + +func (c *RTCClient) GetPublishedTrackIDs() []string { + c.lock.Lock() + defer c.lock.Unlock() + var trackIDs []string + for key := range c.localTracks { + trackIDs = append(trackIDs, key) + } + return trackIDs +} + +// LastAnswer return SDP of the last answer for the publisher connection +func (c *RTCClient) LastAnswer() *webrtc.SessionDescription { + return c.publisher.CurrentRemoteDescription() +} + +func (c *RTCClient) ensurePublisherConnected() error { + if c.publisher.HasEverConnected() { + return nil + } + + // start negotiating + c.publisher.Negotiate(false) + + // wait until connected, increase wait time since it takes more than 10s sometimes on GH + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for { + select { + case <-ctx.Done(): + return fmt.Errorf("could not connect publisher after timeout") + case <-time.After(10 * time.Millisecond): + if c.publisherFullyEstablished.Load() { + return nil + } + } + } +} + +func (c *RTCClient) handleDataMessage(kind livekit.DataPacket_Kind, data []byte) { + dp := &livekit.DataPacket{} + err := proto.Unmarshal(data, dp) + if err != nil { + return + } + dp.Kind = kind + if val, ok := dp.Value.(*livekit.DataPacket_User); ok { + if c.OnDataReceived != nil { + c.OnDataReceived(val.User.Payload, val.User.ParticipantSid) + } + } +} + +func (c *RTCClient) handleDataMessageUnlabeled(data []byte) { + if c.OnDataUnlabeledReceived != nil { + c.OnDataUnlabeledReceived(data) + } +} + +func (c *RTCClient) handleDataTrackMessage(data []byte, _arrivalTime int64) { + var packet datatrack.Packet + if err := packet.Unmarshal(data); err != nil { + return + } + + var dataTrackRemote *DataTrackRemote + c.lock.Lock() + for _, tracks := range c.subscribedDataTracks { + if tracks[packet.Handle] != nil { + dataTrackRemote = tracks[packet.Handle] + break + } + } + c.lock.Unlock() + + if dataTrackRemote != nil { + dataTrackRemote.PacketReceived(&packet) + } +} + +// handles a server initiated offer, handle on subscriber PC +func (c *RTCClient) handleOffer(desc webrtc.SessionDescription, offerId uint32, _midToTrackID map[string]string) { + logger.Infow("handling server offer", "participant", c.localParticipant.Identity) + c.subscriber.HandleRemoteDescription(desc, offerId) + c.processPendingRemoteTracks() +} + +// the client handles answer on the publisher PC +func (c *RTCClient) handleAnswer(desc webrtc.SessionDescription, answerId uint32, _midToTrackID map[string]string) { + logger.Infow("handling server answer", "participant", c.localParticipant.Identity) + + // remote answered the offer, establish connection + c.publisher.HandleRemoteDescription(desc, answerId) + c.processPendingRemoteTracks() +} + +// the client handles media sections requirement on the publisher PC +func (c *RTCClient) handleMediaSectionsRequirement(mediaSectionsRequirement *livekit.MediaSectionsRequirement) { + addTransceivers := func(kind webrtc.RTPCodecType, count uint32) { + for range count { + if _, err := c.publisher.AddTransceiverFromKind( + kind, + webrtc.RTPTransceiverInit{ + Direction: webrtc.RTPTransceiverDirectionRecvonly, + }, + ); err != nil { + logger.Warnw( + "could not add transceiver", err, + "participant", c.localParticipant.Identity, + "kind", kind, + ) + } else { + logger.Infow( + "added transceiver of kind", + "participant", c.localParticipant.Identity, + "kind", kind, + ) + } + } + } + + addTransceivers(webrtc.RTPCodecTypeAudio, mediaSectionsRequirement.NumAudios) + addTransceivers(webrtc.RTPCodecTypeVideo, mediaSectionsRequirement.NumVideos) + c.publisher.Negotiate(false) +} + +func (c *RTCClient) onOffer(offer webrtc.SessionDescription, offerId uint32, midToTrackID map[string]string) error { + if c.localParticipant != nil { + logger.Infow("starting negotiation", "participant", c.localParticipant.Identity) + logger.Infow( + "sending publisher offer", + "participant", c.localParticipant.Identity, + "offer", offer, + "midToTrackID", midToTrackID, + ) + } + return c.SendRequest(&livekit.SignalRequest{ + Message: &livekit.SignalRequest_Offer{ + Offer: signalling.ToProtoSessionDescription(offer, offerId, nil), + }, + }) +} + +func (c *RTCClient) processPendingRemoteTracks() { + c.lock.Lock() + pendingRemoteTracks := c.pendingRemoteTracks + c.pendingRemoteTracks = nil + c.lock.Unlock() + + for _, pendingRemoteTrack := range pendingRemoteTracks { + go c.processRemoteTrack(pendingRemoteTrack) + } +} + +func (c *RTCClient) processRemoteTrack(track *webrtc.TrackRemote) { + lastUpdate := time.Time{} + + // because of FireOnTrackBySdp, it is possible get an empty streamID + // if media comes before SDP, cache and try later + streamID := track.StreamID() + if streamID == "" { + logger.Infow( + "client caching track", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "codec", track.Codec(), + "ssrc", track.SSRC(), + ) + c.lock.Lock() + c.pendingRemoteTracks = append(c.pendingRemoteTracks, track) + c.lock.Unlock() + return + } + + publisherID, trackID := rtc.UnpackStreamID(streamID) + if trackID == "" { + trackID = livekit.TrackID(track.ID()) + } + c.lock.Lock() + c.subscribedTracks[publisherID] = append(c.subscribedTracks[publisherID], track) + c.lock.Unlock() + + logger.Infow( + "client added track", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "publisherID", publisherID, + "trackID", trackID, + "codec", track.Codec(), + "ssrc", track.SSRC(), + ) + + defer func() { + c.lock.Lock() + c.subscribedTracks[publisherID] = funk.Without(c.subscribedTracks[publisherID], track).([]*webrtc.TrackRemote) + c.lock.Unlock() + }() + + numBytes := 0 + for { + pkt, _, err := track.ReadRTP() + if c.ctx.Err() != nil { + break + } + if rtc.IsEOF(err) { + logger.Infow( + "client track removed", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "publisherID", publisherID, + "trackID", trackID, + "codec", track.Codec(), + "ssrc", track.SSRC(), + ) + break + } + if err != nil { + logger.Warnw("error reading RTP", err) + continue + } + c.lock.Lock() + c.lastPackets[publisherID] = pkt + c.bytesReceived[publisherID] += uint64(pkt.MarshalSize()) + c.lock.Unlock() + numBytes += pkt.MarshalSize() + if time.Since(lastUpdate) > 30*time.Second { + logger.Infow( + "consumed from participant", + "participant", c.localParticipant.Identity, + "pID", c.ID(), + "publisherID", publisherID, + "trackID", trackID, + "size", numBytes, + ) + lastUpdate = time.Now() + } + } +} + +func (c *RTCClient) BytesReceived() uint64 { + var total uint64 + c.lock.Lock() + for _, size := range c.bytesReceived { + total += size + } + c.lock.Unlock() + return total +} + +func (c *RTCClient) SendNacks(count int) { + var packets []rtcp.Packet + c.lock.Lock() + for _, pkt := range c.lastPackets { + seqs := make([]uint16, 0, count) + for i := range count { + seqs = append(seqs, pkt.SequenceNumber-uint16(i)) + } + packets = append(packets, &rtcp.TransportLayerNack{ + MediaSSRC: pkt.SSRC, + Nacks: rtcp.NackPairsFromSequenceNumbers(seqs), + }) + } + c.lock.Unlock() + + _ = c.subscriber.WriteRTCP(packets) +} + +func encodeQueryParam(key, value string) string { + return fmt.Sprintf("&%s=%s", url.QueryEscape(key), url.QueryEscape(value)) +} diff --git a/livekit/test/client/datachannel_reader.go b/livekit/test/client/datachannel_reader.go new file mode 100644 index 0000000..eb3b642 --- /dev/null +++ b/livekit/test/client/datachannel_reader.go @@ -0,0 +1,31 @@ +package client + +import ( + "time" + + "github.com/livekit/livekit-server/pkg/sfu/datachannel" +) + +type DataChannelReader struct { + bitrate *datachannel.BitrateCalculator + target int +} + +func NewDataChannelReader(bitrate int) *DataChannelReader { + return &DataChannelReader{ + target: bitrate, + bitrate: datachannel.NewBitrateCalculator(datachannel.BitrateDuration*5, datachannel.BitrateWindow), + } +} + +func (d *DataChannelReader) Read(p []byte, sid string) { + for { + if bitrate, ok := d.bitrate.ForceBitrate(time.Now()); ok && bitrate > 0 && bitrate > d.target { + time.Sleep(10 * time.Millisecond) + d.bitrate.AddBytes(0, 0, time.Now()) + continue + } + break + } + d.bitrate.AddBytes(len(p), 0, time.Now()) +} diff --git a/livekit/test/client/datatrack_remote.go b/livekit/test/client/datatrack_remote.go new file mode 100644 index 0000000..3e9b1df --- /dev/null +++ b/livekit/test/client/datatrack_remote.go @@ -0,0 +1,86 @@ +package client + +import ( + "github.com/frostbyte73/core" + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "go.uber.org/atomic" +) + +type DataTrackRemote struct { + publisherIdentity livekit.ParticipantIdentity + publisherID livekit.ParticipantID + handle uint16 + trackID livekit.TrackID + logger logger.Logger + numReceivedPackets atomic.Uint32 + + closed core.Fuse +} + +func NewDataTrackRemote( + publisherIdentity livekit.ParticipantIdentity, + publisherID livekit.ParticipantID, + handle uint16, + trackID livekit.TrackID, + logger logger.Logger, +) *DataTrackRemote { + logger.Infow( + "creating data track remote", + "publisherIdentity", publisherIdentity, + "publisherID", publisherID, + "handle", handle, + "trackID", trackID, + ) + return &DataTrackRemote{ + publisherIdentity: publisherIdentity, + publisherID: publisherID, + handle: handle, + trackID: trackID, + logger: logger, + } +} + +func (d *DataTrackRemote) Close() { + d.logger.Infow( + "closing data track remote", + "publisherIdentity", d.publisherIdentity, + "publisherID", d.publisherID, + "handle", d.handle, + "trackID", d.trackID, + ) + d.closed.Break() +} + +func (d *DataTrackRemote) Handle() uint16 { + return d.handle +} + +func (d *DataTrackRemote) ID() livekit.TrackID { + return d.trackID +} + +func (d *DataTrackRemote) PacketReceived(packet *datatrack.Packet) { + if d.closed.IsBroken() { + return + } + + valid := true + if len(packet.Payload) == 0 { + valid = false + } + for i := range packet.Payload { + if packet.Payload[i] != byte(255-i) { + valid = false + break + } + } + if valid { + d.numReceivedPackets.Inc() + } +} + +func (d *DataTrackRemote) NumReceivedPackets() uint32 { + return d.numReceivedPackets.Load() +} diff --git a/livekit/test/client/datatrack_writer.go b/livekit/test/client/datatrack_writer.go new file mode 100644 index 0000000..82c6a18 --- /dev/null +++ b/livekit/test/client/datatrack_writer.go @@ -0,0 +1,79 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "math/rand" + "time" + + "github.com/livekit/livekit-server/pkg/rtc/datatrack" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/protocol/logger" +) + +type dataTrackWriter struct { + ctx context.Context + cancel context.CancelFunc + handle uint16 + transport types.DataTrackTransport +} + +func NewDataTrackWriter(ctx context.Context, handle uint16, transport types.DataTrackTransport) TrackWriter { + ctx, cancel := context.WithCancel(ctx) + return &dataTrackWriter{ + ctx: ctx, + cancel: cancel, + handle: handle, + transport: transport, + } +} + +func (d *dataTrackWriter) Start() error { + go d.writeFrames() + return nil +} + +func (d *dataTrackWriter) Stop() { + d.cancel() +} + +func (d *dataTrackWriter) writeFrames() { + seqNum := uint16(0) + frameNum := uint16(0) + for { + select { + case <-d.ctx.Done(): + return + + default: + packets := datatrack.GenerateRawDataPackets(d.handle, seqNum, frameNum, 1, rand.Intn(2048)+1, 100*time.Millisecond) + for _, packet := range packets { + if err := d.transport.SendDataTrackMessage(packet); err != nil { + logger.Errorw("could not send data track packet", err) + } + } + + if len(packets) != 0 { + var lastPacket datatrack.Packet + if err := lastPacket.Unmarshal(packets[len(packets)-1]); err == nil { + seqNum = lastPacket.SequenceNumber + 1 + frameNum = lastPacket.FrameNumber + 1 + } + } + time.Sleep(100 * time.Millisecond) + } + } +} diff --git a/livekit/test/client/trackwriter.go b/livekit/test/client/trackwriter.go new file mode 100644 index 0000000..eae2da4 --- /dev/null +++ b/livekit/test/client/trackwriter.go @@ -0,0 +1,191 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "io" + "os" + "time" + + "github.com/pion/webrtc/v4" + "github.com/pion/webrtc/v4/pkg/media" + "github.com/pion/webrtc/v4/pkg/media/h264reader" + "github.com/pion/webrtc/v4/pkg/media/ivfreader" + "github.com/pion/webrtc/v4/pkg/media/oggreader" + + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/protocol/logger" +) + +type TrackWriter interface { + Start() error + Stop() +} + +// Writes a file to an RTP track. +// makes it easier to debug and create RTP streams +type trackWriter struct { + ctx context.Context + cancel context.CancelFunc + track *webrtc.TrackLocalStaticSample + filePath string + mime mime.MimeType + + ogg *oggreader.OggReader + ivfheader *ivfreader.IVFFileHeader + ivf *ivfreader.IVFReader + h264 *h264reader.H264Reader +} + +func NewTrackWriter(ctx context.Context, track *webrtc.TrackLocalStaticSample, filePath string) TrackWriter { + ctx, cancel := context.WithCancel(ctx) + return &trackWriter{ + ctx: ctx, + cancel: cancel, + track: track, + filePath: filePath, + mime: mime.NormalizeMimeType(track.Codec().MimeType), + } +} + +func (w *trackWriter) Start() error { + if w.filePath == "" { + go w.writeNull() + return nil + } + + file, err := os.Open(w.filePath) + if err != nil { + return err + } + + logger.Debugw( + "starting track writer", + "trackID", w.track.ID(), + "mime", w.mime, + ) + switch w.mime { + case mime.MimeTypeOpus: + w.ogg, _, err = oggreader.NewWith(file) + if err != nil { + return err + } + go w.writeOgg() + case mime.MimeTypeVP8: + w.ivf, w.ivfheader, err = ivfreader.NewWith(file) + if err != nil { + return err + } + go w.writeVP8() + case mime.MimeTypeH264: + w.h264, err = h264reader.NewReader(file) + if err != nil { + return err + } + go w.writeH264() + } + return nil +} + +func (w *trackWriter) Stop() { + w.cancel() +} + +func (w *trackWriter) writeNull() { + defer w.onWriteComplete() + sample := media.Sample{Data: []byte{0x0, 0xff, 0xff, 0xff, 0xff}, Duration: 30 * time.Millisecond} + h264Sample := media.Sample{Data: []byte{0x00, 0x00, 0x00, 0x01, 0x7, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x8, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x5, 0xff, 0xff, 0xff, 0xff}, Duration: 30 * time.Millisecond} + for { + select { + case <-time.After(20 * time.Millisecond): + if w.mime == mime.MimeTypeH264 { + w.track.WriteSample(h264Sample) + } else { + w.track.WriteSample(sample) + } + case <-w.ctx.Done(): + return + } + } +} + +func (w *trackWriter) writeOgg() { + // Keep track of last granule, the difference is the amount of samples in the buffer + var lastGranule uint64 + for { + if w.ctx.Err() != nil { + return + } + pageData, pageHeader, err := w.ogg.ParseNextPage() + if err == io.EOF { + logger.Debugw("all audio samples parsed and sent") + w.onWriteComplete() + return + } + + if err != nil { + logger.Errorw("could not parse ogg page", err) + return + } + + // The amount of samples is the difference between the last and current timestamp + sampleCount := float64(pageHeader.GranulePosition - lastGranule) + lastGranule = pageHeader.GranulePosition + sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond + + if err = w.track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil { + logger.Errorw("could not write sample", err) + return + } + + time.Sleep(sampleDuration) + } +} + +func (w *trackWriter) writeVP8() { + // Send our video file frame at a time. Pace our sending such that we send it at the same speed it should be played back as. + // This isn't required since the video is timestamped, but we will such much higher loss if we send all at once. + sleepTime := time.Millisecond * time.Duration((float32(w.ivfheader.TimebaseNumerator)/float32(w.ivfheader.TimebaseDenominator))*1000) + for { + if w.ctx.Err() != nil { + return + } + frame, _, err := w.ivf.ParseNextFrame() + if err == io.EOF { + logger.Debugw("all video frames parsed and sent") + w.onWriteComplete() + return + } + + if err != nil { + logger.Errorw("could not parse VP8 frame", err) + return + } + + time.Sleep(sleepTime) + if err = w.track.WriteSample(media.Sample{Data: frame, Duration: time.Second}); err != nil { + logger.Errorw("could not write sample", err) + return + } + } +} + +func (w *trackWriter) writeH264() { + // TODO: this is harder +} + +func (w *trackWriter) onWriteComplete() { +} diff --git a/livekit/test/integration_helpers.go b/livekit/test/integration_helpers.go new file mode 100644 index 0000000..1011214 --- /dev/null +++ b/livekit/test/integration_helpers.go @@ -0,0 +1,361 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "fmt" + "net/http" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/twitchtv/twirp" + + "github.com/livekit/mediatransportutil/pkg/rtcconfig" + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/guid" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/telemetry/prometheus" + "github.com/livekit/livekit-server/pkg/testutils" + testclient "github.com/livekit/livekit-server/test/client" +) + +const ( + testApiKey = "apikey" + testApiSecret = "apiSecretExtendTo32BytesAsThatIsMinimum" + testRoom = "mytestroom" + defaultServerPort = 7880 + secondServerPort = 8880 + nodeID1 = "node-1" + nodeID2 = "node-2" + + syncDelay = 100 * time.Millisecond + // if there are deadlocks, it's helpful to set a short test timeout (i.e. go test -timeout=30s) + // let connection timeout happen + // connectTimeout = 5000 * time.Second +) + +var roomClient livekit.RoomService + +func init() { + config.InitLoggerFromConfig(&config.DefaultConfig.Logging) + + prometheus.Init("test", livekit.NodeType_SERVER) +} + +func setupSingleNodeTest(name string) (*service.LivekitServer, func()) { + logger.Infow("----------------STARTING TEST----------------", "test", name) + s := createSingleNodeServer(nil) + go func() { + if err := s.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + + waitForServerToStart(s) + + return s, func() { + s.Stop(true) + logger.Infow("----------------FINISHING TEST----------------", "test", name) + } +} + +func setupMultiNodeTest(name string) (*service.LivekitServer, *service.LivekitServer, func()) { + logger.Infow("----------------STARTING TEST----------------", "test", name) + s1 := createMultiNodeServer(guid.New(nodeID1), defaultServerPort) + s2 := createMultiNodeServer(guid.New(nodeID2), secondServerPort) + go s1.Start() + go s2.Start() + + waitForServerToStart(s1) + waitForServerToStart(s2) + + return s1, s2, func() { + s1.Stop(true) + s2.Stop(true) + redisClient().FlushAll(context.Background()) + logger.Infow("----------------FINISHING TEST----------------", "test", name) + } +} + +func contextWithToken(token string) context.Context { + header := make(http.Header) + testclient.SetAuthorizationToken(header, token) + tctx, err := twirp.WithHTTPRequestHeaders(context.Background(), header) + if err != nil { + panic(err) + } + return tctx +} + +func waitForServerToStart(s *service.LivekitServer) { + // wait till ready + ctx, cancel := context.WithTimeout(context.Background(), testutils.ConnectTimeout) + defer cancel() + for { + select { + case <-ctx.Done(): + panic("could not start server after timeout") + case <-time.After(10 * time.Millisecond): + if s.IsRunning() { + // ensure we can connect to it + res, err := http.Get(fmt.Sprintf("http://localhost:%d", s.HTTPPort())) + if err == nil && res.StatusCode == http.StatusOK { + return + } + } + } + } +} + +func waitUntilConnected(t *testing.T, clients ...*testclient.RTCClient) { + logger.Infow("waiting for clients to become connected") + wg := sync.WaitGroup{} + for i := range clients { + c := clients[i] + wg.Add(1) + go func() { + defer wg.Done() + err := c.WaitUntilConnected() + if err != nil { + t.Error(err) + } + }() + } + wg.Wait() + if t.Failed() { + t.FailNow() + } +} + +func createSingleNodeServer(configUpdater func(*config.Config)) *service.LivekitServer { + var err error + conf, err := config.NewConfig("", true, nil, nil) + if err != nil { + panic(fmt.Sprintf("could not create config: %v", err)) + } + conf.Keys = map[string]string{testApiKey: testApiSecret} + conf.EnableDataTracks = true + if configUpdater != nil { + configUpdater(conf) + } + + currentNode, err := routing.NewLocalNode(conf) + if err != nil { + panic(fmt.Sprintf("could not create local node: %v", err)) + } + currentNode.SetNodeID(livekit.NodeID(guid.New(nodeID1))) + + s, err := service.InitializeServer(conf, currentNode) + if err != nil { + panic(fmt.Sprintf("could not create server: %v", err)) + } + + roomClient = livekit.NewRoomServiceJSONClient(fmt.Sprintf("http://localhost:%d", defaultServerPort), &http.Client{}) + return s +} + +func createMultiNodeServer(nodeID string, port uint32) *service.LivekitServer { + var err error + conf, err := config.NewConfig("", true, nil, nil) + if err != nil { + panic(fmt.Sprintf("could not create config: %v", err)) + } + conf.Port = port + conf.RTC.UDPPort = rtcconfig.PortRange{Start: int(port) + 1} + conf.RTC.TCPPort = port + 2 + conf.Redis.Address = "localhost:6379" + conf.Keys = map[string]string{testApiKey: testApiSecret} + conf.EnableDataTracks = true + + currentNode, err := routing.NewLocalNode(conf) + if err != nil { + panic(err) + } + currentNode.SetNodeID(livekit.NodeID(nodeID)) + + // redis routing and store + s, err := service.InitializeServer(conf, currentNode) + if err != nil { + panic(fmt.Sprintf("could not create server: %v", err)) + } + + roomClient = livekit.NewRoomServiceJSONClient(fmt.Sprintf("http://localhost:%d", port), &http.Client{}) + return s +} + +type testRTCServicePath int + +const ( + testRTCServicePathv0 testRTCServicePath = iota + testRTCServicePathv0SinglePeerConnection + testRTCServicePathv1 +) + +func (t testRTCServicePath) String() string { + switch t { + case testRTCServicePathv0: + return "v0" + case testRTCServicePathv0SinglePeerConnection: + return "v0-single-peer-connection" + case testRTCServicePathv1: + return "v1" + default: + return fmt.Sprintf("unknown: %d", t) + } +} + +var testRTCServicePaths = []testRTCServicePath{ + testRTCServicePathv0, + testRTCServicePathv0SinglePeerConnection, + testRTCServicePathv1, +} + +func testRTCServicePathToTestClientOptions(testRTCServicePath testRTCServicePath, opts *testclient.Options) { + if opts == nil { + return + } + + switch testRTCServicePath { + case testRTCServicePathv0: + opts.RTCServicePath = "/rtc" + case testRTCServicePathv0SinglePeerConnection: + opts.RTCServicePath = "/rtc" + opts.UseJoinRequestQueryParam = true + case testRTCServicePathv1: + opts.RTCServicePath = "/rtc/v1" + opts.UseJoinRequestQueryParam = true + default: + opts.RTCServicePath = "/rtc" + } +} + +// creates a client and runs against server +func createRTCClient(name string, port int, testRTCServicePath testRTCServicePath, opts *testclient.Options) *testclient.RTCClient { + var customizer func(token *auth.AccessToken, grants *auth.VideoGrant) + if opts != nil { + customizer = opts.TokenCustomizer + } + token := joinToken(testRoom, name, customizer) + + return createRTCClientWithToken(token, port, testRTCServicePath, opts) +} + +// creates a client and runs against server +func createRTCClientWithToken(token string, port int, testRTCServicePath testRTCServicePath, opts *testclient.Options) *testclient.RTCClient { + if opts == nil { + opts = &testclient.Options{ + AutoSubscribe: true, + } + } + testRTCServicePathToTestClientOptions(testRTCServicePath, opts) + ws, err := testclient.NewWebSocketConn(fmt.Sprintf("ws://localhost:%d", port), token, opts) + if err != nil { + panic(err) + } + + c, err := testclient.NewRTCClient(ws, opts.UseJoinRequestQueryParam, opts) + if err != nil { + panic(err) + } + + go c.Run() + + return c +} + +func redisClient() *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + }) +} + +func joinToken(room, name string, customFn func(token *auth.AccessToken, grants *auth.VideoGrant)) string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + SetIdentity(name). + SetName(name). + SetMetadata("metadata" + name) + grant := &auth.VideoGrant{RoomJoin: true, Room: room} + if customFn != nil { + customFn(at, grant) + } + at.AddGrant(grant) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} + +func joinTokenWithGrant(name string, grant *auth.VideoGrant) string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(grant). + SetIdentity(name). + SetName(name) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} + +func createRoomToken() string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(&auth.VideoGrant{RoomCreate: true}) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} + +func adminRoomToken(name string) string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(&auth.VideoGrant{RoomAdmin: true, Room: name}) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} + +func listRoomToken() string { + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(&auth.VideoGrant{RoomList: true}) + t, err := at.ToJWT() + if err != nil { + panic(err) + } + return t +} + +func stopWriters(writers ...testclient.TrackWriter) { + for _, w := range writers { + w.Stop() + } +} + +func stopClients(clients ...*testclient.RTCClient) { + for _, c := range clients { + c.Stop() + } +} diff --git a/livekit/test/multinode_roomservice_test.go b/livekit/test/multinode_roomservice_test.go new file mode 100644 index 0000000..de74572 --- /dev/null +++ b/livekit/test/multinode_roomservice_test.go @@ -0,0 +1,189 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/testutils" +) + +func TestMultiNodeRoomList(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, _, finish := setupMultiNodeTest("TestMultiNodeRoomList") + defer finish() + + roomServiceListRoom(t) +} + +// update room metadata when it's empty +func TestMultiNodeUpdateRoomMetadata(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + t.Run("when room is empty", func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateRoomMetadata_empty") + defer finish() + + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: "emptyRoom", + }) + require.NoError(t, err) + + rm, err := roomClient.UpdateRoomMetadata(contextWithToken(adminRoomToken("emptyRoom")), &livekit.UpdateRoomMetadataRequest{ + Room: "emptyRoom", + Metadata: "updated metadata", + }) + require.NoError(t, err) + require.Equal(t, "updated metadata", rm.Metadata) + }) + + t.Run("when room has a participant", func(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateRoomMetadata_with_participant") + defer finish() + + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + defer c1.Stop() + + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: "emptyRoom", + }) + require.NoError(t, err) + + rm, err := roomClient.UpdateRoomMetadata(contextWithToken(adminRoomToken("emptyRoom")), &livekit.UpdateRoomMetadataRequest{ + Room: "emptyRoom", + Metadata: "updated metadata", + }) + require.NoError(t, err) + require.Equal(t, "updated metadata", rm.Metadata) + }) + } + }) +} + +// remove a participant +func TestMultiNodeRemoveParticipant(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeRemoveParticipant") + defer finish() + + c1 := createRTCClient("mn_remove_participant", defaultServerPort, testRTCServicePath, nil) + defer c1.Stop() + waitUntilConnected(t, c1) + + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.RemoveParticipant(ctx, &livekit.RoomParticipantIdentity{ + Room: testRoom, + Identity: "mn_remove_participant", + }) + require.NoError(t, err) + + // participant list doesn't show the participant + listRes, err := roomClient.ListParticipants(ctx, &livekit.ListParticipantsRequest{ + Room: testRoom, + }) + require.NoError(t, err) + require.Len(t, listRes.Participants, 0) + }) + } +} + +// update participant metadata +func TestMultiNodeUpdateParticipantMetadata(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateParticipantMetadata") + defer finish() + + c1 := createRTCClient("update_participant_metadata", defaultServerPort, testRTCServicePath, nil) + defer c1.Stop() + waitUntilConnected(t, c1) + + ctx := contextWithToken(adminRoomToken(testRoom)) + res, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "update_participant_metadata", + Metadata: "the new metadata", + }) + require.NoError(t, err) + require.Equal(t, "the new metadata", res.Metadata) + }) + } +} + +// admin mute published track +func TestMultiNodeMutePublishedTrack(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeMutePublishedTrack") + defer finish() + + identity := "mute_published_track" + c1 := createRTCClient(identity, defaultServerPort, testRTCServicePath, nil) + defer c1.Stop() + waitUntilConnected(t, c1) + + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) + + trackIDs := c1.GetPublishedTrackIDs() + require.NotEmpty(t, trackIDs) + + ctx := contextWithToken(adminRoomToken(testRoom)) + // wait for it to be published before + testutils.WithTimeout(t, func() string { + res, err := roomClient.GetParticipant(ctx, &livekit.RoomParticipantIdentity{ + Room: testRoom, + Identity: identity, + }) + require.NoError(t, err) + if len(res.Tracks) == 2 { + return "" + } else { + return fmt.Sprintf("expected 2 tracks to be published, actual: %d", len(res.Tracks)) + } + }) + + res, err := roomClient.MutePublishedTrack(ctx, &livekit.MuteRoomTrackRequest{ + Room: testRoom, + Identity: identity, + TrackSid: trackIDs[0], + Muted: true, + }) + require.NoError(t, err) + require.Equal(t, trackIDs[0], res.Track.Sid) + require.True(t, res.Track.Muted) + }) + } +} diff --git a/livekit/test/multinode_test.go b/livekit/test/multinode_test.go new file mode 100644 index 0000000..9b4d5ac --- /dev/null +++ b/livekit/test/multinode_test.go @@ -0,0 +1,427 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/testutils" + "github.com/livekit/livekit-server/test/client" +) + +func TestMultiNodeRouting(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("TestMultiNodeRouting") + defer finish() + + // creating room on node 1 + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: testRoom, + }) + require.NoError(t, err) + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + // one node connecting to node 1, and another connecting to node 2 + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", secondServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) + + // c1 publishing, and c2 receiving + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + if t1 != nil { + defer t1.Stop() + } + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 received no tracks" + } + if len(c2.SubscribedTracks()[c1.ID()]) != 1 { + return "c2 didn't receive track published by c1" + } + tr1 := c2.SubscribedTracks()[c1.ID()][0] + streamID, _ := rtc.UnpackStreamID(tr1.StreamID()) + require.Equal(t, c1.ID(), streamID) + return "" + }) + + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + require.Equal(t, "c1", remoteC1.Name) + require.Equal(t, "metadatac1", remoteC1.Metadata) + }) + } +} + +func TestConnectWithoutCreation(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("TestConnectWithoutCreation") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + c1.Stop() + }) + } +} + +func TestMultinodePublishingUponJoining(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, _, finish := setupMultiNodeTest("TestMultinodePublishingUponJoining") + defer finish() + + scenarioPublishingUponJoining(t) +} + +func TestMultinodeReceiveBeforePublish(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, _, finish := setupMultiNodeTest("TestMultinodeReceiveBeforePublish") + defer finish() + + scenarioReceiveBeforePublish(t) +} + +// reconnecting to the same room, after one of the servers has gone away +func TestMultinodeReconnectAfterNodeShutdown(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + _, s2, finish := setupMultiNodeTest("TestMultinodeReconnectAfterNodeShutdown") + defer finish() + + // creating room on node 1 + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: testRoom, + NodeId: s2.Node().Id, + }) + require.NoError(t, err) + + // one node connecting to node 1, and another connecting to node 2 + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", secondServerPort, testRTCServicePath, nil) + + waitUntilConnected(t, c1, c2) + stopClients(c1, c2) + + // stop s2, and connect to room again + s2.Stop(true) + + time.Sleep(syncDelay) + + c3 := createRTCClient("c3", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c3) + }) + } +} + +func TestMultinodeDataPublishing(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("TestMultinodeDataPublishing") + defer finish() + + scenarioDataPublish(t) + scenarioDataUnlabeledPublish(t) + scenarioDataTracksPublishingUponJoining(t) +} + +func TestMultiNodeJoinAfterClose(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("TestMultiNodeJoinAfterClose") + defer finish() + + scenarioJoinClosedRoom(t) +} + +func TestMultiNodeCloseNonRTCRoom(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("closeNonRTCRoom") + defer finish() + + closeNonRTCRoom(t) +} + +// ensure that token accurately reflects out of band updates +func TestMultiNodeRefreshToken(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeJoinAfterClose") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + // a participant joining with full permissions + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + // update permissions and metadata + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "c1", + Permission: &livekit.ParticipantPermission{ + CanPublish: false, + CanSubscribe: true, + }, + Metadata: "metadata", + }) + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + if c1.RefreshToken() == "" { + return "did not receive refresh token" + } + // parse token to ensure it's correct + verifier, err := auth.ParseAPIToken(c1.RefreshToken()) + require.NoError(t, err) + + _, grants, err := verifier.Verify(testApiSecret) + require.NoError(t, err) + + if grants.Metadata != "metadata" { + return "metadata did not match" + } + if *grants.Video.CanPublish { + return "canPublish should be false" + } + if *grants.Video.CanPublishData { + return "canPublishData should be false" + } + if !*grants.Video.CanSubscribe { + return "canSubscribe should be true" + } + return "" + }) + }) + } +} + +// ensure that token accurately reflects out of band updates +func TestMultiNodeUpdateAttributes(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, _, finish := setupMultiNodeTest("TestMultiNodeUpdateAttributes") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("au1", defaultServerPort, testRTCServicePath, &client.Options{ + TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { + token.SetAttributes(map[string]string{ + "mykey": "au1", + }) + }, + }) + c2 := createRTCClient("au2", secondServerPort, testRTCServicePath, &client.Options{ + TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { + token.SetAttributes(map[string]string{ + "mykey": "au2", + }) + grants.SetCanUpdateOwnMetadata(true) + }, + }) + waitUntilConnected(t, c1, c2) + + testutils.WithTimeout(t, func() string { + rc2 := c1.GetRemoteParticipant(c2.ID()) + rc1 := c2.GetRemoteParticipant(c1.ID()) + if rc2 == nil || rc1 == nil { + return "participants could not see each other" + } + if rc1.Attributes == nil || rc1.Attributes["mykey"] != "au1" { + return "rc1's initial attributes are incorrect" + } + if rc2.Attributes == nil || rc2.Attributes["mykey"] != "au2" { + return "rc2's initial attributes are incorrect" + } + return "" + }) + + // this one should not go through + _ = c1.SetAttributes(map[string]string{"mykey": "shouldnotchange"}) + _ = c2.SetAttributes(map[string]string{"secondkey": "au2"}) + + // updates using room API should succeed + _, err := roomClient.UpdateParticipant(contextWithToken(adminRoomToken(testRoom)), &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "au1", + Attributes: map[string]string{ + "secondkey": "au1", + }, + }) + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + rc1 := c2.GetRemoteParticipant(c1.ID()) + rc2 := c1.GetRemoteParticipant(c2.ID()) + if rc1.Attributes["secondkey"] != "au1" { + return "au1's attribute update failed" + } + if rc2.Attributes["secondkey"] != "au2" { + return "au2's attribute update failed" + } + if rc1.Attributes["mykey"] != "au1" { + return "au1's mykey should not change" + } + if rc2.Attributes["mykey"] != "au2" { + return "au2's mykey should not change" + } + return "" + }) + }) + } +} + +func TestMultiNodeRevokePublishPermission(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestMultiNodeRevokePublishPermission") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", secondServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + + // c1 publishes a track for c2 + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 did not receive c1's tracks" + } + return "" + }) + + // revoke permission + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err := roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "c1", + Permission: &livekit.ParticipantPermission{ + CanPublish: false, + CanPublishData: true, + CanSubscribe: true, + }, + }) + require.NoError(t, err) + + // ensure c1 no longer has track published, c2 no longer see track under C1 + testutils.WithTimeout(t, func() string { + if len(c1.GetPublishedTrackIDs()) != 0 { + return "c1 did not unpublish tracks" + } + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + if remoteC1 == nil { + return "c2 doesn't know about c1" + } + if len(remoteC1.Tracks) != 0 { + return "c2 still has c1's tracks" + } + return "" + }) + }) + } +} + +func TestCloseDisconnectedParticipantOnSignalClose(t *testing.T) { + _, _, finish := setupMultiNodeTest("TestCloseDisconnectedParticipantOnSignalClose") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", secondServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, &client.Options{ + SignalRequestInterceptor: func(msg *livekit.SignalRequest, next client.SignalRequestHandler) error { + switch msg.Message.(type) { + case *livekit.SignalRequest_Offer, *livekit.SignalRequest_Answer, *livekit.SignalRequest_Leave: + return nil + default: + return next(msg) + } + }, + SignalResponseInterceptor: func(msg *livekit.SignalResponse, next client.SignalResponseHandler) error { + switch msg.Message.(type) { + case *livekit.SignalResponse_Offer, *livekit.SignalResponse_Answer: + return nil + default: + return next(msg) + } + }, + }) + + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) != 1 { + return "c1 did not see c2 join" + } + return "" + }) + + c2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) != 0 { + return "c1 did not see c2 removed" + } + return "" + }) + }) + } +} diff --git a/livekit/test/scenarios.go b/livekit/test/scenarios.go new file mode 100644 index 0000000..aee71fe --- /dev/null +++ b/livekit/test/scenarios.go @@ -0,0 +1,388 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/testutils" + testclient "github.com/livekit/livekit-server/test/client" +) + +// a scenario with lots of clients connecting, publishing, and leaving at random periods +func scenarioPublishingUponJoining(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("puj_1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("puj_2", secondServerPort, testRTCServicePath, &testclient.Options{AutoSubscribe: true}) + c3 := createRTCClient("puj_3", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribe: true}) + defer stopClients(c1, c2, c3) + + waitUntilConnected(t, c1, c2, c3) + + // c1 and c2 publishing, c3 just receiving + writers := publishTracksForClients(t, c1, c2) + defer stopWriters(writers...) + + logger.Infow("waiting to receive tracks from c1 and c2") + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedTracks() + if len(tracks[c1.ID()]) != 2 { + return "did not receive tracks from c1" + } + if len(tracks[c2.ID()]) != 2 { + return "did not receive tracks from c2" + } + return "" + }) + + // after a delay, c2 reconnects, then publishing + time.Sleep(syncDelay) + c2.Stop() + + logger.Infow("waiting for c2 tracks to be gone") + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedTracks() + + if len(tracks[c1.ID()]) != 2 { + return fmt.Sprintf("c3 should be subscribed to 2 tracks from c1, actual: %d", len(tracks[c1.ID()])) + } + if len(tracks[c2.ID()]) != 0 { + return fmt.Sprintf("c3 should be subscribed to 0 tracks from c2, actual: %d", len(tracks[c2.ID()])) + } + if len(c1.SubscribedTracks()[c2.ID()]) != 0 { + return fmt.Sprintf("c3 should be subscribed to 0 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + return "" + }) + + logger.Infow("c2 reconnecting") + // connect to a diff port + c2 = createRTCClient("puj_2", defaultServerPort, testRTCServicePath, nil) + defer c2.Stop() + waitUntilConnected(t, c2) + writers = publishTracksForClients(t, c2) + defer stopWriters(writers...) + + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedTracks() + // "new c2 tracks should be published again", + if len(tracks[c2.ID()]) != 2 { + return fmt.Sprintf("c3 should be subscribed to 2 tracks from c2, actual: %d", len(tracks[c2.ID()])) + } + if len(c1.SubscribedTracks()[c2.ID()]) != 2 { + return fmt.Sprintf("c1 should be subscribed to 2 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + return "" + }) + }) + } +} + +func scenarioReceiveBeforePublish(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("rbp_1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("rbp_2", defaultServerPort, testRTCServicePath, nil) + + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) + + // c1 publishes + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) + + // c2 should see some bytes flowing through + testutils.WithTimeout(t, func() string { + if c2.BytesReceived() > 20 { + return "" + } else { + return fmt.Sprintf("c2 only received %d bytes", c2.BytesReceived()) + } + }) + + // now publish on C2 + writers = publishTracksForClients(t, c2) + defer stopWriters(writers...) + + testutils.WithTimeout(t, func() string { + if len(c1.SubscribedTracks()[c2.ID()]) == 2 { + return "" + } else { + return fmt.Sprintf("expected c1 to receive 2 tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + }) + + // now leave, and ensure that it's immediate + c2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) > 0 { + return fmt.Sprintf("expected no remote participants, actual: %v", c1.RemoteParticipants()) + } + return "" + }) + }) + } +} + +func scenarioDataPublish(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("scenarioDataPublish/testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("dp1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("dp2", secondServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) + + payload := "test bytes" + + received := atomic.NewBool(false) + c2.OnDataReceived = func(data []byte, sid string) { + if string(data) == payload && livekit.ParticipantID(sid) == c1.ID() { + received.Store(true) + } + } + + require.NoError(t, c1.PublishData([]byte(payload), livekit.DataPacket_RELIABLE)) + + testutils.WithTimeout(t, func() string { + if received.Load() { + return "" + } else { + return "c2 did not receive published data" + } + }) + }) + } +} + +func scenarioDataUnlabeledPublish(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("scenarioDataUnlabeledPublish/testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("dp1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("dp2", secondServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + defer stopClients(c1, c2) + + payload := "test unlabeled bytes" + + received := atomic.NewBool(false) + c2.OnDataReceived = func(data []byte, _sid string) { + if string(data) == payload { + received.Store(true) + } + } + + require.NoError(t, c1.PublishDataUnlabeled([]byte(payload))) + + testutils.WithTimeout(t, func() string { + if received.Load() { + return "" + } else { + return "c2 did not receive published data unlabeled" + } + }) + }) + } +} + +func scenarioDataTracksPublishingUponJoining(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("scenarioDataTracksPublishingUponJoining/testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("dtpuj_1", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + c2 := createRTCClient("dtpuj_2", secondServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + c3 := createRTCClient("dtpuj_3", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + defer stopClients(c1, c2, c3) + + waitUntilConnected(t, c1, c2, c3) + + // c1 and c2 publishing, c3 just receiving + writers := publishDataTracksForClients(t, c1, c2) + defer stopWriters(writers...) + + logger.Infow("waiting to receive data tracks from c1 and c2") + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedDataTracks() + if len(tracks[c1.ID()]) != 2 { + return "did not receive data tracks from c1" + } + if len(tracks[c2.ID()]) != 2 { + return "did not receive data tracks from c2" + } + for _, dts := range tracks { + for _, dt := range dts { + if dt.NumReceivedPackets() == 0 { + return fmt.Sprintf("no packets received from %s", dt.ID()) + } + } + } + return "" + }) + + // after a delay, c2 reconnects, then publishing + time.Sleep(syncDelay) + c2.Stop() + + logger.Infow("waiting for c2 data tracks to be gone") + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedDataTracks() + + if len(tracks[c1.ID()]) != 2 { + return fmt.Sprintf("c3 should be subscribed to 2 data tracks from c1, actual: %d", len(tracks[c1.ID()])) + } + if len(tracks[c2.ID()]) != 0 { + return fmt.Sprintf("c3 should be subscribed to 0 data tracks from c2, actual: %d", len(tracks[c2.ID()])) + } + if len(c1.SubscribedDataTracks()[c2.ID()]) != 0 { + return fmt.Sprintf("c3 should be subscribed to 0 data tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + return "" + }) + + logger.Infow("c2 reconnecting") + // connect to a diff port + c2 = createRTCClient("dtpuj_2", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + defer c2.Stop() + waitUntilConnected(t, c2) + writers = publishDataTracksForClients(t, c2) + defer stopWriters(writers...) + + testutils.WithTimeout(t, func() string { + tracks := c3.SubscribedDataTracks() + // new c2 data tracks should be published again + if len(tracks[c2.ID()]) != 2 { + return fmt.Sprintf("c3 should be subscribed to 2 data tracks from c2, actual: %d", len(tracks[c2.ID()])) + } + for _, dt := range tracks[c2.ID()] { + if dt.NumReceivedPackets() == 0 { + return fmt.Sprintf("c3 did not receive packets from c2 data track after reconnecting %s", dt.ID()) + } + } + + if len(c1.SubscribedDataTracks()[c2.ID()]) != 2 { + return fmt.Sprintf("c1 should be subscribed to 2 data tracks from c2, actual: %d", len(c1.SubscribedTracks()[c2.ID()])) + } + for _, dt := range c1.SubscribedDataTracks()[c2.ID()] { + if dt.NumReceivedPackets() == 0 { + return fmt.Sprintf("c1 did not receive packets from c2 data track after reconnecting %s", dt.ID()) + } + } + return "" + }) + }) + } +} + +func scenarioJoinClosedRoom(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("jcr1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + // close room with room client + _, err := roomClient.DeleteRoom(contextWithToken(createRoomToken()), &livekit.DeleteRoomRequest{ + Room: testRoom, + }) + require.NoError(t, err) + + // now join again + c2 := createRTCClient("jcr2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c2) + stopClients(c2) + }) + } +} + +// close a room that has been created, but no participant has joined +func closeNonRTCRoom(t *testing.T) { + createCtx := contextWithToken(createRoomToken()) + _, err := roomClient.CreateRoom(createCtx, &livekit.CreateRoomRequest{ + Name: testRoom, + }) + require.NoError(t, err) + + _, err = roomClient.DeleteRoom(createCtx, &livekit.DeleteRoomRequest{ + Room: testRoom, + }) + require.NoError(t, err) +} + +func publishTracksForClients(t *testing.T, clients ...*testclient.RTCClient) []testclient.TrackWriter { + logger.Infow("publishing tracks for clients") + var writers []testclient.TrackWriter + for i := range clients { + c := clients[i] + tw, err := c.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + writers = append(writers, tw) + + tw, err = c.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + writers = append(writers, tw) + } + return writers +} + +func publishDataTracksForClients(t *testing.T, clients ...*testclient.RTCClient) []testclient.TrackWriter { + logger.Infow("publishing data tracks for clients") + var writers []testclient.TrackWriter + for i := range clients { + c := clients[i] + for range 2 { + dtw, err := c.PublishDataTrack() + require.NoError(t, err) + writers = append(writers, dtw) + } + } + return writers +} + +// Room service tests + +func roomServiceListRoom(t *testing.T) { + createCtx := contextWithToken(createRoomToken()) + listCtx := contextWithToken(listRoomToken()) + // create rooms + _, err := roomClient.CreateRoom(createCtx, &livekit.CreateRoomRequest{ + Name: testRoom, + }) + require.NoError(t, err) + _, err = roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{ + Name: "yourroom", + }) + require.NoError(t, err) + + t.Run("list all rooms", func(t *testing.T) { + res, err := roomClient.ListRooms(listCtx, &livekit.ListRoomsRequest{}) + require.NoError(t, err) + require.Len(t, res.Rooms, 2) + }) + t.Run("list specific rooms", func(t *testing.T) { + res, err := roomClient.ListRooms(listCtx, &livekit.ListRoomsRequest{ + Names: []string{"yourroom"}, + }) + require.NoError(t, err) + require.Len(t, res.Rooms, 1) + require.Equal(t, "yourroom", res.Rooms[0].Name) + }) +} diff --git a/livekit/test/singlenode_test.go b/livekit/test/singlenode_test.go new file mode 100644 index 0000000..03b9e46 --- /dev/null +++ b/livekit/test/singlenode_test.go @@ -0,0 +1,1102 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "net/http" + "reflect" + "strings" + "testing" + "time" + + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v4" + "github.com/stretchr/testify/require" + "github.com/thoas/go-funk" + "github.com/twitchtv/twirp" + "go.uber.org/atomic" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/rtc" + "github.com/livekit/livekit-server/pkg/sfu/datachannel" + "github.com/livekit/livekit-server/pkg/sfu/mime" + "github.com/livekit/livekit-server/pkg/testutils" + testclient "github.com/livekit/livekit-server/test/client" +) + +const ( + waitTick = 10 * time.Millisecond + waitTimeout = 5 * time.Second +) + +func TestClientCouldConnect(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("TestClientCouldConnect") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + + // ensure they both see each other + testutils.WithTimeout(t, func() string { + if len(c1.RemoteParticipants()) == 0 { + return "c1 did not see c2" + } + if len(c2.RemoteParticipants()) == 0 { + return "c2 did not see c1" + } + return "" + }) + }) + } +} + +func TestClientConnectDuplicate(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("TestClientConnectDuplicate") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanPublish(true) + grant.SetCanSubscribe(true) + token := joinTokenWithGrant("c1", grant) + c1 := createRTCClientWithToken(token, defaultServerPort, testRTCServicePath, nil) + + // publish 2 tracks + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() + + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + + opts := &testclient.Options{ + Publish: "duplicate_connection", + } + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 didn't subscribe to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 didn't subscribe to both tracks from c1" + } + + // participant ID can be appended with '#..' . but should contain orig id as prefix + tr1 := c2.SubscribedTracks()[c1.ID()][0] + participantId1, _ := rtc.UnpackStreamID(tr1.StreamID()) + require.Equal(t, c1.ID(), participantId1) + tr2 := c2.SubscribedTracks()[c1.ID()][1] + participantId2, _ := rtc.UnpackStreamID(tr2.StreamID()) + require.Equal(t, c1.ID(), participantId2) + return "" + }) + + c1Dup := createRTCClientWithToken(token, defaultServerPort, testRTCServicePath, opts) + + waitUntilConnected(t, c1Dup) + + t3, err := c1Dup.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t3.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()[c1Dup.ID()]) != 1 { + return "c2 was not subscribed to track from duplicated c1" + } + + tr3 := c2.SubscribedTracks()[c1Dup.ID()][0] + participantId3, _ := rtc.UnpackStreamID(tr3.StreamID()) + require.Contains(t, c1Dup.ID(), participantId3) + + return "" + }) + }) + } +} + +func TestSinglePublisher(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + s, finish := setupSingleNodeTest("TestSinglePublisher") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + + // publish an audio and video track and ensure clients receive it ok + t1, err := c1.AddStaticTrack("audio/OPUS", "audio", "webcamaudio") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcamvideo") + require.NoError(t, err) + defer t2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 didn't subscribe to both tracks from c1" + } + + tr1 := c2.SubscribedTracks()[c1.ID()][0] + participantId, _ := rtc.UnpackStreamID(tr1.StreamID()) + require.Equal(t, c1.ID(), participantId) + return "" + }) + // ensure mime type is received + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + audioTrack := funk.Find(remoteC1.Tracks, func(ti *livekit.TrackInfo) bool { + return ti.Name == "webcamaudio" + }).(*livekit.TrackInfo) + require.Equal(t, "audio/opus", audioTrack.MimeType) + + // a new client joins and should get the initial stream + c3 := createRTCClient("c3", defaultServerPort, testRTCServicePath, nil) + + // ensure that new client that has joined also received tracks + waitUntilConnected(t, c3) + testutils.WithTimeout(t, func() string { + if len(c3.SubscribedTracks()) == 0 { + return "c3 didn't subscribe to anything" + } + // should have received two tracks + if len(c3.SubscribedTracks()[c1.ID()]) != 2 { + return "c3 didn't subscribe to tracks from c1" + } + return "" + }) + + // ensure that the track ids are generated by server + tracks := c3.SubscribedTracks()[c1.ID()] + for _, tr := range tracks { + require.True(t, strings.HasPrefix(tr.ID(), "TR_"), "track should begin with TR") + } + + // when c3 disconnects, ensure subscriber is cleaned up correctly + c3.Stop() + + testutils.WithTimeout(t, func() string { + room := s.RoomManager().GetRoom(context.Background(), testRoom) + p := room.GetParticipant("c1") + require.NotNil(t, p) + + for _, t := range p.GetPublishedTracks() { + if t.IsSubscriber(c3.ID()) { + return "c3 was not a subscriber of c1's tracks" + } + } + return "" + }) + }) + } +} + +func Test_WhenAutoSubscriptionDisabled_ClientShouldNotReceiveAnyPublishedTracks(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("Test_WhenAutoSubscriptionDisabled_ClientShouldNotReceiveAnyPublishedTracks") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + opts := testclient.Options{AutoSubscribe: false} + publisher := createRTCClient("publisher", defaultServerPort, testRTCServicePath, &opts) + client := createRTCClient("client", defaultServerPort, testRTCServicePath, &opts) + defer publisher.Stop() + defer client.Stop() + waitUntilConnected(t, publisher, client) + + track, err := publisher.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer track.Stop() + + time.Sleep(syncDelay) + + require.Empty(t, client.SubscribedTracks()[publisher.ID()]) + }) + } +} + +func Test_RenegotiationWithDifferentCodecs(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("TestRenegotiationWithDifferentCodecs") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1, c2) + + // publish a vp8 video track and ensure clients receive it ok + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 was not subscribed to tracks from c1" + } + + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + return "" + + } + } + return "did not receive track with vp8" + }) + + t3, err := c1.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{ + MimeType: "video/h264", + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + }, "videoscreen", "screen") + defer t3.Stop() + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2's not subscribed to anything" + } + // should have received three tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 3 { + return "c2's not subscribed to 3 tracks from c1" + } + + var vp8Found, h264Found bool + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + vp8Found = true + } else if mime.IsMimeTypeStringH264(t.Codec().MimeType) { + h264Found = true + } + } + if !vp8Found { + return "did not receive track with vp8" + } + if !h264Found { + return "did not receive track with h264" + } + return "" + }) + }) + } +} + +func TestSingleNodeRoomList(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestSingleNodeRoomList") + defer finish() + + roomServiceListRoom(t) +} + +func TestSingleNodeUpdateParticipant(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestSingleNodeRoomList") + defer finish() + + adminCtx := contextWithToken(adminRoomToken(testRoom)) + t.Run("update nonexistent participant", func(t *testing.T) { + _, err := roomClient.UpdateParticipant(adminCtx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "nonexistent", + Permission: &livekit.ParticipantPermission{ + CanPublish: true, + }, + }) + require.Error(t, err) + var twErr twirp.Error + require.True(t, errors.As(err, &twErr)) + require.Equal(t, twirp.NotFound, twErr.Code()) + }) +} + +// Ensure that CORS headers are returned +func TestSingleNodeCORS(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + s, finish := setupSingleNodeTest("TestSingleNodeCORS") + defer finish() + + req, err := http.NewRequest("POST", fmt.Sprintf("http://localhost:%d", s.HTTPPort()), nil) + require.NoError(t, err) + req.Header.Set("Authorization", "bearer xyz") + req.Header.Set("Origin", "testhost.com") + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, "testhost.com", res.Header.Get("Access-Control-Allow-Origin")) +} + +func TestSingleNodeDoubleSlash(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + s, finish := setupSingleNodeTest("TestSingleNodeDoubleSlash") + defer finish() + // client contains trailing slash in URL, causing path to contain double // + // without our middleware, this would cause a 302 redirect + roomClient = livekit.NewRoomServiceJSONClient(fmt.Sprintf("http://localhost:%d/", s.HTTPPort()), &http.Client{}) + _, err := roomClient.ListRooms(contextWithToken(listRoomToken()), &livekit.ListRoomsRequest{}) + require.NoError(t, err) +} + +func TestPingPong(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestPingPong") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + require.NoError(t, c1.SendPing()) + require.Eventually(t, func() bool { + return c1.PongReceivedAt() > 0 + }, time.Second, 10*time.Millisecond) + }) + } +} + +func TestSingleNodeJoinAfterClose(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("TestJoinAfterClose") + defer finish() + + scenarioJoinClosedRoom(t) +} + +func TestSingleNodeCloseNonRTCRoom(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("closeNonRTCRoom") + defer finish() + + closeNonRTCRoom(t) +} + +func TestAutoCreate(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + disableAutoCreate := func(conf *config.Config) { + conf.Room.AutoCreate = false + } + t.Run("cannot join if room isn't created", func(t *testing.T) { + s := createSingleNodeServer(disableAutoCreate) + go func() { + if err := s.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + defer s.Stop(true) + + waitForServerToStart(s) + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + token := joinToken(testRoom, "start-before-create", nil) + opts := &testclient.Options{} + testRTCServicePathToTestClientOptions(testRTCServicePath, opts) + _, err := testclient.NewWebSocketConn( + fmt.Sprintf("ws://localhost:%d", defaultServerPort), + token, + opts, + ) + require.Error(t, err) + + // second join should also fail + token = joinToken(testRoom, "start-before-create-2", nil) + _, err = testclient.NewWebSocketConn( + fmt.Sprintf("ws://localhost:%d", defaultServerPort), + token, + opts, + ) + require.Error(t, err) + }) + } + }) + + t.Run("join with explicit createRoom", func(t *testing.T) { + s := createSingleNodeServer(disableAutoCreate) + go func() { + if err := s.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + defer s.Stop(true) + + waitForServerToStart(s) + + // explicitly create + _, err := roomClient.CreateRoom(contextWithToken(createRoomToken()), &livekit.CreateRoomRequest{Name: testRoom}) + require.NoError(t, err) + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("join-after-create", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + + c1.Stop() + }) + } + }) +} + +// don't give user subscribe permissions initially, and ensure autosubscribe is triggered afterwards +func TestSingleNodeUpdateSubscriptionPermissions(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestSingleNodeUpdateSubscriptionPermissions") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + pub := createRTCClient("pub", defaultServerPort, testRTCServicePath, nil) + + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanSubscribe(false) + at := auth.NewAccessToken(testApiKey, testApiSecret). + AddGrant(grant). + SetIdentity("sub") + token, err := at.ToJWT() + require.NoError(t, err) + sub := createRTCClientWithToken(token, defaultServerPort, testRTCServicePath, nil) + + waitUntilConnected(t, pub, sub) + + writers := publishTracksForClients(t, pub) + defer stopWriters(writers...) + + // wait sub receives tracks + testutils.WithTimeout(t, func() string { + pubRemote := sub.GetRemoteParticipant(pub.ID()) + if pubRemote == nil { + return "could not find remote publisher" + } + if len(pubRemote.Tracks) != 2 { + return "did not receive metadata for published tracks" + } + return "" + }) + + // set permissions out of band + ctx := contextWithToken(adminRoomToken(testRoom)) + _, err = roomClient.UpdateParticipant(ctx, &livekit.UpdateParticipantRequest{ + Room: testRoom, + Identity: "sub", + Permission: &livekit.ParticipantPermission{ + CanSubscribe: true, + CanPublish: true, + }, + }) + require.NoError(t, err) + + testutils.WithTimeout(t, func() string { + tracks := sub.SubscribedTracks()[pub.ID()] + if len(tracks) == 2 { + return "" + } else { + return fmt.Sprintf("expected 2 tracks subscribed, actual: %d", len(tracks)) + } + }) + }) + } +} + +func TestSingleNodeAttributes(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + _, finish := setupSingleNodeTest("TestSingleNodeAttributes") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + pub := createRTCClient("pub", defaultServerPort, testRTCServicePath, &testclient.Options{ + Attributes: map[string]string{ + "b": "2", + "c": "3", + }, + TokenCustomizer: func(token *auth.AccessToken, grants *auth.VideoGrant) { + T := true + grants.CanUpdateOwnMetadata = &T + token.SetAttributes(map[string]string{ + "a": "0", + "b": "1", + }) + }, + }) + + grant := &auth.VideoGrant{RoomJoin: true, Room: testRoom} + grant.SetCanSubscribe(false) + at := auth.NewAccessToken(testApiKey, testApiSecret). + SetVideoGrant(grant). + SetIdentity("sub") + token, err := at.ToJWT() + require.NoError(t, err) + sub := createRTCClientWithToken(token, defaultServerPort, testRTCServicePath, nil) + + waitUntilConnected(t, pub, sub) + + // wait sub receives initial attributes + testutils.WithTimeout(t, func() string { + pubRemote := sub.GetRemoteParticipant(pub.ID()) + if pubRemote == nil { + return "could not find remote publisher" + } + attrs := pubRemote.Attributes + if !reflect.DeepEqual(attrs, map[string]string{ + "a": "0", + "b": "2", + "c": "3", + }) { + return fmt.Sprintf("did not receive expected attributes: %v", attrs) + } + return "" + }) + }) + } +} + +// TestDeviceCodecOverride checks that codecs that are incompatible with a device is not +// negotiated by the server +func TestDeviceCodecOverride(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("TestDeviceCodecOverride") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + // simulate device that isn't compatible with H.264 + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, &testclient.Options{ + ClientInfo: &livekit.ClientInfo{ + Os: "android", + DeviceModel: "Xiaomi 2201117TI", + }, + }) + defer c1.Stop() + waitUntilConnected(t, c1) + + // it doesn't really matter what the codec set here is, uses default Pion MediaEngine codecs + tw, err := c1.AddStaticTrack("video/h264", "video", "webcam") + require.NoError(t, err) + defer stopWriters(tw) + + var desc *sdp.MediaDescription + require.Eventually(t, func() bool { + lastAnswer := c1.LastAnswer() + if lastAnswer == nil { + return false + } + + sd := webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: lastAnswer.SDP, + } + answer, err := sd.Unmarshal() + require.NoError(t, err) + + // video and data channel + if len(answer.MediaDescriptions) < 2 { + return false + } + + for _, md := range answer.MediaDescriptions { + if md.MediaName.Media == "video" { + desc = md + break + } + } + return desc != nil + }, waitTimeout, waitTick, "did not receive answer") + + hasSeenVP8 := false + for _, a := range desc.Attributes { + if a.Key == "rtpmap" { + require.NotContains(t, a.Value, mime.MimeTypeCodecH264.String(), "should not contain H264 codec") + if strings.Contains(a.Value, mime.MimeTypeCodecVP8.String()) { + hasSeenVP8 = true + } + } + } + require.True(t, hasSeenVP8, "should have seen VP8 codec in SDP") + }) + } +} + +func TestSubscribeToCodecUnsupported(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + _, finish := setupSingleNodeTest("TestSubscribeToCodecUnsupported") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + // create a client that doesn't support H264 + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, &testclient.Options{ + AutoSubscribe: true, + DisabledCodecs: []webrtc.RTPCodecCapability{ + {MimeType: "video/H264"}, + }, + }) + waitUntilConnected(t, c1, c2) + + // publish a vp8 video track and ensure c2 receives it ok + t1, err := c1.AddStaticTrack("audio/opus", "audio", "webcam") + require.NoError(t, err) + defer t1.Stop() + t2, err := c1.AddStaticTrack("video/vp8", "video", "webcam") + require.NoError(t, err) + defer t2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 2 { + return "c2 was not subscribed to tracks from c1" + } + + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + return "" + } + } + return "did not receive track with vp8" + }) + require.Nil(t, c2.GetSubscriptionResponseAndClear()) + + // publish a h264 track and ensure c2 got subscription error + t3, err := c1.AddStaticTrackWithCodec(webrtc.RTPCodecCapability{ + MimeType: "video/h264", + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + }, "videoscreen", "screen") + defer t3.Stop() + require.NoError(t, err) + + var h264TrackID string + require.Eventually(t, func() bool { + remoteC1 := c2.GetRemoteParticipant(c1.ID()) + require.NotNil(t, remoteC1) + for _, track := range remoteC1.Tracks { + if mime.IsMimeTypeStringH264(track.MimeType) { + h264TrackID = track.Sid + return true + } + } + return false + }, time.Second, 10*time.Millisecond, "did not receive track info with h264") + + require.Eventually(t, func() bool { + sr := c2.GetSubscriptionResponseAndClear() + if sr == nil { + return false + } + require.Equal(t, h264TrackID, sr.TrackSid) + require.Equal(t, livekit.SubscriptionError_SE_CODEC_UNSUPPORTED, sr.Err) + return true + }, 5*time.Second, 10*time.Millisecond, "did not receive subscription response") + + // publish another vp8 track again, ensure the transport recovered by sfu and c2 can receive it + t4, err := c1.AddStaticTrack("video/vp8", "video2", "webcam2") + require.NoError(t, err) + defer t4.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedTracks()) == 0 { + return "c2 was not subscribed to anything" + } + // should have received two tracks + if len(c2.SubscribedTracks()[c1.ID()]) != 3 { + return "c2 was not subscribed to tracks from c1" + } + + var vp8Count int + tracks := c2.SubscribedTracks()[c1.ID()] + for _, t := range tracks { + if mime.IsMimeTypeStringVP8(t.Codec().MimeType) { + vp8Count++ + } + } + if vp8Count == 2 { + return "" + } + return "did not 2 receive track with vp8" + }) + require.Nil(t, c2.GetSubscriptionResponseAndClear()) + }) + } +} + +func TestDataPublishSlowSubscriber(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + dataChannelSlowThreshold := 21024 + + logger.Infow("----------------STARTING TEST----------------", "test", t.Name()) + s := createSingleNodeServer(func(c *config.Config) { + c.RTC.DatachannelSlowThreshold = dataChannelSlowThreshold + }) + go func() { + if err := s.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + + waitForServerToStart(s) + + defer func() { + s.Stop(true) + logger.Infow("----------------FINISHING TEST----------------", "test", t.Name()) + }() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + pub := createRTCClient("pub", defaultServerPort, testRTCServicePath, nil) + fastSub := createRTCClient("fastSub", defaultServerPort, testRTCServicePath, nil) + slowSubNotDrop := createRTCClient("slowSubNotDrop", defaultServerPort, testRTCServicePath, nil) + slowSubDrop := createRTCClient("slowSubDrop", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, pub, fastSub, slowSubDrop, slowSubNotDrop) + defer func() { + pub.Stop() + fastSub.Stop() + slowSubNotDrop.Stop() + slowSubDrop.Stop() + }() + + // no data should be dropped for fast subscriber + var fastDataIndex atomic.Uint64 + fastSub.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, fastDataIndex.Load()+1, idx) + fastDataIndex.Store(idx) + } + + // no data should be dropped for slow subscriber that is above threshold + var slowNoDropDataIndex atomic.Uint64 + var drainSlowSubNotDrop atomic.Bool + slowNoDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold * 2) + slowSubNotDrop.OnDataReceived = func(data []byte, sid string) { + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + require.Equal(t, slowNoDropDataIndex.Load()+1, idx) + slowNoDropDataIndex.Store(idx) + if !drainSlowSubNotDrop.Load() { + slowNoDropReader.Read(data, sid) + } + } + + // data should be dropped for slow subscriber that is below threshold + var slowDropDataIndex atomic.Uint64 + dropped := make(chan struct{}) + slowDropReader := testclient.NewDataChannelReader(dataChannelSlowThreshold / 2) + slowSubDrop.OnDataReceived = func(data []byte, sid string) { + select { + case <-dropped: + return + default: + } + idx := binary.BigEndian.Uint64(data[len(data)-8:]) + if idx != slowDropDataIndex.Load()+1 { + close(dropped) + } + slowDropDataIndex.Store(idx) + slowDropReader.Read(data, sid) + } + + // publisher sends data as fast as possible, it will block by the slowest subscriber above the slow threshold + var ( + blocked atomic.Bool + stopWrite atomic.Bool + writeIdx atomic.Uint64 + ) + writeStopped := make(chan struct{}) + go func() { + defer close(writeStopped) + var i int + buf := make([]byte, 100) + for !stopWrite.Load() { + i++ + binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(i)) + if err := pub.PublishData(buf, livekit.DataPacket_RELIABLE); err != nil { + if errors.Is(err, datachannel.ErrDataDroppedBySlowReader) { + blocked.Store(true) + i-- + continue + } else { + t.Log("error writing", err) + break + } + } + writeIdx.Store(uint64(i)) + } + }() + + <-dropped + + time.Sleep(time.Second) + blocked.Store(false) + require.Eventually(t, func() bool { return blocked.Load() }, 30*time.Second, 100*time.Millisecond) + stopWrite.Store(true) + <-writeStopped + drainSlowSubNotDrop.Store(true) + require.Eventually(t, func() bool { + return writeIdx.Load() == fastDataIndex.Load() && + writeIdx.Load() == slowNoDropDataIndex.Load() + }, 10*time.Second, 50*time.Millisecond, "writeIdx %d, fast %d, slowNoDrop %d", writeIdx.Load(), fastDataIndex.Load(), slowNoDropDataIndex.Load()) + }) + } +} + +func TestFireTrackBySdp(t *testing.T) { + _, finish := setupSingleNodeTest("TestFireTrackBySdp") + defer finish() + + var cases = []struct { + name string + codecs []webrtc.RTPCodecCapability + pubSDK livekit.ClientInfo_SDK + }{ + { + name: "js client could pub a/v tracks", + codecs: []webrtc.RTPCodecCapability{ + {MimeType: mime.MimeTypeH264.String()}, + {MimeType: mime.MimeTypeOpus.String()}, + }, + pubSDK: livekit.ClientInfo_JS, + }, + { + name: "go client could pub audio tracks", + codecs: []webrtc.RTPCodecCapability{ + {MimeType: "audio/opus"}, + }, + pubSDK: livekit.ClientInfo_GO, + }, + } + + for _, c := range cases { + codecs, sdk := c.codecs, c.pubSDK + t.Run(c.name, func(t *testing.T) { + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient(c.name+"_c1", defaultServerPort, testRTCServicePath, &testclient.Options{ + ClientInfo: &livekit.ClientInfo{ + Sdk: sdk, + }, + }) + c2 := createRTCClient(c.name+"_c2", defaultServerPort, testRTCServicePath, &testclient.Options{ + AutoSubscribe: true, + ClientInfo: &livekit.ClientInfo{ + Sdk: livekit.ClientInfo_JS, + }, + }) + waitUntilConnected(t, c1, c2) + defer func() { + c1.Stop() + c2.Stop() + }() + + // publish tracks and don't write any packets + for _, codec := range codecs { + _, err := c1.AddStaticTrackWithCodec(codec, codec.MimeType, codec.MimeType, testclient.AddTrackNoWriter()) + require.NoError(t, err) + } + + require.Eventually(t, func() bool { + return len(c2.SubscribedTracks()[c1.ID()]) == len(codecs) + }, 5*time.Second, 10*time.Millisecond) + + var found int + for _, pubTrack := range c1.GetPublishedTrackIDs() { + t.Log("pub track", pubTrack) + tracks := c2.SubscribedTracks()[c1.ID()] + for _, track := range tracks { + t.Log("sub track", track.ID(), track.Codec()) + if track.Codec().PayloadType == 0 && track.ID() == pubTrack { + found++ + break + } + } + } + require.Equal(t, len(codecs), found) + }) + } + }) + } +} + +func TestSinglePublisherDataTrack(t *testing.T) { + if testing.Short() { + t.SkipNow() + return + } + + s, finish := setupSingleNodeTest("TestSinglePublisherDataTrack") + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + waitUntilConnected(t, c1, c2) + + // publish a couple of data tracks and ensure clients receive it ok + dt1, err := c1.PublishDataTrack() + require.NoError(t, err) + defer dt1.Stop() + + dt2, err := c1.PublishDataTrack() + require.NoError(t, err) + defer dt2.Stop() + + testutils.WithTimeout(t, func() string { + if len(c2.SubscribedDataTracks()) == 0 { + return "c2 was not subscribed to any data tracks" + } + // should have received two data tracks + if len(c2.SubscribedDataTracks()[c1.ID()]) != 2 { + return "c2 didn't subscribe to both data tracks from c1" + } + return "" + }) + + // a new client joins and should get the initial stream + c3 := createRTCClient("c3", defaultServerPort, testRTCServicePath, &testclient.Options{AutoSubscribeDataTrack: true}) + + // ensure that new client that has joined also received data tracks + waitUntilConnected(t, c3) + testutils.WithTimeout(t, func() string { + if len(c3.SubscribedDataTracks()) == 0 { + return "c3 didn't subscribe to any data tracks" + } + // should have received two data tracks + if len(c3.SubscribedDataTracks()[c1.ID()]) != 2 { + return "c3 didn't subscribe to tracks from c1" + } + return "" + }) + + // ensure that the data track ids are generated by server + tracks := c3.SubscribedDataTracks()[c1.ID()] + for _, tr := range tracks { + require.True(t, strings.HasPrefix(string(tr.ID()), "DTR_"), "data track should begin with DTR") + } + + // when c3 disconnects, ensure subscriber is cleaned up correctly + c3.Stop() + + testutils.WithTimeout(t, func() string { + room := s.RoomManager().GetRoom(context.Background(), testRoom) + p := room.GetParticipant("c1") + require.NotNil(t, p) + + for _, t := range p.GetPublishedDataTracks() { + if t.IsSubscriber(c3.ID()) { + return "c3 was not a subscriber of c1's data tracks" + } + } + return "" + }) + }) + } +} diff --git a/livekit/test/webhook_test.go b/livekit/test/webhook_test.go new file mode 100644 index 0000000..8ea62ee --- /dev/null +++ b/livekit/test/webhook_test.go @@ -0,0 +1,241 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" + + "github.com/livekit/protocol/auth" + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils/guid" + "github.com/livekit/protocol/webhook" + + "github.com/livekit/livekit-server/pkg/config" + "github.com/livekit/livekit-server/pkg/routing" + "github.com/livekit/livekit-server/pkg/rtc/types" + "github.com/livekit/livekit-server/pkg/service" + "github.com/livekit/livekit-server/pkg/testutils" +) + +func TestWebhooks(t *testing.T) { + server, ts, finish, err := setupServerWithWebhook() + require.NoError(t, err) + defer finish() + + for _, testRTCServicePath := range testRTCServicePaths { + t.Run(fmt.Sprintf("testRTCServicePath=%s", testRTCServicePath.String()), func(t *testing.T) { + c1 := createRTCClient("c1", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c1) + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventRoomStarted) == nil { + return "did not receive RoomStarted" + } + if ts.GetEvent(webhook.EventParticipantJoined) == nil { + return "did not receive ParticipantJoined" + } + return "" + }) + + // first participant join should have started the room + started := ts.GetEvent(webhook.EventRoomStarted) + require.Equal(t, testRoom, started.Room.Name) + require.NotEmpty(t, started.Id) + require.Greater(t, started.CreatedAt, time.Now().Unix()-100) + require.GreaterOrEqual(t, time.Now().Unix(), started.CreatedAt) + joined := ts.GetEvent(webhook.EventParticipantJoined) + require.Equal(t, "c1", joined.Participant.Identity) + ts.ClearEvents() + + // another participant joins + c2 := createRTCClient("c2", defaultServerPort, testRTCServicePath, nil) + waitUntilConnected(t, c2) + defer c2.Stop() + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventParticipantJoined) == nil { + return "did not receive ParticipantJoined" + } + return "" + }) + joined = ts.GetEvent(webhook.EventParticipantJoined) + require.Equal(t, "c2", joined.Participant.Identity) + ts.ClearEvents() + + // track published + writers := publishTracksForClients(t, c1) + defer stopWriters(writers...) + testutils.WithTimeout(t, func() string { + ev := ts.GetEvent(webhook.EventTrackPublished) + if ev == nil { + return "did not receive TrackPublished" + } + require.NotNil(t, ev.Track, "TrackPublished did not include trackInfo") + require.Equal(t, string(c1.ID()), ev.Participant.Sid) + return "" + }) + ts.ClearEvents() + + // first participant leaves + c1.Stop() + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventParticipantLeft) == nil { + return "did not receive ParticipantLeft" + } + return "" + }) + left := ts.GetEvent(webhook.EventParticipantLeft) + require.Equal(t, "c1", left.Participant.Identity) + ts.ClearEvents() + + // room closed + rm := server.RoomManager().GetRoom(context.Background(), testRoom) + rm.Close(types.ParticipantCloseReasonNone) + testutils.WithTimeout(t, func() string { + if ts.GetEvent(webhook.EventRoomFinished) == nil { + return "did not receive RoomFinished" + } + return "" + }) + require.Equal(t, testRoom, ts.GetEvent(webhook.EventRoomFinished).Room.Name) + }) + } +} + +func setupServerWithWebhook() (server *service.LivekitServer, testServer *webhookTestServer, finishFunc func(), err error) { + conf, err := config.NewConfig("", true, nil, nil) + if err != nil { + panic(fmt.Sprintf("could not create config: %v", err)) + } + conf.WebHook.URLs = []string{"http://localhost:7890"} + conf.WebHook.APIKey = testApiKey + conf.Keys = map[string]string{testApiKey: testApiSecret} + + testServer = newTestServer(":7890") + if err = testServer.Start(); err != nil { + return + } + + currentNode, err := routing.NewLocalNode(conf) + if err != nil { + return + } + currentNode.SetNodeID(livekit.NodeID(guid.New(nodeID1))) + + server, err = service.InitializeServer(conf, currentNode) + if err != nil { + return + } + + go func() { + if err := server.Start(); err != nil { + logger.Errorw("server returned error", err) + } + }() + + waitForServerToStart(server) + + finishFunc = func() { + server.Stop(true) + testServer.Stop() + } + return +} + +type webhookTestServer struct { + server *http.Server + events map[string]*livekit.WebhookEvent + lock sync.Mutex + provider auth.KeyProvider +} + +func newTestServer(addr string) *webhookTestServer { + s := &webhookTestServer{ + events: make(map[string]*livekit.WebhookEvent), + provider: auth.NewFileBasedKeyProviderFromMap(map[string]string{testApiKey: testApiSecret}), + } + s.server = &http.Server{ + Addr: addr, + Handler: s, + } + return s +} + +func (s *webhookTestServer) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + data, err := webhook.Receive(r, s.provider) + if err != nil { + logger.Errorw("could not receive webhook", err) + return + } + + event := livekit.WebhookEvent{} + if err = protojson.Unmarshal(data, &event); err != nil { + logger.Errorw("could not unmarshal event", err) + return + } + + s.lock.Lock() + s.events[event.Event] = &event + s.lock.Unlock() +} + +func (s *webhookTestServer) GetEvent(name string) *livekit.WebhookEvent { + s.lock.Lock() + defer s.lock.Unlock() + return s.events[name] +} + +func (s *webhookTestServer) ClearEvents() { + s.lock.Lock() + s.events = make(map[string]*livekit.WebhookEvent) + s.lock.Unlock() +} + +func (s *webhookTestServer) Start() error { + l, err := net.Listen("tcp", s.server.Addr) + if err != nil { + return err + } + go s.server.Serve(l) + + // wait for webhook server to start + ctx, cancel := context.WithTimeout(context.Background(), testutils.ConnectTimeout) + defer cancel() + for { + select { + case <-ctx.Done(): + return errors.New("could not start webhook server after timeout") + case <-time.After(10 * time.Millisecond): + // ensure we can connect to it + res, err := http.Get(fmt.Sprintf("http://localhost%s", s.server.Addr)) + if err == nil && res.StatusCode == http.StatusOK { + return nil + } + } + } +} + +func (s *webhookTestServer) Stop() { + _ = s.server.Shutdown(context.Background()) +} diff --git a/livekit/tools/tools.go b/livekit/tools/tools.go new file mode 100644 index 0000000..f77dcd8 --- /dev/null +++ b/livekit/tools/tools.go @@ -0,0 +1,23 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build tools +// +build tools + +package tools + +import ( + _ "github.com/google/wire/cmd/wire" + _ "github.com/maxbrunsfeld/counterfeiter/v6" +) diff --git a/livekit/version/version.go b/livekit/version/version.go new file mode 100644 index 0000000..e9829ea --- /dev/null +++ b/livekit/version/version.go @@ -0,0 +1,17 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package version + +const Version = "1.9.11"